线程池改进
This commit is contained in:
parent
795e5f050e
commit
bdde25a9ab
@ -16,7 +16,7 @@ config_file = os.path.join('./bert/bert.json')
|
|||||||
|
|
||||||
|
|
||||||
# 禁用分词器多线程
|
# 禁用分词器多线程
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
# from my_data_collator import MyDataCollatorForPreTraining
|
# from my_data_collator import MyDataCollatorForPreTraining
|
||||||
# model_file = os.path.join("./pytorch_model.bin")
|
# model_file = os.path.join("./pytorch_model.bin")
|
||||||
@ -27,7 +27,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|||||||
def load_model():
|
def load_model():
|
||||||
config = BertConfig.from_json_file(config_file)
|
config = BertConfig.from_json_file(config_file)
|
||||||
model = BertForPreTraining(config)
|
model = BertForPreTraining(config)
|
||||||
state_dict = torch.load(model_file, map_location='cpu')
|
state_dict = torch.load(model_file)
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
import r2pipe
|
import r2pipe
|
||||||
import hashlib
|
import hashlib
|
||||||
from my_utils import *
|
from my_utils import *
|
||||||
@ -14,7 +16,7 @@ import multiprocessing
|
|||||||
|
|
||||||
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
|
ret_trap_opcode_family = ["ret", "hlt", "int3", "ud2"]
|
||||||
|
|
||||||
sample_type = 'benign'
|
sample_type = 'malware'
|
||||||
|
|
||||||
|
|
||||||
def extract_opcode(disasm_text):
|
def extract_opcode(disasm_text):
|
||||||
@ -252,11 +254,11 @@ def exe_to_json(file_path):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logger = init_logging()
|
logger = init_logging()
|
||||||
sample_file_path = f"/mnt/d/bishe/sample_{sample_type}"
|
sample_file_path = f"/mnt/d/bishe/dataset/sample_{sample_type}"
|
||||||
sample_file_list = os.listdir(sample_file_path)
|
sample_file_list = os.listdir(sample_file_path)
|
||||||
print(f"max worker {os.cpu_count()}")
|
print(f"max worker {os.cpu_count()}")
|
||||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||||
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[::-1]]), total=len(sample_file_list)))
|
result = list(tqdm(pool.imap_unordered(exe_to_json, [os.path.join(sample_file_path, file_name) for file_name in sample_file_list[:10000]]), total=10000))
|
||||||
|
|
||||||
|
|
||||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
# with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from my_utils import save_json, load_json, setup_logger
|
from my_utils import save_json, load_json, setup_logger, multi_thread_order, THREAD_HALF,THREAD_FULL
|
||||||
from bert.obtain_inst_vec import bb2vec
|
from bert.obtain_inst_vec import bb2vec
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import warnings
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
# 忽略输出torch cpu执行警告
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def addr2vec(base_file_path):
|
def addr2vec(base_file_path):
|
||||||
@ -39,12 +38,20 @@ def addr2vec(base_file_path):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logger = setup_logger('feature2json', './log/feature2json.log')
|
logger = setup_logger('feature2json', './log/feature2json.log')
|
||||||
sample_type = 'benign'
|
sample_type = 'malware'
|
||||||
|
# json_path = os.path.join(f'./out/json/{sample_type}')
|
||||||
json_path = os.path.join(f'./out/json/{sample_type}')
|
json_path = os.path.join(f'./out/json/{sample_type}')
|
||||||
json_files = os.listdir(json_path)
|
json_files = os.listdir(json_path)
|
||||||
with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
now = datetime.now()
|
||||||
result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files]),
|
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
total=len(json_files)))
|
print("start time:", formatted_now)
|
||||||
|
# with multiprocessing.Pool(processes=os.cpu_count()) as pool:
|
||||||
|
# result = list(tqdm(pool.imap_unordered(addr2vec, [os.path.join(json_path, file) for file in json_files[:1] if os.path.isfile(os.path.join(json_path, file))]),
|
||||||
|
# total=len(json_files)))
|
||||||
|
multi_thread_order(addr2vec, [os.path.join(json_path, file) for file in json_files[::-1] if os.path.isfile(os.path.join(json_path, file))], thread_num=THREAD_FULL)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# for json_file in json_files:
|
# for json_file in json_files:
|
||||||
# addr2vec(json_path, json_file)
|
# addr2vec(json_path, json_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user