复现完成存档
This commit is contained in:
parent
37b3d9c4cf
commit
3df2fe07cb
@ -1,5 +1,5 @@
|
||||
Data:
|
||||
preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON"
|
||||
preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON_remake"
|
||||
train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl"
|
||||
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
|
||||
Training:
|
||||
|
25
req_pip.txt
Normal file
25
req_pip.txt
Normal file
@ -0,0 +1,25 @@
|
||||
antlr4-python3-runtime==4.8
|
||||
ase==3.21.1
|
||||
cmake==3.18.4.post1
|
||||
dataclasses==0.6
|
||||
googledrivedownloader==0.4
|
||||
hydra-core==1.0.6
|
||||
importlib-resources==5.1.2
|
||||
intel-openmp==2021.1.2
|
||||
magma-cuda112==2.5.2
|
||||
mkl==2021.1.1
|
||||
mkl-include==2021.1.1
|
||||
ninja==1.10.0.post2
|
||||
omegaconf==2.0.6
|
||||
prefetch-generator==1.0.1
|
||||
rdflib==5.0.0
|
||||
tbb==2021.1.1
|
||||
texttable==1.6.3
|
||||
torch==1.8.0+cu111
|
||||
torch-cluster==1.5.9
|
||||
torch-geometric==1.6.3
|
||||
torch-scatter==2.0.6
|
||||
torch-sparse==0.6.9
|
||||
torch-spline-conv==1.2.1
|
||||
torchaudio==0.8.0
|
||||
torchvision==0.9.0+cu111
|
@ -1,10 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import sys
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.dirname(sys.path[0]))
|
||||
from src.utils.Vocabulary import Vocab
|
||||
|
||||
|
||||
@ -35,6 +35,8 @@ def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save
|
||||
|
||||
|
||||
def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int):
|
||||
if not os.path.exists(save_path+f"{train_type}_{file_type}/"):
|
||||
os.mkdir(save_path+f"{train_type}_{file_type}/")
|
||||
with open(file, "r", encoding="utf-8") as item:
|
||||
line = item.readline()
|
||||
item = json.loads(line)
|
||||
@ -65,11 +67,14 @@ def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: s
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
json_path = "./jsonl/infected_jsonl/"
|
||||
train_vocab_file = "../data/processed_dataset/train_external_function_name_vocab.jsonl"
|
||||
save_vocab_file = "../data/processed_dataset/DatasetJSON/"
|
||||
file_type = "malware"
|
||||
malware_json_path = "/home/king/python/data/jsonl/infected_jsonl/"
|
||||
benign_json_path = "/home/king/python/data/jsonl/refind_jsonl/"
|
||||
train_vocab_file = "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl"
|
||||
save_vocab_file = "/home/king/python/data/processed_dataset/DatasetJSON_remake/"
|
||||
file_type = ["malware", "benign"]
|
||||
max_vocab_size = 10000
|
||||
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
|
||||
parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
||||
file_type=file_type)
|
||||
parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
||||
file_type=file_type[0])
|
||||
parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file,
|
||||
file_type=file_type[1])
|
||||
|
@ -6,8 +6,8 @@ import heapq
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == '__main__':
|
||||
mal_file_name = './jsonl/infected_jsonl/'
|
||||
ben_file_name = './jsonl/refind_jsonl/'
|
||||
mal_file_name = '/home/king/python/data/jsonl/infected_jsonl/'
|
||||
ben_file_name = '/home/king/python/data/jsonl/refind_jsonl/'
|
||||
fun_name_dict = {}
|
||||
for file in tqdm(os.listdir(mal_file_name)):
|
||||
with open(mal_file_name + file, 'r') as item:
|
||||
@ -29,7 +29,7 @@ if __name__ == '__main__':
|
||||
fun_name_dict[fun_name] += 1
|
||||
else:
|
||||
fun_name_dict[fun_name] = 1
|
||||
with open('./res.jsonl', 'w') as file:
|
||||
with open('/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl', 'w') as file:
|
||||
largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1])
|
||||
for key, value in largest_10000_items:
|
||||
temp = {"f_name": key, "count": value}
|
||||
|
@ -170,6 +170,7 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
|
||||
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
|
||||
|
||||
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
if details is True:
|
||||
@ -185,8 +186,8 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
|
||||
|
||||
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
|
||||
log_result_file: str):
|
||||
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
|
||||
# dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
|
||||
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
|
||||
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# model configure
|
||||
@ -304,7 +305,7 @@ def main_app(config: DictConfig):
|
||||
num_gpus = torch.cuda.device_count()
|
||||
log.info("Total number of GPUs = {}".format(num_gpus))
|
||||
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
|
||||
# main_train_worker(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file, "")
|
||||
|
||||
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
|
||||
|
||||
else:
|
||||
@ -312,7 +313,7 @@ def main_app(config: DictConfig):
|
||||
|
||||
# model re-init and loading
|
||||
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
|
||||
device = torch.device('cuda:0')
|
||||
device = torch.device('cuda')
|
||||
train_vocab_path = _train_params.external_func_vocab_file
|
||||
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
|
||||
|
||||
@ -343,4 +344,4 @@ def main_app(config: DictConfig):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_app()
|
||||
main_app()
|
@ -190,16 +190,6 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
|
||||
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
|
||||
# try:
|
||||
# _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
# except ValueError:
|
||||
# _roc_auc_score = 0.001
|
||||
print(gather_true_classes)
|
||||
print(gather_positive_prods)
|
||||
# try:
|
||||
# _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
# except ValueError:
|
||||
# pass
|
||||
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
if details is True:
|
||||
|
@ -55,7 +55,6 @@ def _simulating(_dataset, _batch_size: int):
|
||||
if index >= 3:
|
||||
break
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
|
||||
print(data)
|
||||
print("Hash: ", _hash)
|
||||
print("Position: ", _position)
|
||||
print("\n")
|
||||
|
@ -8,7 +8,6 @@ def create_real_batch_data(one_batch: Batch):
|
||||
real = []
|
||||
position = [0]
|
||||
count = 0
|
||||
|
||||
assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other"
|
||||
|
||||
for item in one_batch.local_acfgs:
|
||||
|
Loading…
Reference in New Issue
Block a user