diff --git a/bert/obtain_inst_vec.py b/bert/obtain_inst_vec.py index 74d5d36..40d7bf8 100644 --- a/bert/obtain_inst_vec.py +++ b/bert/obtain_inst_vec.py @@ -27,7 +27,8 @@ config_file = os.path.join('./bert/bert.json') def load_model(): config = BertConfig.from_json_file(config_file) model = BertForPreTraining(config) - state_dict = torch.load(model_file) + # state_dict = torch.load(model_file) + state_dict = torch.load(model_file, map_location='cpu') model.load_state_dict(state_dict) model.eval() @@ -103,9 +104,9 @@ def generate_inst_vec(inst, method="mean"): return result -def bb2vec(inst): - tmp = generate_inst_vec(inst, method="mean") - return list(np.mean(tmp.detach().numpy(), axis=0).astype(float)) +def bb2vec(item): + tmp = generate_inst_vec(item['opcode'], method="mean") + return item['addr'], list(np.mean(tmp.detach().numpy(), axis=0).astype(float)) if __name__ == "__main__": diff --git a/json_feature2json.py b/json_feature2json.py index 64affc8..27c59ff 100644 --- a/json_feature2json.py +++ b/json_feature2json.py @@ -1,11 +1,13 @@ import os -from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL +from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF, THREAD_FULL from bert.obtain_inst_vec import bb2vec import multiprocessing from tqdm import tqdm import warnings from datetime import datetime +import concurrent.futures +warnings.filterwarnings("ignore") def addr2vec(base_file_path): @@ -13,32 +15,56 @@ def addr2vec(base_file_path): file_name = str(os.path.basename(base_file_path)) file_path = str(os.path.dirname(base_file_path)) - # 忽略已生成的文件 - if os.path.exists(os.path.join(file_path, 'final', file_name)): - return - # 如果不是路径则开始转化 if file_name: + # 忽略已生成的文件 + if os.path.exists(os.path.join(file_path, 'final', file_name)): + process_bar.update(1) + return file_json = load_json(base_file_path) # 确保存在基础文件而不存在特征文件的情况 feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists( os.path.join(file_path, 'feature', file_name)) else None if feature_json is not None: feature_set = {} - for item in feature_json: - feature_set[item['addr']] = bb2vec(item['opcode']) - for item in file_json['acfg_list']: - bb_feature_addr_list = item['block_features'] - item['block_features'] = [feature_set[key] for key in bb_feature_addr_list] + + # 多线程预测bert + # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + # try: + # future_args = { + # executor.submit(bb2vec, item): item for item in feature_json + # } + # for future in concurrent.futures.as_completed(future_args): + # result = future.result() + # feature_set[result[0]] = result[1] + # except Exception as e: + # logger.error(f"bert 解析出错 {file_name},{e}") + + with multiprocessing.Pool(processes=os.cpu_count()) as pool: + try: + results = pool.imap_unordered(bb2vec, [item for item in feature_json]) + for result in results: + feature_set[result[0]] = result[1] + except Exception as e: + logger.error(f"bert 解析出错 {file_name},{e}") + + try: + for item in file_json['acfg_list']: + bb_feature_addr_list = item['block_features'] + item['block_features'] = [feature_set[key] for key in bb_feature_addr_list] + except Exception as e: + logger.error(f"地址对应出错{file_name}, {e}") save_json(os.path.join(file_path, 'final', file_name), file_json) + else: logger.error(f'文件{file_name}不存在特征文件') + process_bar.update(1) return if __name__ == '__main__': logger = setup_logger('feature2json', './log/feature2json.log') - sample_type = 'malware' + sample_type = 'benign' # json_path = os.path.join(f'./out/json/{sample_type}') json_path = os.path.join(f'./out/json/{sample_type}') json_files = os.listdir(json_path) @@ -48,10 +74,9 @@ if __name__ == '__main__': # with multiprocessing.Pool(processes=os.cpu_count()) as pool: # result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]), # total=len(json_files))) - multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL) - - - - - # for json_file in json_files: - # addr2vec(json_path, json_file) + # multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files if + # os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL) + process_bar = tqdm(total=len(json_files)) + for json_file in json_files: + if os.path.isfile(os.path.join(json_path, json_file)): + addr2vec(os.path.join(json_path, json_file))