From 7c8145b52a32f4897c407f7821fdc4968c1535b0 Mon Sep 17 00:00:00 2001 From: huihun <781165206@qq.com> Date: Wed, 17 Apr 2024 15:54:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=E5=AE=8C=E5=96=84?= =?UTF-8?q?=20=20=E7=89=B9=E5=BE=81=E6=95=B4=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bert/obtain_inst_vec.py | 8 +++++-- exe2json.py | 15 ++++++------- json_feature2json.py | 50 +++++++++++++++++++++++++++++++++++++++++ my_utils.py | 20 +++++++++++++---- 4 files changed, 79 insertions(+), 14 deletions(-) create mode 100644 json_feature2json.py diff --git a/bert/obtain_inst_vec.py b/bert/obtain_inst_vec.py index 7f3af0e..675f36f 100644 --- a/bert/obtain_inst_vec.py +++ b/bert/obtain_inst_vec.py @@ -14,6 +14,10 @@ model_file = os.path.join("./bert/pytorch_model.bin") tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json") config_file = os.path.join('./bert/bert.json') + +# 禁用分词器多线程 +os.environ["TOKENIZERS_PARALLELISM"] = "false" + # from my_data_collator import MyDataCollatorForPreTraining # model_file = os.path.join("./pytorch_model.bin") # tokenizer_file = os.path.join("./tokenizer-inst.all.json") @@ -23,7 +27,7 @@ config_file = os.path.join('./bert/bert.json') def load_model(): config = BertConfig.from_json_file(config_file) model = BertForPreTraining(config) - state_dict = torch.load(model_file) + state_dict = torch.load(model_file, map_location='cpu') model.load_state_dict(state_dict) model.eval() @@ -101,7 +105,7 @@ def generate_inst_vec(inst, method="mean"): def bb2vec(inst): tmp = generate_inst_vec(inst, method="mean") - return list(np.mean(tmp.detach().numpy(), axis=0)) + return list(np.mean(tmp.detach().numpy(), axis=0).astype(float)) if __name__ == "__main__": diff --git a/exe2json.py b/exe2json.py index ad575ff..37574ba 100644 --- a/exe2json.py +++ b/exe2json.py @@ -11,11 +11,10 @@ import os import concurrent.futures import multiprocessing -# 禁用分词器多线程 -os.environ["TOKENIZERS_PARALLELISM"] = "false" + ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"] -sample_type = 'benign' +sample_type = 'malware' def extract_opcode(disasm_text): @@ -144,7 +143,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name): # block_feature_list = [] # 暂时将bb地址作为特征 后续将操作码集中转化为特征 - block_feature_list = block_addr + block_feature_list.append(block_addr) acfg_feature_item.append({'addr': block_addr, 'opcode': block_Statement}) # 过滤不存在的边 @@ -214,9 +213,9 @@ def exe_to_json(file_path): output_path = f"./out/json/{sample_type}" file_fingerprint = calc_sha256(file_path) if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')): - logger.info(f"二进制可执行文件已解析,文件名{file_path}") + # logger.info(f"二进制可执行文件已解析,文件名{file_path}") return - logger.info(f"开始解析,文件名{file_path}") + # logger.info(f"开始解析,文件名{file_path}") # 获取r2pipe并解析文件 解析完即释放r2 r2 = r2pipe.open(file_path, flags=['-2']) @@ -248,7 +247,7 @@ def exe_to_json(file_path): # json写入 os.makedirs(output_path, exist_ok=True) save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj) - logger.info(f"解析完成,文件名{file_path}") + # logger.info(f"解析完成,文件名{file_path}") return @@ -258,7 +257,7 @@ if __name__ == '__main__': sample_file_list = os.listdir(sample_file_path) print(f"max worker {os.cpu_count()}") with multiprocessing.Pool(processes=os.cpu_count()) as pool: - result = list(tqdm(pool.imap(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list))) + result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list))) # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: diff --git a/json_feature2json.py b/json_feature2json.py new file mode 100644 index 0000000..408ede7 --- /dev/null +++ b/json_feature2json.py @@ -0,0 +1,50 @@ +import os +from my_utils import save_json, load_json, setup_logger +from bert.obtain_inst_vec import bb2vec +import multiprocessing +from tqdm import tqdm +import warnings + +# 忽略输出torch cpu执行警告 +warnings.filterwarnings('ignore', category=UserWarning) + + +def addr2vec(base_file_path): + # 从路径拆分文件名与路径 + file_name = str(os.path.basename(base_file_path)) + file_path = str(os.path.dirname(base_file_path)) + + # 忽略已生成的文件 + if os.path.exists(os.path.join(file_path, 'final', file_name)): + return + + # 如果不是路径则开始转化 + if file_name: + file_json = load_json(base_file_path) + # 确保存在基础文件而不存在特征文件的情况 + feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists( + os.path.join(file_path, 'feature', file_name)) else None + if feature_json is not None: + feature_set = {} + for item in feature_json: + feature_set[item['addr']] = bb2vec(item['opcode']) + for item in file_json['acfg_list']: + bb_feature_addr_list = item['block_features'] + item['block_features'] = [feature_set[key] for key in bb_feature_addr_list] + save_json(os.path.join(file_path, 'final', file_name), file_json) + else: + logger.error(f'文件{file_name}不存在特征文件') + return + + +if __name__ == '__main__': + logger = setup_logger('feature2json', './log/feature2json.log') + sample_type = 'benign' + json_path = os.path.join(f'./out/json/{sample_type}') + json_files = os.listdir(json_path) + with multiprocessing.Pool(processes=os.cpu_count()) as pool: + result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]), + total=len(json_files))) + + # for json_file in json_files: + # addr2vec(json_path, json_file) diff --git a/my_utils.py b/my_utils.py index 0874262..6ec6caf 100644 --- a/my_utils.py +++ b/my_utils.py @@ -1,6 +1,6 @@ import logging import os - +import json """ 日志工具 @@ -45,7 +45,7 @@ THREAD_FULL = os.cpu_count() THREAD_HALF = int(os.cpu_count() / 2) -def multi_thread(func, args, thread_num=THREAD_FULL): +def multi_thread_order(func, args, thread_num=THREAD_FULL): """ 多线程执行函数 :param func: 函数 @@ -73,10 +73,22 @@ def multi_thread(func, args, thread_num=THREAD_FULL): return result -def save_json(filename, data): - import json +def multi_thread_disorder(func, thread_num=THREAD_FULL, **args): + import multiprocessing + from tqdm import tqdm + with multiprocessing.Pool(processes=thread_num) as pool: + list(tqdm(pool.imap_unordered(func, args), total=len(args))) + +def save_json(filename, data): data = json.dumps(data) file = open(filename, 'w') file.write(data) file.close() + + +def load_json(filename): + file = open(filename, 'r') + data = json.loads(file.read()) + file.close() + return data