乱序bert预测

This commit is contained in:
huihun 2024-04-20 13:20:21 +08:00
parent bdde25a9ab
commit 7f1d7de95d
2 changed files with 48 additions and 22 deletions

View File

@ -27,7 +27,8 @@ config_file = os.path.join('./bert/bert.json')
def load_model(): def load_model():
config = BertConfig.from_json_file(config_file) config = BertConfig.from_json_file(config_file)
model = BertForPreTraining(config) model = BertForPreTraining(config)
state_dict = torch.load(model_file) # state_dict = torch.load(model_file)
state_dict = torch.load(model_file, map_location='cpu')
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
@ -103,9 +104,9 @@ def generate_inst_vec(inst, method="mean"):
return result return result
def bb2vec(inst): def bb2vec(item):
tmp = generate_inst_vec(inst, method="mean") tmp = generate_inst_vec(item['opcode'], method="mean")
return list(np.mean(tmp.detach().numpy(), axis=0).astype(float)) return item['addr'], list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,11 +1,13 @@
import os import os
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF, THREAD_FULL
from bert.obtain_inst_vec import bb2vec from bert.obtain_inst_vec import bb2vec
import multiprocessing import multiprocessing
from tqdm import tqdm from tqdm import tqdm
import warnings import warnings
from datetime import datetime from datetime import datetime
import concurrent.futures
warnings.filterwarnings("ignore")
def addr2vec(base_file_path): def addr2vec(base_file_path):
@ -13,32 +15,56 @@ def addr2vec(base_file_path):
file_name = str(os.path.basename(base_file_path)) file_name = str(os.path.basename(base_file_path))
file_path = str(os.path.dirname(base_file_path)) file_path = str(os.path.dirname(base_file_path))
# 忽略已生成的文件
if os.path.exists(os.path.join(file_path, 'final', file_name)):
return
# 如果不是路径则开始转化 # 如果不是路径则开始转化
if file_name: if file_name:
# 忽略已生成的文件
if os.path.exists(os.path.join(file_path, 'final', file_name)):
process_bar.update(1)
return
file_json = load_json(base_file_path) file_json = load_json(base_file_path)
# 确保存在基础文件而不存在特征文件的情况 # 确保存在基础文件而不存在特征文件的情况
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists( feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
os.path.join(file_path, 'feature', file_name)) else None os.path.join(file_path, 'feature', file_name)) else None
if feature_json is not None: if feature_json is not None:
feature_set = {} feature_set = {}
for item in feature_json:
feature_set[item['addr']] = bb2vec(item['opcode']) # 多线程预测bert
for item in file_json['acfg_list']: # with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
bb_feature_addr_list = item['block_features'] # try:
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list] # future_args = {
# executor.submit(bb2vec, item): item for item in feature_json
# }
# for future in concurrent.futures.as_completed(future_args):
# result = future.result()
# feature_set[result[0]] = result[1]
# except Exception as e:
# logger.error(f"bert 解析出错 {file_name}{e}")
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
try:
results = pool.imap_unordered(bb2vec, [item for item in feature_json])
for result in results:
feature_set[result[0]] = result[1]
except Exception as e:
logger.error(f"bert 解析出错 {file_name}{e}")
try:
for item in file_json['acfg_list']:
bb_feature_addr_list = item['block_features']
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
except Exception as e:
logger.error(f"地址对应出错{file_name}, {e}")
save_json(os.path.join(file_path, 'final', file_name), file_json) save_json(os.path.join(file_path, 'final', file_name), file_json)
else: else:
logger.error(f'文件{file_name}不存在特征文件') logger.error(f'文件{file_name}不存在特征文件')
process_bar.update(1)
return return
if __name__ == '__main__': if __name__ == '__main__':
logger = setup_logger('feature2json', './log/feature2json.log') logger = setup_logger('feature2json', './log/feature2json.log')
sample_type = 'malware' sample_type = 'benign'
# json_path = os.path.join(f'./out/json/{sample_type}') # json_path = os.path.join(f'./out/json/{sample_type}')
json_path = os.path.join(f'./out/json/{sample_type}') json_path = os.path.join(f'./out/json/{sample_type}')
json_files = os.listdir(json_path) json_files = os.listdir(json_path)
@ -48,10 +74,9 @@ if __name__ == '__main__':
# with multiprocessing.Pool(processes=os.cpu_count()) as pool: # with multiprocessing.Pool(processes=os.cpu_count()) as pool:
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]), # result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
# total=len(json_files))) # total=len(json_files)))
multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL) # multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files if
# os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
process_bar = tqdm(total=len(json_files))
for json_file in json_files:
if os.path.isfile(os.path.join(json_path, json_file)):
# for json_file in json_files: addr2vec(os.path.join(json_path, json_file))
# addr2vec(json_path, json_file)