乱序bert预测
This commit is contained in:
parent
bdde25a9ab
commit
7f1d7de95d
@ -27,7 +27,8 @@ config_file = os.path.join('./bert/bert.json')
|
|||||||
def load_model():
|
def load_model():
|
||||||
config = BertConfig.from_json_file(config_file)
|
config = BertConfig.from_json_file(config_file)
|
||||||
model = BertForPreTraining(config)
|
model = BertForPreTraining(config)
|
||||||
state_dict = torch.load(model_file)
|
# state_dict = torch.load(model_file)
|
||||||
|
state_dict = torch.load(model_file, map_location='cpu')
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -103,9 +104,9 @@ def generate_inst_vec(inst, method="mean"):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def bb2vec(inst):
|
def bb2vec(item):
|
||||||
tmp = generate_inst_vec(inst, method="mean")
|
tmp = generate_inst_vec(item['opcode'], method="mean")
|
||||||
return list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
|
return item['addr'], list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL
|
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF, THREAD_FULL
|
||||||
from bert.obtain_inst_vec import bb2vec
|
from bert.obtain_inst_vec import bb2vec
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
def addr2vec(base_file_path):
|
def addr2vec(base_file_path):
|
||||||
@ -13,32 +15,56 @@ def addr2vec(base_file_path):
|
|||||||
file_name = str(os.path.basename(base_file_path))
|
file_name = str(os.path.basename(base_file_path))
|
||||||
file_path = str(os.path.dirname(base_file_path))
|
file_path = str(os.path.dirname(base_file_path))
|
||||||
|
|
||||||
# 忽略已生成的文件
|
|
||||||
if os.path.exists(os.path.join(file_path, 'final', file_name)):
|
|
||||||
return
|
|
||||||
|
|
||||||
# 如果不是路径则开始转化
|
# 如果不是路径则开始转化
|
||||||
if file_name:
|
if file_name:
|
||||||
|
# 忽略已生成的文件
|
||||||
|
if os.path.exists(os.path.join(file_path, 'final', file_name)):
|
||||||
|
process_bar.update(1)
|
||||||
|
return
|
||||||
file_json = load_json(base_file_path)
|
file_json = load_json(base_file_path)
|
||||||
# 确保存在基础文件而不存在特征文件的情况
|
# 确保存在基础文件而不存在特征文件的情况
|
||||||
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
|
feature_json = load_json(os.path.join(file_path, 'feature', file_name)) if os.path.exists(
|
||||||
os.path.join(file_path, 'feature', file_name)) else None
|
os.path.join(file_path, 'feature', file_name)) else None
|
||||||
if feature_json is not None:
|
if feature_json is not None:
|
||||||
feature_set = {}
|
feature_set = {}
|
||||||
for item in feature_json:
|
|
||||||
feature_set[item['addr']] = bb2vec(item['opcode'])
|
# 多线程预测bert
|
||||||
|
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||||
|
# try:
|
||||||
|
# future_args = {
|
||||||
|
# executor.submit(bb2vec, item): item for item in feature_json
|
||||||
|
# }
|
||||||
|
# for future in concurrent.futures.as_completed(future_args):
|
||||||
|
# result = future.result()
|
||||||
|
# feature_set[result[0]] = result[1]
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"bert 解析出错 {file_name},{e}")
|
||||||
|
|
||||||
|
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||||
|
try:
|
||||||
|
results = pool.imap_unordered(bb2vec, [item for item in feature_json])
|
||||||
|
for result in results:
|
||||||
|
feature_set[result[0]] = result[1]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"bert 解析出错 {file_name},{e}")
|
||||||
|
|
||||||
|
try:
|
||||||
for item in file_json['acfg_list']:
|
for item in file_json['acfg_list']:
|
||||||
bb_feature_addr_list = item['block_features']
|
bb_feature_addr_list = item['block_features']
|
||||||
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
|
item['block_features'] = [feature_set[key] for key in bb_feature_addr_list]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"地址对应出错{file_name}, {e}")
|
||||||
save_json(os.path.join(file_path, 'final', file_name), file_json)
|
save_json(os.path.join(file_path, 'final', file_name), file_json)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f'文件{file_name}不存在特征文件')
|
logger.error(f'文件{file_name}不存在特征文件')
|
||||||
|
process_bar.update(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logger = setup_logger('feature2json', './log/feature2json.log')
|
logger = setup_logger('feature2json', './log/feature2json.log')
|
||||||
sample_type = 'malware'
|
sample_type = 'benign'
|
||||||
# json_path = os.path.join(f'./out/json/{sample_type}')
|
# json_path = os.path.join(f'./out/json/{sample_type}')
|
||||||
json_path = os.path.join(f'./out/json/{sample_type}')
|
json_path = os.path.join(f'./out/json/{sample_type}')
|
||||||
json_files = os.listdir(json_path)
|
json_files = os.listdir(json_path)
|
||||||
@ -48,10 +74,9 @@ if __name__ == '__main__':
|
|||||||
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||||
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
|
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
|
||||||
# total=len(json_files)))
|
# total=len(json_files)))
|
||||||
multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
|
# multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files if
|
||||||
|
# os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
|
||||||
|
process_bar = tqdm(total=len(json_files))
|
||||||
|
for json_file in json_files:
|
||||||
|
if os.path.isfile(os.path.join(json_path, json_file)):
|
||||||
# for json_file in json_files:
|
addr2vec(os.path.join(json_path, json_file))
|
||||||
# addr2vec(json_path, json_file)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user