From bdde25a9ab1b38b595274aedfa225a9676e9f9f9 Mon Sep 17 00:00:00 2001 From: huihun <781165206@qq.com> Date: Thu, 18 Apr 2024 16:44:31 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BA=BF=E7=A8=8B=E6=B1=A0=E6=94=B9=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bert/obtain_inst_vec.py | 4 ++-- exe2json.py | 8 +++++--- json_feature2json.py | 21 ++++++++++++++------- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/bert/obtain_inst_vec.py b/bert/obtain_inst_vec.py index 675f36f..74d5d36 100644 --- a/bert/obtain_inst_vec.py +++ b/bert/obtain_inst_vec.py @@ -16,7 +16,7 @@ config_file = os.path.join('./bert/bert.json') # 禁用分词器多线程 -os.environ["TOKENIZERS_PARALLELISM"] = "false" +# os.environ["TOKENIZERS_PARALLELISM"] = "false" # from my_data_collator import MyDataCollatorForPreTraining # model_file = os.path.join("./pytorch_model.bin") @@ -27,7 +27,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" def load_model(): config = BertConfig.from_json_file(config_file) model = BertForPreTraining(config) - state_dict = torch.load(model_file, map_location='cpu') + state_dict = torch.load(model_file) model.load_state_dict(state_dict) model.eval() diff --git a/exe2json.py b/exe2json.py index 1e8dad0..6e0eb1a 100644 --- a/exe2json.py +++ b/exe2json.py @@ -1,3 +1,5 @@ +import random + import r2pipe import hashlib from my_utils import * @@ -14,7 +16,7 @@ import multiprocessing ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"] -sample_type = 'benign' +sample_type = 'malware' def extract_opcode(disasm_text): @@ -252,11 +254,11 @@ def exe_to_json(file_path): if __name__ == '__main__': logger = init_logging() - sample_file_path = f"/mnt/d/bishe/sample_{sample_type}" + sample_file_path = f"/mnt/d/bishe/dataset/sample_{sample_type}" sample_file_list = os.listdir(sample_file_path) print(f"max worker {os.cpu_count()}") with multiprocessing.Pool(processes=os.cpu_count()) as pool: - result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[::-1]]), total=len(sample_file_list))) + result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[:10000]]), total=10000)) # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: diff --git a/json_feature2json.py b/json_feature2json.py index 408ede7..64affc8 100644 --- a/json_feature2json.py +++ b/json_feature2json.py @@ -1,12 +1,11 @@ import os -from my_utils import save_json, load_json, setup_logger +from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL from bert.obtain_inst_vec import bb2vec import multiprocessing from tqdm import tqdm import warnings +from datetime import datetime -# 忽略输出torch cpu执行警告 -warnings.filterwarnings('ignore', category=UserWarning) def addr2vec(base_file_path): @@ -39,12 +38,20 @@ def addr2vec(base_file_path): if __name__ == '__main__': logger = setup_logger('feature2json', './log/feature2json.log') - sample_type = 'benign' + sample_type = 'malware' + # json_path = os.path.join(f'./out/json/{sample_type}') json_path = os.path.join(f'./out/json/{sample_type}') json_files = os.listdir(json_path) - with multiprocessing.Pool(processes=os.cpu_count()) as pool: - result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]), - total=len(json_files))) + now = datetime.now() + formatted_now = now.strftime("%Y-%m-%d %H:%M:%S") + print("start time:", formatted_now) + # with multiprocessing.Pool(processes=os.cpu_count()) as pool: + # result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]), + # total=len(json_files))) + multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL) + + + # for json_file in json_files: # addr2vec(json_path, json_file)