线程池改进
This commit is contained in:
parent
795e5f050e
commit
bdde25a9ab
@ -16,7 +16,7 @@ config_file = os.path.join('./bert/bert.json')
|
||||
|
||||
|
||||
# 禁用分词器多线程
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# from my_data_collator import MyDataCollatorForPreTraining
|
||||
# model_file = os.path.join("./pytorch_model.bin")
|
||||
@ -27,7 +27,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
def load_model():
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
model = BertForPreTraining(config)
|
||||
state_dict = torch.load(model_file, map_location='cpu')
|
||||
state_dict = torch.load(model_file)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import random
|
||||
|
||||
import r2pipe
|
||||
import hashlib
|
||||
from my_utils import *
|
||||
@ -14,7 +16,7 @@ import multiprocessing
|
||||
|
||||
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
|
||||
|
||||
sample_type = 'benign'
|
||||
sample_type = 'malware'
|
||||
|
||||
|
||||
def extract_opcode(disasm_text):
|
||||
@ -252,11 +254,11 @@ def exe_to_json(file_path):
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger = init_logging()
|
||||
sample_file_path = f"/mnt/d/bishe/sample_{sample_type}"
|
||||
sample_file_path = f"/mnt/d/bishe/dataset/sample_{sample_type}"
|
||||
sample_file_list = os.listdir(sample_file_path)
|
||||
print(f"max worker {os.cpu_count()}")
|
||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[::-1]]), total=len(sample_file_list)))
|
||||
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[:10000]]), total=10000))
|
||||
|
||||
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||
|
@ -1,12 +1,11 @@
|
||||
import os
|
||||
from my_utils import save_json, load_json, setup_logger
|
||||
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL
|
||||
from bert.obtain_inst_vec import bb2vec
|
||||
import multiprocessing
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
# 忽略输出torch cpu执行警告
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def addr2vec(base_file_path):
|
||||
@ -39,12 +38,20 @@ def addr2vec(base_file_path):
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger = setup_logger('feature2json', './log/feature2json.log')
|
||||
sample_type = 'benign'
|
||||
sample_type = 'malware'
|
||||
# json_path = os.path.join(f'./out/json/{sample_type}')
|
||||
json_path = os.path.join(f'./out/json/{sample_type}')
|
||||
json_files = os.listdir(json_path)
|
||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||
result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]),
|
||||
total=len(json_files)))
|
||||
now = datetime.now()
|
||||
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
print("start time:", formatted_now)
|
||||
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
|
||||
# total=len(json_files)))
|
||||
multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
|
||||
|
||||
|
||||
|
||||
|
||||
# for json_file in json_files:
|
||||
# addr2vec(json_path, json_file)
|
||||
|
Loading…
Reference in New Issue
Block a user