2024-04-21 21:49:12 +08:00
import concurrent . futures
2024-04-17 15:54:00 +08:00
import os
2024-04-21 21:49:12 +08:00
from my_utils import save_json , load_json , setup_logger
2024-04-17 15:54:00 +08:00
from bert . obtain_inst_vec import bb2vec
import multiprocessing
from tqdm import tqdm
import warnings
2024-04-18 16:44:31 +08:00
from datetime import datetime
2024-04-17 15:54:00 +08:00
2024-04-20 13:20:21 +08:00
warnings . filterwarnings ( " ignore " )
2024-04-17 15:54:00 +08:00
2024-04-20 13:39:14 +08:00
def addr2vec ( base_file_path , index ) :
2024-04-17 15:54:00 +08:00
# 从路径拆分文件名与路径
file_name = str ( os . path . basename ( base_file_path ) )
file_path = str ( os . path . dirname ( base_file_path ) )
# 如果不是路径则开始转化
if file_name :
2024-04-21 21:49:12 +08:00
# 无操作码标志位
none_opcode_flag = False
2024-04-20 13:20:21 +08:00
# 忽略已生成的文件
if os . path . exists ( os . path . join ( file_path , ' final ' , file_name ) ) :
return
2024-04-17 15:54:00 +08:00
file_json = load_json ( base_file_path )
# 确保存在基础文件而不存在特征文件的情况
feature_json = load_json ( os . path . join ( file_path , ' feature ' , file_name ) ) if os . path . exists (
os . path . join ( file_path , ' feature ' , file_name ) ) else None
if feature_json is not None :
2024-04-21 21:49:12 +08:00
# 如果出现无操作码的情况,直接跳过文件
for item in feature_json :
if len ( item [ ' opcode ' ] ) == 0 :
logger . error ( f " 基础块无操作码 { file_name } ,地址 { item [ ' addr ' ] } " )
none_opcode_flag = True
if none_opcode_flag :
2024-04-20 18:20:20 +08:00
return
2024-04-21 21:49:12 +08:00
# 对于长度过长的文件先不处理
# if len(feature_json) > 10000:
# data = {
# 'file_name': file_name,
# 'feature_len': len(feature_json)
# }
# continuation_json(os.path.join(f'./out/json/too_long_{sample_type}.json'), data)
# return
2024-04-20 13:20:21 +08:00
# 多线程预测bert
2024-04-20 18:20:20 +08:00
feature_set = { }
2024-04-21 21:49:12 +08:00
with multiprocessing . Pool ( processes = 4 ) as pool :
try :
results = list ( tqdm ( pool . imap_unordered ( bb2vec , [ item for item in feature_json ] ) ,
total = len ( feature_json ) ,
desc = f ' { file_name } Progress: { index } / { json_files_len } ' ,
2024-04-22 21:12:58 +08:00
ascii = True ,
2024-04-21 21:49:12 +08:00
leave = False ,
dynamic_ncols = True ,
position = 1 ) )
for result in results :
if result [ 0 ] :
feature_set [ result [ 1 ] ] = result [ 2 ]
else :
logger . error ( f " bert解析出错 { file_name } ,地址 { result [ 1 ] } ,操作码 { result [ 2 ] } ,报错 { result [ 3 ] } " )
return
except Exception as e :
logger . error ( f " 多线程解析出错: { file_name } ,报错 { e } " )
return
2024-04-20 18:20:20 +08:00
# debug
2024-04-21 21:49:12 +08:00
# try:
# for index, feature in tqdm(enumerate(feature_json), total=len(feature_json)):
# addr, feature = bb2vec(feature)
# feature_set[addr] = feature
# except Exception as e:
# print(index)
# print(e)
# print(feature['opcode'])
2024-04-20 13:20:21 +08:00
try :
for item in file_json [ ' acfg_list ' ] :
bb_feature_addr_list = item [ ' block_features ' ]
item [ ' block_features ' ] = [ feature_set [ key ] for key in bb_feature_addr_list ]
except Exception as e :
logger . error ( f " 地址对应出错 { file_name } , { e } " )
2024-04-20 18:20:20 +08:00
return
2024-04-17 15:54:00 +08:00
save_json ( os . path . join ( file_path , ' final ' , file_name ) , file_json )
2024-04-20 13:20:21 +08:00
2024-04-17 15:54:00 +08:00
else :
logger . error ( f ' 文件 { file_name } 不存在特征文件 ' )
return
if __name__ == ' __main__ ' :
2024-04-21 21:49:12 +08:00
logger = setup_logger ( ' feature2json ' , ' ./log/feature2json.log ' , reset = True )
2024-04-20 13:20:21 +08:00
sample_type = ' benign '
2024-04-18 16:44:31 +08:00
# json_path = os.path.join(f'./out/json/{sample_type}')
2024-04-17 15:54:00 +08:00
json_path = os . path . join ( f ' ./out/json/ { sample_type } ' )
2024-04-21 21:49:12 +08:00
json_files = os . listdir ( json_path )
# json_files = ['1710ae16c54ac149f353ba58e752ba7069f88326e6b71107598283bd0fffcbd6.jsonl']
2024-04-20 13:39:14 +08:00
json_files_len = len ( json_files )
2024-04-18 16:44:31 +08:00
now = datetime . now ( )
formatted_now = now . strftime ( " % Y- % m- %d % H: % M: % S " )
print ( " start time: " , formatted_now )
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
# total=len(json_files)))
2024-04-20 13:20:21 +08:00
# multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files if
# os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
2024-04-21 21:49:12 +08:00
# with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
# tqdm_position = 1
# try:
# future_args = {
# executor.submit(addr2vec, os.path.join(json_path, file), index, tqdm_position)
# for index, file in enumerate(json_files)
# }
# for future in tqdm(concurrent.futures.as_completed(future_args),
# total=len(json_files),
# desc='Total:',
# position=0
# ):
# tqdm_position += 1
# except Exception as e:
# print(e)
for index , json_file in tqdm ( enumerate ( json_files ) ,
total = len ( json_files ) ,
2024-04-22 21:12:58 +08:00
ascii = True ,
2024-04-21 21:49:12 +08:00
desc = ' Total: ' ,
position = 0 ) :
2024-04-20 13:20:21 +08:00
if os . path . isfile ( os . path . join ( json_path , json_file ) ) :
2024-04-20 13:39:14 +08:00
addr2vec ( os . path . join ( json_path , json_file ) , index )