复现完成存档

This commit is contained in:
huihun 2024-01-26 13:10:33 +08:00
parent 37b3d9c4cf
commit 3df2fe07cb
8 changed files with 48 additions and 29 deletions

View File

@ -1,5 +1,5 @@
Data: Data:
preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON" preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON_remake"
train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl" train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl"
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
Training: Training:

25
req_pip.txt Normal file
View File

@ -0,0 +1,25 @@
antlr4-python3-runtime==4.8
ase==3.21.1
cmake==3.18.4.post1
dataclasses==0.6
googledrivedownloader==0.4
hydra-core==1.0.6
importlib-resources==5.1.2
intel-openmp==2021.1.2
magma-cuda112==2.5.2
mkl==2021.1.1
mkl-include==2021.1.1
ninja==1.10.0.post2
omegaconf==2.0.6
prefetch-generator==1.0.1
rdflib==5.0.0
tbb==2021.1.1
texttable==1.6.3
torch==1.8.0+cu111
torch-cluster==1.5.9
torch-geometric==1.6.3
torch-scatter==2.0.6
torch-sparse==0.6.9
torch-spline-conv==1.2.1
torchaudio==0.8.0
torchvision==0.9.0+cu111

View File

@ -1,10 +1,10 @@
import json import json
import os import os
import sys
import torch import torch
from torch_geometric.data import Data from torch_geometric.data import Data
from tqdm import tqdm from tqdm import tqdm
sys.path.append(os.path.dirname(sys.path[0]))
from src.utils.Vocabulary import Vocab from src.utils.Vocabulary import Vocab
@ -35,6 +35,8 @@ def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save
def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int): def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int):
if not os.path.exists(save_path+f"{train_type}_{file_type}/"):
os.mkdir(save_path+f"{train_type}_{file_type}/")
with open(file, "r", encoding="utf-8") as item: with open(file, "r", encoding="utf-8") as item:
line = item.readline() line = item.readline()
item = json.loads(line) item = json.loads(line)
@ -65,11 +67,14 @@ def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: s
if __name__ == '__main__': if __name__ == '__main__':
json_path = "./jsonl/infected_jsonl/" malware_json_path = "/home/king/python/data/jsonl/infected_jsonl/"
train_vocab_file = "../data/processed_dataset/train_external_function_name_vocab.jsonl" benign_json_path = "/home/king/python/data/jsonl/refind_jsonl/"
save_vocab_file = "../data/processed_dataset/DatasetJSON/" train_vocab_file = "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl"
file_type = "malware" save_vocab_file = "/home/king/python/data/processed_dataset/DatasetJSON_remake/"
file_type = ["malware", "benign"]
max_vocab_size = 10000 max_vocab_size = 10000
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size) vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary, save_path=save_vocab_file, parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
file_type=file_type) file_type=file_type[0])
parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file,
file_type=file_type[1])

View File

@ -6,8 +6,8 @@ import heapq
from tqdm import tqdm from tqdm import tqdm
if __name__ == '__main__': if __name__ == '__main__':
mal_file_name = './jsonl/infected_jsonl/' mal_file_name = '/home/king/python/data/jsonl/infected_jsonl/'
ben_file_name = './jsonl/refind_jsonl/' ben_file_name = '/home/king/python/data/jsonl/refind_jsonl/'
fun_name_dict = {} fun_name_dict = {}
for file in tqdm(os.listdir(mal_file_name)): for file in tqdm(os.listdir(mal_file_name)):
with open(mal_file_name + file, 'r') as item: with open(mal_file_name + file, 'r') as item:
@ -29,7 +29,7 @@ if __name__ == '__main__':
fun_name_dict[fun_name] += 1 fun_name_dict[fun_name] += 1
else: else:
fun_name_dict[fun_name] = 1 fun_name_dict[fun_name] = 1
with open('./res.jsonl', 'w') as file: with open('/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl', 'w') as file:
largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1]) largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1])
for key, value in largest_10000_items: for key, value in largest_10000_items:
temp = {"f_name": key, "count": value} temp = {"f_name": key, "count": value}

View File

@ -170,6 +170,7 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods) _fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
if details is True: if details is True:
@ -185,8 +186,8 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger, def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
log_result_file: str): log_result_file: str):
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) # dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
# dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
# model configure # model configure
@ -304,7 +305,7 @@ def main_app(config: DictConfig):
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
log.info("Total number of GPUs = {}".format(num_gpus)) log.info("Total number of GPUs = {}".format(num_gpus))
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
# main_train_worker(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file, "")
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0)) best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
else: else:
@ -312,7 +313,7 @@ def main_app(config: DictConfig):
# model re-init and loading # model re-init and loading
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file)) log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
device = torch.device('cuda:0') device = torch.device('cuda')
train_vocab_path = _train_params.external_func_vocab_file train_vocab_path = _train_params.external_func_vocab_file
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size) vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)

View File

@ -190,16 +190,6 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
# try:
# _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
# except ValueError:
# _roc_auc_score = 0.001
print(gather_true_classes)
print(gather_positive_prods)
# try:
# _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
# except ValueError:
# pass
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods) _fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
if details is True: if details is True:

View File

@ -55,7 +55,6 @@ def _simulating(_dataset, _batch_size: int):
if index >= 3: if index >= 3:
break break
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data) _real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
print(data)
print("Hash: ", _hash) print("Hash: ", _hash)
print("Position: ", _position) print("Position: ", _position)
print("\n") print("\n")

View File

@ -8,7 +8,6 @@ def create_real_batch_data(one_batch: Batch):
real = [] real = []
position = [0] position = [0]
count = 0 count = 0
assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other" assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other"
for item in one_batch.local_acfgs: for item in one_batch.local_acfgs: