线程池改进

This commit is contained in:
huihun 2024-04-18 16:44:31 +08:00
parent 795e5f050e
commit bdde25a9ab
3 changed files with 21 additions and 12 deletions

View File

@ -16,7 +16,7 @@ config_file = os.path.join('./bert/bert.json')
# 禁用分词器多线程
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# from my_data_collator import MyDataCollatorForPreTraining
# model_file = os.path.join("./pytorch_model.bin")
@ -27,7 +27,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
def load_model():
config = BertConfig.from_json_file(config_file)
model = BertForPreTraining(config)
state_dict = torch.load(model_file, map_location='cpu')
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.eval()

View File

@ -1,3 +1,5 @@
import random
import r2pipe
import hashlib
from my_utils import *
@ -14,7 +16,7 @@ import multiprocessing
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
sample_type = 'benign'
sample_type = 'malware'
def extract_opcode(disasm_text):
@ -252,11 +254,11 @@ def exe_to_json(file_path):
if __name__ == '__main__':
logger = init_logging()
sample_file_path = f"/mnt/d/bishe/sample_{sample_type}"
sample_file_path = f"/mnt/d/bishe/dataset/sample_{sample_type}"
sample_file_list = os.listdir(sample_file_path)
print(f"max worker {os.cpu_count()}")
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[::-1]]), total=len(sample_file_list)))
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[:10000]]), total=10000))
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:

View File

@ -1,12 +1,11 @@
import os
from my_utils import save_json, load_json, setup_logger
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL
from bert.obtain_inst_vec import bb2vec
import multiprocessing
from tqdm import tqdm
import warnings
from datetime import datetime
# 忽略输出torch cpu执行警告
warnings.filterwarnings('ignore', category=UserWarning)
def addr2vec(base_file_path):
@ -39,12 +38,20 @@ def addr2vec(base_file_path):
if __name__ == '__main__':
logger = setup_logger('feature2json', './log/feature2json.log')
sample_type = 'benign'
sample_type = 'malware'
# json_path = os.path.join(f'./out/json/{sample_type}')
json_path = os.path.join(f'./out/json/{sample_type}')
json_files = os.listdir(json_path)
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]),
total=len(json_files)))
now = datetime.now()
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
print("start time:", formatted_now)
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
# total=len(json_files)))
multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
# for json_file in json_files:
# addr2vec(json_path, json_file)