多线程完善 特征整合
This commit is contained in:
parent
30ffaa2a46
commit
7c8145b52a
@ -14,6 +14,10 @@ model_file = os.path.join("./bert/pytorch_model.bin")
|
|||||||
tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json")
|
tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json")
|
||||||
config_file = os.path.join('./bert/bert.json')
|
config_file = os.path.join('./bert/bert.json')
|
||||||
|
|
||||||
|
|
||||||
|
# 禁用分词器多线程
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
# from my_data_collator import MyDataCollatorForPreTraining
|
# from my_data_collator import MyDataCollatorForPreTraining
|
||||||
# model_file = os.path.join("./pytorch_model.bin")
|
# model_file = os.path.join("./pytorch_model.bin")
|
||||||
# tokenizer_file = os.path.join("./tokenizer-inst.all.json")
|
# tokenizer_file = os.path.join("./tokenizer-inst.all.json")
|
||||||
@ -23,7 +27,7 @@ config_file = os.path.join('./bert/bert.json')
|
|||||||
def load_model():
|
def load_model():
|
||||||
config = BertConfig.from_json_file(config_file)
|
config = BertConfig.from_json_file(config_file)
|
||||||
model = BertForPreTraining(config)
|
model = BertForPreTraining(config)
|
||||||
state_dict = torch.load(model_file)
|
state_dict = torch.load(model_file, map_location='cpu')
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -101,7 +105,7 @@ def generate_inst_vec(inst, method="mean"):
|
|||||||
|
|
||||||
def bb2vec(inst):
|
def bb2vec(inst):
|
||||||
tmp = generate_inst_vec(inst, method="mean")
|
tmp = generate_inst_vec(inst, method="mean")
|
||||||
return list(np.mean(tmp.detach().numpy(), axis=0))
|
return list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
15
exe2json.py
15
exe2json.py
@ -11,11 +11,10 @@ import os
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
# 禁用分词器多线程
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
|
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
|
||||||
|
|
||||||
sample_type = 'benign'
|
sample_type = 'malware'
|
||||||
|
|
||||||
|
|
||||||
def extract_opcode(disasm_text):
|
def extract_opcode(disasm_text):
|
||||||
@ -144,7 +143,7 @@ def get_graph_cfg_r2pipe(r2pipe_open, file_path, output_path, file_name):
|
|||||||
# block_feature_list = []
|
# block_feature_list = []
|
||||||
|
|
||||||
# 暂时将bb地址作为特征 后续将操作码集中转化为特征
|
# 暂时将bb地址作为特征 后续将操作码集中转化为特征
|
||||||
block_feature_list = block_addr
|
block_feature_list.append(block_addr)
|
||||||
acfg_feature_item.append({'addr': block_addr, 'opcode': block_Statement})
|
acfg_feature_item.append({'addr': block_addr, 'opcode': block_Statement})
|
||||||
|
|
||||||
# 过滤不存在的边
|
# 过滤不存在的边
|
||||||
@ -214,9 +213,9 @@ def exe_to_json(file_path):
|
|||||||
output_path = f"./out/json/{sample_type}"
|
output_path = f"./out/json/{sample_type}"
|
||||||
file_fingerprint = calc_sha256(file_path)
|
file_fingerprint = calc_sha256(file_path)
|
||||||
if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')):
|
if os.path.exists(os.path.join(output_path, file_fingerprint + '.jsonl')):
|
||||||
logger.info(f"二进制可执行文件已解析,文件名{file_path}")
|
# logger.info(f"二进制可执行文件已解析,文件名{file_path}")
|
||||||
return
|
return
|
||||||
logger.info(f"开始解析,文件名{file_path}")
|
# logger.info(f"开始解析,文件名{file_path}")
|
||||||
|
|
||||||
# 获取r2pipe并解析文件 解析完即释放r2
|
# 获取r2pipe并解析文件 解析完即释放r2
|
||||||
r2 = r2pipe.open(file_path, flags=['-2'])
|
r2 = r2pipe.open(file_path, flags=['-2'])
|
||||||
@ -248,7 +247,7 @@ def exe_to_json(file_path):
|
|||||||
# json写入
|
# json写入
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj)
|
save_json(os.path.join(output_path, file_fingerprint + '.jsonl'), json_obj)
|
||||||
logger.info(f"解析完成,文件名{file_path}")
|
# logger.info(f"解析完成,文件名{file_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@ -258,7 +257,7 @@ if __name__ == '__main__':
|
|||||||
sample_file_list = os.listdir(sample_file_path)
|
sample_file_list = os.listdir(sample_file_path)
|
||||||
print(f"max worker {os.cpu_count()}")
|
print(f"max worker {os.cpu_count()}")
|
||||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||||
result = list(tqdm(pool.imap(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list)))
|
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list]), total=len(sample_file_list)))
|
||||||
|
|
||||||
|
|
||||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||||
|
50
json_feature2json.py
Normal file
50
json_feature2json.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
from my_utils import save_json, load_json, setup_logger
|
||||||
|
from bert.obtain_inst_vec import bb2vec
|
||||||
|
import multiprocessing
|
||||||
|
from tqdm import tqdm
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
# 忽略输出torch cpu执行警告
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def addr2vec(base_file_path):
|
||||||
|
# 从路径拆分文件名与路径
|
||||||
|
file_name = str(os.path.basename(base_file_path))
|
||||||
|
file_path = str(os.path.dirname(base_file_path))
|
||||||
|
|
||||||
|
# 忽略已生成的文件
|
||||||
|
if os.path.exists(os.path.join(file_path, 'final', file_name)):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果不是路径则开始转化
|
||||||
|
if file_name:
|
||||||
|
file_json = load_json(base_file_path)
|
||||||
|
# 确保存在基础文件而不存在特征文件的情况
|
||||||
|
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
|
||||||
|
os.path.join(file_path, 'feature', file_name)) else None
|
||||||
|
if feature_json is not None:
|
||||||
|
feature_set = {}
|
||||||
|
for item in feature_json:
|
||||||
|
feature_set[item['addr']] = bb2vec(item['opcode'])
|
||||||
|
for item in file_json['acfg_list']:
|
||||||
|
bb_feature_addr_list = item['block_features']
|
||||||
|
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
|
||||||
|
save_json(os.path.join(file_path, 'final', file_name), file_json)
|
||||||
|
else:
|
||||||
|
logger.error(f'文件{file_name}不存在特征文件')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logger = setup_logger('feature2json', './log/feature2json.log')
|
||||||
|
sample_type = 'benign'
|
||||||
|
json_path = os.path.join(f'./out/json/{sample_type}')
|
||||||
|
json_files = os.listdir(json_path)
|
||||||
|
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||||
|
result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]),
|
||||||
|
total=len(json_files)))
|
||||||
|
|
||||||
|
# for json_file in json_files:
|
||||||
|
# addr2vec(json_path, json_file)
|
20
my_utils.py
20
my_utils.py
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
"""
|
"""
|
||||||
日志工具
|
日志工具
|
||||||
@ -45,7 +45,7 @@ THREAD_FULL = os.cpu_count()
|
|||||||
THREAD_HALF = int(os.cpu_count() / 2)
|
THREAD_HALF = int(os.cpu_count() / 2)
|
||||||
|
|
||||||
|
|
||||||
def multi_thread(func, args, thread_num=THREAD_FULL):
|
def multi_thread_order(func, args, thread_num=THREAD_FULL):
|
||||||
"""
|
"""
|
||||||
多线程执行函数
|
多线程执行函数
|
||||||
:param func: 函数
|
:param func: 函数
|
||||||
@ -73,10 +73,22 @@ def multi_thread(func, args, thread_num=THREAD_FULL):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def save_json(filename, data):
|
def multi_thread_disorder(func, thread_num=THREAD_FULL, **args):
|
||||||
import json
|
import multiprocessing
|
||||||
|
from tqdm import tqdm
|
||||||
|
with multiprocessing.Pool(processes=thread_num) as pool:
|
||||||
|
list(tqdm(pool.imap_unordered(func, args), total=len(args)))
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(filename, data):
|
||||||
data = json.dumps(data)
|
data = json.dumps(data)
|
||||||
file = open(filename, 'w')
|
file = open(filename, 'w')
|
||||||
file.write(data)
|
file.write(data)
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(filename):
|
||||||
|
file = open(filename, 'r')
|
||||||
|
data = json.loads(file.read())
|
||||||
|
file.close()
|
||||||
|
return data
|
||||||
|
Loading…
Reference in New Issue
Block a user