diff --git a/exe2json.py b/exe2json.py index 6e0eb1a..3445d73 100644 --- a/exe2json.py +++ b/exe2json.py @@ -5,18 +5,19 @@ import hashlib from my_utils import * import json # 基础块抽取 -# from bert.obtain_inst_vec import bb2vec +from bert.obtain_inst_vec import bb2vec from tqdm import tqdm import numpy as np import os +import warnings import concurrent.futures import multiprocessing - ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"] -sample_type = 'malware' +sample_type = 'benign' + def extract_opcode(disasm_text): @@ -73,7 +74,6 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name): block_addr = block["addr"] block_Statement = [] - node_list.append(block["addr"]) # 获取基本块的反汇编指令 disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"])) @@ -84,7 +84,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name): block_len = len(disasm) for op_index, op in enumerate(disasm): # 提取操作码并转换为bert模型输入格式 - op_disasm = extract_opcode(op["disasm"]) + op_disasm = extract_opcode(op["opcode"]) # 如果单个基础块的长度大于20且操作码重复,则跳过 if block_len > 20 and op_disasm in block_Statement: continue @@ -129,7 +129,8 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name): if op not in ret_trap_opcode_family or op["type"] not in ["ret", "trap"]: temp_edge_list.append([block_addr, op["offset"] + op["size"]]) if block_len > 20: - logger.warning(f"二进制可执行文件解析警告,基础块长度大于20,文件{file_path},基础块地址{block_addr},操作码长度{block_len}->{len(block_Statement)}") + logger.warning( + f"二进制可执行文件解析警告,基础块长度大于20,文件{file_path},基础块地址{block_addr},操作码长度{block_len}->{len(block_Statement)}") # debugger # print(len(disasm)) @@ -214,9 +215,9 @@ def exe_to_json(file_path): output_path = f"./out/json/{sample_type}" file_fingerprint = calc_sha256(file_path) if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')): - # logger.info(f"二进制可执行文件已解析,文件名{file_path}") + logger.info(f"二进制可执行文件已解析,文件名{file_path}") return - # logger.info(f"开始解析,文件名{file_path}") + logger.info(f"开始解析,文件名{file_path}") # 获取r2pipe并解析文件 解析完即释放r2 r2 = r2pipe.open(file_path, flags=['-2']) @@ -248,18 +249,20 @@ def exe_to_json(file_path): # json写入 os.makedirs(output_path, exist_ok=True) save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj) - # logger.info(f"解析完成,文件名{file_path}") + logger.info(f"解析完成,文件名{file_path}") return if __name__ == '__main__': logger = init_logging() - sample_file_path = f"/mnt/d/bishe/dataset/sample_{sample_type}" + sample_file_path = f"/mnt/d/bishe/sample_{sample_type}" sample_file_list = os.listdir(sample_file_path) print(f"max worker {os.cpu_count()}") with multiprocessing.Pool(processes=os.cpu_count()) as pool: - result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[:10000]]), total=10000)) - + result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in + sample_file_list[:10000]]), total=10000)) + # result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in + # [os.path.join(sample_file_path, '00e64dab6a0a572f0474ff92353794fc.exe')]]), total=10000)) # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: # futures_to_args = { @@ -269,7 +272,5 @@ if __name__ == '__main__': # for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(futures_to_args)): # pass - - # test_file_path = '/mnt/d/bishe/exe2json/sample/VirusShare_0a3b625380161cf92c4bb10135326bb5' # exe_to_json(test_file_path) diff --git a/json_feature2json.py b/json_feature2json.py index 1927565..a31f387 100644 --- a/json_feature2json.py +++ b/json_feature2json.py @@ -1,5 +1,5 @@ import os -from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF, THREAD_FULL +from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF, THREAD_FULL, continuation_json from bert.obtain_inst_vec import bb2vec import multiprocessing from tqdm import tqdm @@ -25,31 +25,38 @@ def addr2vec(base_file_path, index): feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists( os.path.join(file_path, 'feature', file_name)) else None if feature_json is not None: - feature_set = {} + + # 对于长度过长的文件先不处理 + if len(feature_json) > 10000: + data = { + 'file_name': file_name, + 'feature_len': len(feature_json) + } + continuation_json(os.path.join(f'./out/json/too_long_{sample_type}.json'), data) + return # 多线程预测bert - # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + feature_set = {} + # with multiprocessing.Pool(processes=os.cpu_count()) as pool: # try: - # future_args = { - # executor.submit(bb2vec, item): item for item in feature_json - # } - # for future in concurrent.futures.as_completed(future_args): - # result = future.result() + # results = list(tqdm(pool.imap_unordered(bb2vec, [item for item in feature_json]), + # total=len(feature_json), + # desc=f'{file_name} Progress:{index}/{json_files_len} ', + # leave=True, + # dynamic_ncols=True)) + # for result in results: # feature_set[result[0]] = result[1] # except Exception as e: # logger.error(f"bert 解析出错 {file_name},{e}") - - with multiprocessing.Pool(processes=os.cpu_count()) as pool: - try: - results = list(tqdm(pool.imap_unordered(bb2vec, [item for item in feature_json]), - total=len(feature_json), - desc=f'{file_name} Progress:{index}/{json_files_len} ', - leave=True, - dynamic_ncols=True)) - for result in results: - feature_set[result[0]] = result[1] - except Exception as e: - logger.error(f"bert 解析出错 {file_name},{e}") + # debug + try: + for index, feature in tqdm(enumerate(feature_json), total=len(feature_json)): + addr, feature = bb2vec(feature) + feature_set[addr] = feature + except Exception as e: + print(index) + print(e) + print(feature['opcode']) try: for item in file_json['acfg_list']: @@ -57,6 +64,7 @@ def addr2vec(base_file_path, index): item['block_features'] = [feature_set[key] for key in bb_feature_addr_list] except Exception as e: logger.error(f"地址对应出错{file_name}, {e}") + return save_json(os.path.join(file_path, 'final', file_name), file_json) else: @@ -69,7 +77,8 @@ if __name__ == '__main__': sample_type = 'benign' # json_path = os.path.join(f'./out/json/{sample_type}') json_path = os.path.join(f'./out/json/{sample_type}') - json_files = os.listdir(json_path) + # json_files = os.listdir(json_path) + json_files = ['1710ae16c54ac149f353ba58e752ba7069f88326e6b71107598283bd0fffcbd6.jsonl'] json_files_len = len(json_files) now = datetime.now() formatted_now = now.strftime("%Y-%m-%d %H:%M:%S") diff --git a/my_utils.py b/my_utils.py index 6ec6caf..b090d19 100644 --- a/my_utils.py +++ b/my_utils.py @@ -10,7 +10,7 @@ import json """ -def setup_logger(name, log_file, level=logging.INFO): +def setup_logger(name, log_file, level=logging.INFO, reset=False): """Function setup as many loggers as you want""" if not os.path.exists(os.path.dirname(log_file)): os.makedirs(os.path.dirname(log_file)) @@ -32,7 +32,7 @@ def setup_logger(name, log_file, level=logging.INFO): # 刷新原有log文件 - if os.path.exists(log_file): + if reset: open(log_file, 'w').close() return logger @@ -73,7 +73,7 @@ def multi_thread_order(func, args, thread_num=THREAD_FULL): return result -def multi_thread_disorder(func, thread_num=THREAD_FULL, **args): +def multi_thread_disorder(func, thread_num=THREAD_FULL, **args): import multiprocessing from tqdm import tqdm with multiprocessing.Pool(processes=thread_num) as pool: @@ -87,6 +87,13 @@ def save_json(filename, data): file.close() +def continuation_json(filename, data): + data = json.dumps(data) + file = open(filename, 'a') + file.write(data) + file.close() + + def load_json(filename): file = open(filename, 'r') data = json.loads(file.read())