94 lines
4.1 KiB
Python
94 lines
4.1 KiB
Python
import os
|
||
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF, THREAD_FULL, continuation_json
|
||
from bert.obtain_inst_vec import bb2vec
|
||
import multiprocessing
|
||
from tqdm import tqdm
|
||
import warnings
|
||
from datetime import datetime
|
||
import concurrent.futures
|
||
|
||
warnings.filterwarnings("ignore")
|
||
|
||
|
||
def addr2vec(base_file_path, index):
|
||
# 从路径拆分文件名与路径
|
||
file_name = str(os.path.basename(base_file_path))
|
||
file_path = str(os.path.dirname(base_file_path))
|
||
|
||
# 如果不是路径则开始转化
|
||
if file_name:
|
||
# 忽略已生成的文件
|
||
if os.path.exists(os.path.join(file_path, 'final', file_name)):
|
||
return
|
||
file_json = load_json(base_file_path)
|
||
# 确保存在基础文件而不存在特征文件的情况
|
||
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
|
||
os.path.join(file_path, 'feature', file_name)) else None
|
||
if feature_json is not None:
|
||
|
||
# 对于长度过长的文件先不处理
|
||
if len(feature_json) > 10000:
|
||
data = {
|
||
'file_name': file_name,
|
||
'feature_len': len(feature_json)
|
||
}
|
||
continuation_json(os.path.join(f'./out/json/too_long_{sample_type}.json'), data)
|
||
return
|
||
|
||
# 多线程预测bert
|
||
feature_set = {}
|
||
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||
# try:
|
||
# results = list(tqdm(pool.imap_unordered(bb2vec, [item for item in feature_json]),
|
||
# total=len(feature_json),
|
||
# desc=f'{file_name} Progress:{index}/{json_files_len} ',
|
||
# leave=True,
|
||
# dynamic_ncols=True))
|
||
# for result in results:
|
||
# feature_set[result[0]] = result[1]
|
||
# except Exception as e:
|
||
# logger.error(f"bert 解析出错 {file_name},{e}")
|
||
# debug
|
||
try:
|
||
for index, feature in tqdm(enumerate(feature_json), total=len(feature_json)):
|
||
addr, feature = bb2vec(feature)
|
||
feature_set[addr] = feature
|
||
except Exception as e:
|
||
print(index)
|
||
print(e)
|
||
print(feature['opcode'])
|
||
|
||
try:
|
||
for item in file_json['acfg_list']:
|
||
bb_feature_addr_list = item['block_features']
|
||
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
|
||
except Exception as e:
|
||
logger.error(f"地址对应出错{file_name}, {e}")
|
||
return
|
||
save_json(os.path.join(file_path, 'final', file_name), file_json)
|
||
|
||
else:
|
||
logger.error(f'文件{file_name}不存在特征文件')
|
||
return
|
||
|
||
|
||
if __name__ == '__main__':
|
||
logger = setup_logger('feature2json', './log/feature2json.log')
|
||
sample_type = 'benign'
|
||
# json_path = os.path.join(f'./out/json/{sample_type}')
|
||
json_path = os.path.join(f'./out/json/{sample_type}')
|
||
# json_files = os.listdir(json_path)
|
||
json_files = ['1710ae16c54ac149f353ba58e752ba7069f88326e6b71107598283bd0fffcbd6.jsonl']
|
||
json_files_len = len(json_files)
|
||
now = datetime.now()
|
||
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
|
||
print("start time:", formatted_now)
|
||
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
|
||
# total=len(json_files)))
|
||
# multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files if
|
||
# os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
|
||
for index, json_file in enumerate(json_files):
|
||
if os.path.isfile(os.path.join(json_path, json_file)):
|
||
addr2vec(os.path.join(json_path, json_file), index)
|