修改多线程执行

This commit is contained in:
huihun 2024-04-17 10:28:01 +08:00
parent 6c85875eb4
commit 30ffaa2a46
2 changed files with 20 additions and 16 deletions

View File

@ -3,12 +3,13 @@ import hashlib
from my_utils import *
import json
# 基础块抽取
from bert.obtain_inst_vec import bb2vec
# from bert.obtain_inst_vec import bb2vec
from tqdm import tqdm
import numpy as np
import os
import concurrent.futures
import multiprocessing
# 禁用分词器多线程
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -79,9 +80,9 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name):
# continue
if disasm:
last_op = ''
if len(disasm) > 200:
logger.warning(
f"基础块指令长度异常,{file_path},函数名称{function['name']}基础块地址{block['addr']},长度{len(disasm)}")
# if len(disasm) > 200:
# logger.warning(
# f"基础块指令长度异常,{file_path},函数名称{function['name']}基础块地址{block['addr']},长度{len(disasm)}")
for op_index, op in enumerate(disasm):
op_disasm = extract_opcode(op["disasm"])
# 防止大量重复的语句造成内存溢出
@ -116,7 +117,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name):
temp_edge_list.append([block_addr, op["offset"] + op["size"]])
else:
logger.warning(
f"二进制可执行文件解析警告,跳转指令识别出错,指令{op}")
f"二进制可执行文件解析警告,跳转指令识别出错,文件{file_path},指令{op}")
# 操作码不存在跳转指令
else:
@ -161,7 +162,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name):
'block_features': block_feature_list
}
acfg_item.append(acfg)
save_json(os.path.join(file_path, 'feature', file_name + '.jsonl'), acfg_feature_item)
save_json(os.path.join(output_path, 'feature', file_name + '.jsonl'), acfg_feature_item)
return True, "二进制可执行文件解析成功", acfg_item
except Exception as e:
@ -256,17 +257,19 @@ if __name__ == '__main__':
sample_file_path = f"/mnt/d/bishe/sample_{sample_type}"
sample_file_list = os.listdir(sample_file_path)
print(f"max worker {os.cpu_count()}")
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
try:
futures_to_args = {
executor.submit(exe_to_json, os.path.join(sample_file_path, file_name)): file_name for file_name in
sample_file_list
}
for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(futures_to_args)):
pass
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
result = list(tqdm(pool.imap(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list)))
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
# futures_to_args = {
# executor.submit(exe_to_json, os.path.join(sample_file_path, file_name)): file_name for file_name in
# sample_file_list
# }
# for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(futures_to_args)):
# pass
except Exception as exc:
logger.error('%r generated an exception: %s' % (futures_to_args[future], exc))
# test_file_path = '/mnt/d/bishe/exe2json/sample/VirusShare_0a3b625380161cf92c4bb10135326bb5'
# exe_to_json(test_file_path)

View File

@ -1,6 +1,7 @@
import logging
import os
"""
日志工具