多线程完善 特征整合
This commit is contained in:
parent
30ffaa2a46
commit
7c8145b52a
@ -14,6 +14,10 @@ model_file = os.path.join("./bert/pytorch_model.bin")
|
||||
tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json")
|
||||
config_file = os.path.join('./bert/bert.json')
|
||||
|
||||
|
||||
# 禁用分词器多线程
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# from my_data_collator import MyDataCollatorForPreTraining
|
||||
# model_file = os.path.join("./pytorch_model.bin")
|
||||
# tokenizer_file = os.path.join("./tokenizer-inst.all.json")
|
||||
@ -23,7 +27,7 @@ config_file = os.path.join('./bert/bert.json')
|
||||
def load_model():
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
model = BertForPreTraining(config)
|
||||
state_dict = torch.load(model_file)
|
||||
state_dict = torch.load(model_file, map_location='cpu')
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
@ -101,7 +105,7 @@ def generate_inst_vec(inst, method="mean"):
|
||||
|
||||
def bb2vec(inst):
|
||||
tmp = generate_inst_vec(inst, method="mean")
|
||||
return list(np.mean(tmp.detach().numpy(), axis=0))
|
||||
return list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
15
exe2json.py
15
exe2json.py
@ -11,11 +11,10 @@ import os
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
|
||||
# 禁用分词器多线程
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
|
||||
|
||||
sample_type = 'benign'
|
||||
sample_type = 'malware'
|
||||
|
||||
|
||||
def extract_opcode(disasm_text):
|
||||
@ -144,7 +143,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name):
|
||||
# block_feature_list = []
|
||||
|
||||
# 暂时将bb地址作为特征 后续将操作码集中转化为特征
|
||||
block_feature_list = block_addr
|
||||
block_feature_list.append(block_addr)
|
||||
acfg_feature_item.append({'addr': block_addr, 'opcode': block_Statement})
|
||||
|
||||
# 过滤不存在的边
|
||||
@ -214,9 +213,9 @@ def exe_to_json(file_path):
|
||||
output_path = f"./out/json/{sample_type}"
|
||||
file_fingerprint = calc_sha256(file_path)
|
||||
if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')):
|
||||
logger.info(f"二进制可执行文件已解析,文件名{file_path}")
|
||||
# logger.info(f"二进制可执行文件已解析,文件名{file_path}")
|
||||
return
|
||||
logger.info(f"开始解析,文件名{file_path}")
|
||||
# logger.info(f"开始解析,文件名{file_path}")
|
||||
|
||||
# 获取r2pipe并解析文件 解析完即释放r2
|
||||
r2 = r2pipe.open(file_path, flags=['-2'])
|
||||
@ -248,7 +247,7 @@ def exe_to_json(file_path):
|
||||
# json写入
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj)
|
||||
logger.info(f"解析完成,文件名{file_path}")
|
||||
# logger.info(f"解析完成,文件名{file_path}")
|
||||
return
|
||||
|
||||
|
||||
@ -258,7 +257,7 @@ if __name__ == '__main__':
|
||||
sample_file_list = os.listdir(sample_file_path)
|
||||
print(f"max worker {os.cpu_count()}")
|
||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||
result = list(tqdm(pool.imap(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list)))
|
||||
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list)))
|
||||
|
||||
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||
|
50
json_feature2json.py
Normal file
50
json_feature2json.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
from my_utils import save_json, load_json, setup_logger
|
||||
from bert.obtain_inst_vec import bb2vec
|
||||
import multiprocessing
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
|
||||
# 忽略输出torch cpu执行警告
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def addr2vec(base_file_path):
|
||||
# 从路径拆分文件名与路径
|
||||
file_name = str(os.path.basename(base_file_path))
|
||||
file_path = str(os.path.dirname(base_file_path))
|
||||
|
||||
# 忽略已生成的文件
|
||||
if os.path.exists(os.path.join(file_path, 'final', file_name)):
|
||||
return
|
||||
|
||||
# 如果不是路径则开始转化
|
||||
if file_name:
|
||||
file_json = load_json(base_file_path)
|
||||
# 确保存在基础文件而不存在特征文件的情况
|
||||
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
|
||||
os.path.join(file_path, 'feature', file_name)) else None
|
||||
if feature_json is not None:
|
||||
feature_set = {}
|
||||
for item in feature_json:
|
||||
feature_set[item['addr']] = bb2vec(item['opcode'])
|
||||
for item in file_json['acfg_list']:
|
||||
bb_feature_addr_list = item['block_features']
|
||||
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
|
||||
save_json(os.path.join(file_path, 'final', file_name), file_json)
|
||||
else:
|
||||
logger.error(f'文件{file_name}不存在特征文件')
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger = setup_logger('feature2json', './log/feature2json.log')
|
||||
sample_type = 'benign'
|
||||
json_path = os.path.join(f'./out/json/{sample_type}')
|
||||
json_files = os.listdir(json_path)
|
||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||
result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]),
|
||||
total=len(json_files)))
|
||||
|
||||
# for json_file in json_files:
|
||||
# addr2vec(json_path, json_file)
|
20
my_utils.py
20
my_utils.py
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import json
|
||||
|
||||
"""
|
||||
日志工具
|
||||
@ -45,7 +45,7 @@ THREAD_FULL = os.cpu_count()
|
||||
THREAD_HALF = int(os.cpu_count() / 2)
|
||||
|
||||
|
||||
def multi_thread(func, args, thread_num=THREAD_FULL):
|
||||
def multi_thread_order(func, args, thread_num=THREAD_FULL):
|
||||
"""
|
||||
多线程执行函数
|
||||
:param func: 函数
|
||||
@ -73,10 +73,22 @@ def multi_thread(func, args, thread_num=THREAD_FULL):
|
||||
return result
|
||||
|
||||
|
||||
def save_json(filename, data):
|
||||
import json
|
||||
def multi_thread_disorder(func, thread_num=THREAD_FULL, **args):
|
||||
import multiprocessing
|
||||
from tqdm import tqdm
|
||||
with multiprocessing.Pool(processes=thread_num) as pool:
|
||||
list(tqdm(pool.imap_unordered(func, args), total=len(args)))
|
||||
|
||||
|
||||
def save_json(filename, data):
|
||||
data = json.dumps(data)
|
||||
file = open(filename, 'w')
|
||||
file.write(data)
|
||||
file.close()
|
||||
|
||||
|
||||
def load_json(filename):
|
||||
file = open(filename, 'r')
|
||||
data = json.loads(file.read())
|
||||
file.close()
|
||||
return data
|
||||
|
Loading…
Reference in New Issue
Block a user