多线程完善 特征整合

This commit is contained in:
huihun 2024-04-17 15:54:00 +08:00
parent 30ffaa2a46
commit 7c8145b52a
4 changed files with 79 additions and 14 deletions

View File

@ -14,6 +14,10 @@ model_file = os.path.join("./bert/pytorch_model.bin")
tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json") tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json")
config_file = os.path.join('./bert/bert.json') config_file = os.path.join('./bert/bert.json')
# 禁用分词器多线程
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# from my_data_collator import MyDataCollatorForPreTraining # from my_data_collator import MyDataCollatorForPreTraining
# model_file = os.path.join("./pytorch_model.bin") # model_file = os.path.join("./pytorch_model.bin")
# tokenizer_file = os.path.join("./tokenizer-inst.all.json") # tokenizer_file = os.path.join("./tokenizer-inst.all.json")
@ -23,7 +27,7 @@ config_file = os.path.join('./bert/bert.json')
def load_model(): def load_model():
config = BertConfig.from_json_file(config_file) config = BertConfig.from_json_file(config_file)
model = BertForPreTraining(config) model = BertForPreTraining(config)
state_dict = torch.load(model_file) state_dict = torch.load(model_file, map_location='cpu')
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
@ -101,7 +105,7 @@ def generate_inst_vec(inst, method="mean"):
def bb2vec(inst): def bb2vec(inst):
tmp = generate_inst_vec(inst, method="mean") tmp = generate_inst_vec(inst, method="mean")
return list(np.mean(tmp.detach().numpy(), axis=0)) return list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,11 +11,10 @@ import os
import concurrent.futures import concurrent.futures
import multiprocessing import multiprocessing
# 禁用分词器多线程
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"] ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
sample_type = 'benign' sample_type = 'malware'
def extract_opcode(disasm_text): def extract_opcode(disasm_text):
@ -144,7 +143,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name):
# block_feature_list = [] # block_feature_list = []
# 暂时将bb地址作为特征 后续将操作码集中转化为特征 # 暂时将bb地址作为特征 后续将操作码集中转化为特征
block_feature_list = block_addr block_feature_list.append(block_addr)
acfg_feature_item.append({'addr': block_addr, 'opcode': block_Statement}) acfg_feature_item.append({'addr': block_addr, 'opcode': block_Statement})
# 过滤不存在的边 # 过滤不存在的边
@ -214,9 +213,9 @@ def exe_to_json(file_path):
output_path = f"./out/json/{sample_type}" output_path = f"./out/json/{sample_type}"
file_fingerprint = calc_sha256(file_path) file_fingerprint = calc_sha256(file_path)
if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')): if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')):
logger.info(f"二进制可执行文件已解析,文件名{file_path}") # logger.info(f"二进制可执行文件已解析,文件名{file_path}")
return return
logger.info(f"开始解析,文件名{file_path}") # logger.info(f"开始解析,文件名{file_path}")
# 获取r2pipe并解析文件 解析完即释放r2 # 获取r2pipe并解析文件 解析完即释放r2
r2 = r2pipe.open(file_path, flags=['-2']) r2 = r2pipe.open(file_path, flags=['-2'])
@ -248,7 +247,7 @@ def exe_to_json(file_path):
# json写入 # json写入
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj) save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj)
logger.info(f"解析完成,文件名{file_path}") # logger.info(f"解析完成,文件名{file_path}")
return return
@ -258,7 +257,7 @@ if __name__ == '__main__':
sample_file_list = os.listdir(sample_file_path) sample_file_list = os.listdir(sample_file_path)
print(f"max worker {os.cpu_count()}") print(f"max worker {os.cpu_count()}")
with multiprocessing.Pool(processes=os.cpu_count()) as pool: with multiprocessing.Pool(processes=os.cpu_count()) as pool:
result = list(tqdm(pool.imap(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list))) result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list)))
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:

50
json_feature2json.py Normal file
View File

@ -0,0 +1,50 @@
import os
from my_utils import save_json, load_json, setup_logger
from bert.obtain_inst_vec import bb2vec
import multiprocessing
from tqdm import tqdm
import warnings
# 忽略输出torch cpu执行警告
warnings.filterwarnings('ignore', category=UserWarning)
def addr2vec(base_file_path):
# 从路径拆分文件名与路径
file_name = str(os.path.basename(base_file_path))
file_path = str(os.path.dirname(base_file_path))
# 忽略已生成的文件
if os.path.exists(os.path.join(file_path, 'final', file_name)):
return
# 如果不是路径则开始转化
if file_name:
file_json = load_json(base_file_path)
# 确保存在基础文件而不存在特征文件的情况
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
os.path.join(file_path, 'feature', file_name)) else None
if feature_json is not None:
feature_set = {}
for item in feature_json:
feature_set[item['addr']] = bb2vec(item['opcode'])
for item in file_json['acfg_list']:
bb_feature_addr_list = item['block_features']
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
save_json(os.path.join(file_path, 'final', file_name), file_json)
else:
logger.error(f'文件{file_name}不存在特征文件')
return
if __name__ == '__main__':
logger = setup_logger('feature2json', './log/feature2json.log')
sample_type = 'benign'
json_path = os.path.join(f'./out/json/{sample_type}')
json_files = os.listdir(json_path)
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]),
total=len(json_files)))
# for json_file in json_files:
# addr2vec(json_path, json_file)

View File

@ -1,6 +1,6 @@
import logging import logging
import os import os
import json
""" """
日志工具 日志工具
@ -45,7 +45,7 @@ THREAD_FULL = os.cpu_count()
THREAD_HALF = int(os.cpu_count() / 2) THREAD_HALF = int(os.cpu_count() / 2)
def multi_thread(func, args, thread_num=THREAD_FULL): def multi_thread_order(func, args, thread_num=THREAD_FULL):
""" """
多线程执行函数 多线程执行函数
:param func: 函数 :param func: 函数
@ -73,10 +73,22 @@ def multi_thread(func, args, thread_num=THREAD_FULL):
return result return result
def save_json(filename, data): def multi_thread_disorder(func, thread_num=THREAD_FULL, **args):
import json import multiprocessing
from tqdm import tqdm
with multiprocessing.Pool(processes=thread_num) as pool:
list(tqdm(pool.imap_unordered(func, args), total=len(args)))
def save_json(filename, data):
data = json.dumps(data) data = json.dumps(data)
file = open(filename, 'w') file = open(filename, 'w')
file.write(data) file.write(data)
file.close() file.close()
def load_json(filename):
file = open(filename, 'r')
data = json.loads(file.read())
file.close()
return data