参数更改

This commit is contained in:
huihun 2024-01-10 10:32:38 +08:00
parent 601a61157b
commit 37b3d9c4cf
9 changed files with 583 additions and 139 deletions

View File

@ -1,6 +1,6 @@
Data: Data:
preprocess_root: "../data/processed_dataset/DatasetJSON/" preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON"
train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl" train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl"
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
Training: Training:
cuda: True # enable GPU training if cuda is available cuda: True # enable GPU training if cuda is available

View File

@ -1,79 +1,83 @@
# This file may be used to create an environment using:
antlr4-python3-runtime==4.8 # $ conda create --name <env> --file <this file>
ase==3.21.1 # platform: linux-64
ca-certificates==2021.1.19 _libgcc_mutex=0.1=main
cached-property==1.5.2 antlr4-python3-runtime=4.8=pypi_0
certifi==2020.12.5 ase=3.21.1=pypi_0
cffi==1.14.5 ca-certificates=2021.1.19=h06a4308_1
chardet==4.0.0 cached-property=1.5.2=pypi_0
cmake==3.18.4.post1 certifi=2020.12.5=py37h06a4308_0
cycler==0.10.0 cffi=1.14.5=pypi_0
dataclasses==0.6 chardet=4.0.0=pypi_0
decorator==4.4.2 cmake=3.18.4.post1=pypi_0
future==0.18.2 cycler=0.10.0=pypi_0
googledrivedownloader==0.4 dataclasses=0.6=pypi_0
h5py==3.2.1 decorator=4.4.2=pypi_0
hydra-core==1.0.6 future=0.18.2=pypi_0
idna==2.10 googledrivedownloader=0.4=pypi_0
importlib-resources==5.1.2 h5py=3.2.1=pypi_0
intel-openmp==2021.1.2 hydra-core=1.0.6=pypi_0
isodate==0.6.0 idna=2.10=pypi_0
jinja2==2.11.3 importlib-resources=5.1.2=pypi_0
joblib==1.0.1 intel-openmp=2021.1.2=pypi_0
kiwisolver==1.3.1 isodate=0.6.0=pypi_0
ld_impl_linux-64==2.33.1 jinja2=2.11.3=pypi_0
libedit==3.1.20191231 joblib=1.0.1=pypi_0
libffi==3.3 kiwisolver=1.3.1=pypi_0
libgcc-ng==9.1.0 ld_impl_linux-64=2.33.1=h53a641e_7
libstdcxx-ng==9.1.0 libedit=3.1.20191231=h14c3975_1
llvmlite==0.35.0 libffi=3.3=he6710b0_2
magma-cuda112==2.5.2 libgcc-ng=9.1.0=hdf63c60_0
markupsafe==1.1.1 libstdcxx-ng=9.1.0=hdf63c60_0
matplotlib==3.3.4 llvmlite=0.35.0=pypi_0
mkl==2021.1.1 magma-cuda112=2.5.2=1
mkl-include==2021.1.1 markupsafe=1.1.1=pypi_0
ncurses==6.2 matplotlib=3.3.4=pypi_0
networkx==2.5 mkl=2021.1.1=pypi_0
ninja==1.10.0.post2 mkl-include=2021.1.1=pypi_0
numba==0.52.0 ncurses=6.2=he6710b0_1
numpy==1.20.1 networkx=2.5=pypi_0
omegaconf==2.0.6 ninja=1.10.0.post2=pypi_0
openssl==1.1.1j numba=0.52.0=pypi_0
pandas==1.2.3 numpy=1.20.1=pypi_0
pillow==8.1.2 omegaconf=2.0.6=pypi_0
pip==21.0.1 openssl=1.1.1j=h27cfd23_0
prefetch-generator==1.0.1 pandas=1.2.3=pypi_0
pycparser==2.20 pillow=8.1.2=pypi_0
pyparsing==2.4.7 pip=21.0.1=py37h06a4308_0
python-dateutil==2.8.1 prefetch-generator=1.0.1=pypi_0
python-louvain==0.15 pycparser=2.20=pypi_0
pytz==2021.1 pyparsing=2.4.7=pypi_0
pyyaml==5.4.1 python=3.7.9=h7579374_0
rdflib==5.0.0 python-dateutil=2.8.1=pypi_0
readline==8.1 python-louvain=0.15=pypi_0
requests==2.25.1 pytz=2021.1=pypi_0
scikit-learn==0.24.1 pyyaml=5.4.1=pypi_0
scipy==1.6.1 rdflib=5.0.0=pypi_0
seaborn==0.11.1 readline=8.1=h27cfd23_0
setuptools==52.0.0 requests=2.25.1=pypi_0
six==1.15.0 scikit-learn=0.24.1=pypi_0
sqlite==3.33.0 scipy=1.6.1=pypi_0
tbb==2021.1.1 seaborn=0.11.1=pypi_0
texttable==1.6.3 setuptools=52.0.0=py37h06a4308_0
threadpoolctl==2.1.0 six=1.15.0=pypi_0
tk==8.6.10 sqlite=3.33.0=h62c20be_0
torch==1.8.0+cu111 tbb=2021.1.1=pypi_0
torch-cluster==1.5.9 texttable=1.6.3=pypi_0
torch-geometric==1.6.3 threadpoolctl=2.1.0=pypi_0
torch-scatter==2.0.6 tk=8.6.10=hbc83047_0
torch-sparse==0.6.9 torch=1.8.0+cu111=pypi_0
torch-spline-conv==1.2.1 torch-cluster=1.5.9=pypi_0
torchaudio==0.8.0 torch-geometric=1.6.3=pypi_0
torchvision==0.9.0+cu111 torch-scatter=2.0.6=pypi_0
tqdm==4.59.0 torch-sparse=0.6.9=pypi_0
typing-extensions==3.7.4.3 torch-spline-conv=1.2.1=pypi_0
urllib3==1.26.3 torchaudio=0.8.0=pypi_0
wheel==0.36.2 torchvision=0.9.0+cu111=pypi_0
xz==5.2.5 tqdm=4.59.0=pypi_0
zipp==3.4.1 typing-extensions=3.7.4.3=pypi_0
zlib==1.2.11 urllib3=1.26.3=pypi_0
wheel=0.36.2=pyhd3eb1b0_0
xz=5.2.5=h7b6447c_0
zipp=3.4.1=pypi_0
zlib=1.2.11=h7b6447c_3

View File

@ -1,4 +1,6 @@
import json import json
import os
import torch import torch
from torch_geometric.data import Data from torch_geometric.data import Data
from tqdm import tqdm from tqdm import tqdm
@ -6,37 +8,68 @@ from tqdm import tqdm
from src.utils.Vocabulary import Vocab from src.utils.Vocabulary import Vocab
def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab): def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save_path: str, file_type: str):
#def parse_json_list_2_pyg_object(jsonl_file: str): # def parse_json_list_2_pyg_object(jsonl_file: str):
train_type = ['train', 'valid', 'test']
index = 0 index = 0
with open(jsonl_file, "r", encoding="utf-8") as file: file_index = 0
for item in tqdm(file): type_index = 0
item = json.loads(item) valid_flag = True
item_hash = item['hash'] test_flag = True
file_len = len(os.listdir(jsonl_file))
acfg_list = []
for one_acfg in item['acfg_list']: # list of dict of acfg for file in tqdm(os.listdir(jsonl_file)):
block_features = one_acfg['block_features'] if index >= file_len * 0.8 and valid_flag:
block_edges = one_acfg['block_edges'] type_index += 1
one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float), edge_index=torch.tensor(block_edges, dtype=torch.long)) valid_flag = False
acfg_list.append(one_acfg_data) file_index = 0
print("make valid set")
item_function_names = item['function_names'] elif index >= file_len * 0.9 and test_flag:
item_function_edges = item['function_edges'] type_index += 1
test_flag = False
local_function_name_list = item_function_names[:len(acfg_list)] file_index = 0
assert len(acfg_list) == len(local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List" print("make test set")
external_function_name_list = item_function_names[len(acfg_list):] j = json_to_pt(file=jsonl_file + file, label=label, vocab=vocab, save_path=save_path, file_type=file_type, train_type=train_type[type_index], index=file_index)
index += 1
external_function_index_list = [vocab[f_name] for f_name in external_function_name_list] file_index += 1
index += 1
torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list, function_edges=item_function_edges, targets=label), "./cache/benign_{}.pt".format(index))
def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int):
with open(file, "r", encoding="utf-8") as item:
line = item.readline()
item = json.loads(line)
item_hash = item['hash']
acfg_list = []
for one_acfg in item['acfg_list']: # list of dict of acfg
block_features = one_acfg['block_features']
block_edges = one_acfg['block_edges']
one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float),
edge_index=torch.tensor(block_edges, dtype=torch.long))
acfg_list.append(one_acfg_data)
item_function_names = item['function_names']
item_function_edges = item['function_edges']
local_function_name_list = item_function_names[:len(acfg_list)]
assert len(acfg_list) == len(
local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List"
external_function_name_list = item_function_names[len(acfg_list):]
external_function_index_list = [vocab[f_name] for f_name in external_function_name_list]
torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list,
function_edges=item_function_edges, targets=label),
save_path + "{}_{}/{}_{}.pt".format(train_type, file_type, file_type, index))
return True
if __name__ == '__main__': if __name__ == '__main__':
json_path = "./benign_result.jsonl" json_path = "./jsonl/infected_jsonl/"
train_vocab_file = "../data/processed_dataset/train_external_function_name_vocab.jsonl" train_vocab_file = "../data/processed_dataset/train_external_function_name_vocab.jsonl"
# train_vocab_file = "./res.jsonl" save_vocab_file = "../data/processed_dataset/DatasetJSON/"
file_type = "malware"
max_vocab_size = 10000 max_vocab_size = 10000
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size) vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary) parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
file_type=file_type)

View File

@ -1,31 +1,38 @@
import json import json
import os
from itertools import islice
import heapq
from tqdm import tqdm from tqdm import tqdm
if __name__ == '__main__': if __name__ == '__main__':
mal_file_name = './malware_result.jsonl' mal_file_name = './jsonl/infected_jsonl/'
ben_file_name = './benign-result.jsonl' ben_file_name = './jsonl/refind_jsonl/'
fil = open(mal_file_name, mode='r')
fun_name_dict = {} fun_name_dict = {}
for item in tqdm(fil): for file in tqdm(os.listdir(mal_file_name)):
item = json.loads(item) with open(mal_file_name + file, 'r') as item:
item_fun_list = item['function_names'] item = json.loads(item.readline())
for fun_name in item_fun_list: item_fun_list = item['function_names']
if fun_name_dict.get(fun_name) is not None: for fun_name in item_fun_list:
fun_name_dict[fun_name] += 1 if fun_name != 'start' and fun_name != 'start_0' and 'sub_' not in fun_name:
else: if fun_name_dict.get(fun_name) is not None:
fun_name_dict[fun_name] = 1 fun_name_dict[fun_name] += 1
fil = open(mal_file_name, mode='r') else:
for item in tqdm(fil): fun_name_dict[fun_name] = 1
item = json.loads(item) for file in tqdm(os.listdir(ben_file_name)):
item_fun_list = item['function_names'] with open(ben_file_name + file, 'r') as item:
for fun_name in item_fun_list: item = json.loads(item.readline())
if fun_name_dict.get(fun_name) is not None: item_fun_list = item['function_names']
fun_name_dict[fun_name] += 1 for fun_name in item_fun_list:
else: if fun_name != 'start' and fun_name != 'start_0' and 'sub_' not in fun_name:
fun_name_dict[fun_name] = 1 if fun_name_dict.get(fun_name) is not None:
fun_name_dict[fun_name] += 1
else:
fun_name_dict[fun_name] = 1
with open('./res.jsonl', 'w') as file: with open('./res.jsonl', 'w') as file:
for key, value in fun_name_dict.items(): largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1])
for key, value in largest_10000_items:
temp = {"f_name": key, "count": value} temp = {"f_name": key, "count": value}
file.write(json.dumps(temp) + '\n') file.write(json.dumps(temp) + '\n')

View File

@ -185,8 +185,8 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger, def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
log_result_file: str): log_result_file: str):
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) # dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
# model configure # model configure
@ -304,7 +304,7 @@ def main_app(config: DictConfig):
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
log.info("Total number of GPUs = {}".format(num_gpus)) log.info("Total number of GPUs = {}".format(num_gpus))
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
# main_train_worker(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file, "")
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0)) best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
else: else:
@ -312,7 +312,7 @@ def main_app(config: DictConfig):
# model re-init and loading # model re-init and loading
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file)) log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
device = torch.device('cuda') device = torch.device('cuda:0')
train_vocab_path = _train_params.external_func_vocab_file train_vocab_path = _train_params.external_func_vocab_file
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size) vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
@ -343,4 +343,4 @@ def main_app(config: DictConfig):
if __name__ == '__main__': if __name__ == '__main__':
main_app() main_app()

392
src/DistTrainModel_dual.py Normal file
View File

@ -0,0 +1,392 @@
import logging
import math
import os
import random
from datetime import datetime
import hydra
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as torch_mp
import torch.utils.data
import torch.utils.data.distributed
from hydra.utils import to_absolute_path
from omegaconf import DictConfig
from prefetch_generator import BackgroundGenerator
from sklearn.metrics import roc_auc_score, roc_curve
from torch import nn
from torch_geometric.data import DataLoader
from tqdm import tqdm
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
from utils.ParameterClasses import ModelParams, TrainParams, OptimizerParams, OneEpochResult
from utils.PreProcessedDataset import MalwareDetectionDataset
from utils.RealBatch import create_real_batch_data
from utils.Vocabulary import Vocab
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
def reduce_sum(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
return rt
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
rt /= nprocs
return rt
def all_gather_concat(tensor):
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, optimizer, nprocs, idx_epoch, best_auc, best_model_file, original_valid_length, result_file):
# print(train_loader.dataset.__dict__)
model.train()
local_device = torch.device("cuda", local_rank)
write_into(file_name_path=result_file, log_str="The local device = {} among {} nprocs in the {}-th epoch.".format(local_device, nprocs, idx_epoch))
until_sum_reduced_loss = 0.0
smooth_avg_reduced_loss_list = []
for _idx_bt, _batch in enumerate(tqdm(train_loader, desc="reading _batch from local_rank={}".format(local_rank))):
model.train()
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=_batch)
if _real_batch is None:
write_into(result_file,
"{}\n_real_batch is None in creating the real batch data of training ... ".format("*-" * 100))
continue
_real_batch = _real_batch.to(local_device)
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
train_batch_pred = model(real_local_batch=_real_batch,
real_bt_positions=_position,
bt_external_names=_external_list,
bt_all_function_edges=_function_edges,
local_device=local_device)
train_batch_pred = train_batch_pred.squeeze()
loss = criterion(train_batch_pred, _true_classes)
torch.distributed.barrier()
optimizer.zero_grad()
loss.backward()
optimizer.step()
reduced_loss = reduce_mean(loss, nprocs)
until_sum_reduced_loss += reduced_loss.item()
smooth_avg_reduced_loss_list.append(until_sum_reduced_loss / (_idx_bt + 1))
if _idx_bt != 0 and (_idx_bt % math.ceil(len(train_loader) / 3) == 0 or _idx_bt == int(len(train_loader) - 1)):
val_start_time = datetime.now()
if local_rank == 0:
write_into(result_file, "\nIn {}-th epoch, {}-th batch, we start to validate ... ".format(idx_epoch, _idx_bt))
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
valid_result = validate(local_rank=local_rank,
valid_loader=valid_loader,
model=model,
criterion=criterion,
evaluate_flag=_eval_flag,
distributed=True, # 分布式
nprocs=nprocs,
original_valid_length=original_valid_length,
result_file=result_file,
details=True # 验证细节
)
if best_auc < valid_result.ROC_AUC_Score:
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
_idx_bt,
best_auc,
valid_result.ROC_AUC_Score,
best_model_file)
best_auc = valid_result.ROC_AUC_Score
torch.save(model.module.state_dict(), best_model_file)
else:
_info = "[AUC NOT Increased!] AUC decreased from {:.5f} to {:.5f}!".format(best_auc, valid_result.ROC_AUC_Score)
if local_rank == 0:
write_into(result_file, valid_result.__str__())
write_into(result_file, _info)
write_into(result_file, "[#One Validation Time#] Consume about {} time period for one validation.".format(datetime.now() - val_start_time))
return smooth_avg_reduced_loss_list, best_auc
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
model.eval()
if distributed:
local_device = torch.device("cuda", local_rank)
else:
local_device = torch.device("cuda")
sum_loss = torch.tensor(0.0, dtype=torch.float, device=local_device)
n_samples = torch.tensor(0, dtype=torch.int, device=local_device)
all_true_classes = []
all_positive_probs = []
with torch.no_grad():
for idx_batch, data in enumerate(tqdm(valid_loader)):
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
if _real_batch is None:
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of validation ... ".format("*-" * 100))
continue
_real_batch = _real_batch.to(local_device)
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
batch_pred = model(real_local_batch=_real_batch,
real_bt_positions=_position,
bt_external_names=_external_list,
bt_all_function_edges=_function_edges,
local_device=local_device)
batch_pred = batch_pred.squeeze(-1)
loss = criterion(batch_pred, _true_classes)
sum_loss += loss.item()
n_samples += len(batch_pred)
all_true_classes.append(_true_classes)
all_positive_probs.append(batch_pred)
avg_loss = sum_loss / (idx_batch + 1)
all_true_classes = torch.cat(all_true_classes, dim=0)
all_positive_probs = torch.cat(all_positive_probs, dim=0)
if distributed:
torch.distributed.barrier()
reduced_n_samples = reduce_sum(n_samples)
reduced_avg_loss = reduce_mean(avg_loss, nprocs)
gather_true_classes = all_gather_concat(all_true_classes).detach().cpu().numpy()
gather_positive_prods = all_gather_concat(all_positive_probs).detach().cpu().numpy()
gather_true_classes = gather_true_classes[:original_valid_length]
gather_positive_prods = gather_positive_prods[:original_valid_length]
else:
reduced_n_samples = n_samples
reduced_avg_loss = avg_loss
gather_true_classes = all_true_classes.detach().cpu().numpy()
gather_positive_prods = all_positive_probs.detach().cpu().numpy()
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
# try:
# _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
# except ValueError:
# _roc_auc_score = 0.001
print(gather_true_classes)
print(gather_positive_prods)
# try:
# _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
# except ValueError:
# pass
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
if details is True:
_100_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.01)
_1000_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.001)
else:
_100_info, _1000_info = "None", "None"
_eval_result = OneEpochResult(Epoch_Flag=evaluate_flag,
Number_Samples=reduced_n_samples,
Avg_Loss=reduced_avg_loss,
Info_100=_100_info,
Info_1000=_1000_info,
ROC_AUC_Score=_roc_auc_score,
Thresholds=_thresholds,
TPRs=_tpr,
FPRs=_fpr)
return _eval_result
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
log_result_file: str):
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
torch.cuda.set_device(local_rank)
# model configure
vocab = Vocab(freq_file=train_params.external_func_vocab_file, max_vocab_size=train_params.max_vocab_size)
if model_params.ablation_models.lower() == "full":
model = HierarchicalGraphNeuralNetwork(model_params=model_params, external_vocab=vocab, global_log=global_log)
else:
raise NotImplementedError
model.cuda(local_rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
if local_rank == 0:
write_into(file_name_path=log_result_file, log_str=model.__str__())
# loss function
criterion = nn.BCELoss().cuda(local_rank)
lr = optimizer_params.lr
if optimizer_params.optimizer_name.lower() == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif optimizer_params.optimizer_name.lower() == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=optimizer_params.weight_decay)
else:
raise NotImplementedError
max_epochs = train_params.max_epochs
dataset_root_path = train_params.processed_files_path
train_batch_size = train_params.train_bs
test_batch_size = train_params.test_bs
# training dataset & dataloader
train_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="train")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
# validation dataset & dataloader
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
valid_loader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, pin_memory=True, sampler=valid_sampler)
if local_rank == 0:
write_into(file_name_path=log_result_file, log_str="Training dataset={}, sampler={}, loader={}".format(len(train_dataset), len(train_sampler), len(train_loader)))
write_into(file_name_path=log_result_file, log_str="Validation dataset={}, sampler={}, loader={}".format(len(valid_dataset), len(valid_sampler), len(valid_loader)))
best_auc = 0.0
ori_valid_length = len(valid_dataset)
best_model_path = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(local_rank))
all_batch_avg_smooth_loss_list = []
for epoch in range(max_epochs):
train_sampler.set_epoch(epoch)
valid_sampler.set_epoch(epoch)
# train for one epoch
time_start = datetime.now()
if local_rank == 0:
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank,
train_loader=train_loader,
valid_loader=valid_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
nprocs=nprocs,
idx_epoch=epoch,
best_auc=best_auc,
best_model_file=best_model_path,
original_valid_length=ori_valid_length,
result_file=log_result_file)
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
# adjust learning rate
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] / optimizer_params.learning_anneal
time_end = datetime.now()
if local_rank == 0:
write_into(log_result_file, "\n{} end of {}-epoch, with best_auc={}, len of loss={}, end time={}, and time period={} {}".format("*" * 50, epoch, best_auc,
len(smooth_avg_reduced_loss_list),
time_end.strftime("%Y-%m-%d@%H:%M:%S"),
time_end - time_start, "*" * 50))
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
@hydra.main(config_path="../configs/", config_name="default.yaml")
def main_app(config: DictConfig):
# set seed for determinism for reproduction
random.seed(config.Training.seed)
np.random.seed(config.Training.seed)
torch.manual_seed(config.Training.seed)
torch.cuda.manual_seed(config.Training.seed)
torch.cuda.manual_seed_all(config.Training.seed)
# setting hyper-parameter for Training / Model / Optimizer
_train_params = TrainParams(processed_files_path=to_absolute_path(config.Data.preprocess_root), max_epochs=config.Training.max_epoches, train_bs=config.Training.train_batch_size, test_bs=config.Training.test_batch_size, external_func_vocab_file=to_absolute_path(config.Data.train_vocab_file), max_vocab_size=config.Data.max_vocab_size)
_model_params = ModelParams(gnn_type=config.Model.gnn_type, pool_type=config.Model.pool_type, acfg_init_dims=config.Model.acfg_node_init_dims, cfg_filters=config.Model.cfg_filters, fcg_filters=config.Model.fcg_filters, number_classes=config.Model.number_classes, dropout_rate=config.Model.drapout_rate, ablation_models=config.Model.ablation_models)
_optim_params = OptimizerParams(optimizer_name=config.Optimizer.name, lr=config.Optimizer.learning_rate, weight_decay=config.Optimizer.weight_decay, learning_anneal=config.Optimizer.learning_anneal)
# logging
log = logging.getLogger("DistTrainModel.py")
log.setLevel("DEBUG")
log.warning("Hydra's Current Working Directory: {}".format(os.getcwd()))
# setting for the log directory
result_file = '{}_{}_{}_ACFG_{}_FCG_{}_Epoch_{}_TrainBS_{}_LR_{}_Time_{}.txt'.format(_model_params.ablation_models, _model_params.gnn_type, _model_params.pool_type,
_model_params.cfg_filters, _model_params.fcg_filters, _train_params.max_epochs,
_train_params.train_bs, _optim_params.lr, datetime.now().strftime("%Y%m%d_%H%M%S"))
log_result_file = os.path.join(os.getcwd(), result_file)
_other_params = {"Hydra's Current Working Directory": os.getcwd(), "seed": config.Training.seed, "log result file": log_result_file, "only_test_path": config.Training.only_test_path}
params_print_log(_train_params.__dict__, log_result_file)
params_print_log(_model_params.__dict__, log_result_file)
params_print_log(_optim_params.__dict__, log_result_file)
params_print_log(_other_params, log_result_file)
if config.Training.only_test_path == 'None':
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(config.Training.dist_port)
# num_gpus = 1
num_gpus = torch.cuda.device_count()
log.info("Total number of GPUs = {}".format(num_gpus))
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
else:
best_model_file = config.Training.only_test_path
# model re-init and loading
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
device = torch.device('cuda')
train_vocab_path = _train_params.external_func_vocab_file
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
if _model_params.ablation_models.lower() == "full":
model = HierarchicalGraphNeuralNetwork(model_params=_model_params, external_vocab=vocab, global_log=log)
else:
raise NotImplementedError
model.to(device)
model.load_state_dict(torch.load(best_model_file, map_location=device))
criterion = nn.BCELoss().cuda()
test_batch_size = config.Training.test_batch_size
dataset_root_path = _train_params.processed_files_path
# validation dataset & dataloader
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
valid_dataloader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, shuffle=False)
log.info("Total number of all validation samples = {} ".format(len(valid_dataset)))
# testing dataset & dataloader
test_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="test")
test_dataloader = DataLoaderX(dataset=test_dataset, batch_size=test_batch_size, shuffle=False)
log.info("Total number of all testing samples = {} ".format(len(test_dataset)))
_valid_result = validate(valid_loader=valid_dataloader, model=model, criterion=criterion, evaluate_flag="DoubleCheckValidation", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(valid_dataset), result_file=log_result_file, details=True)
log.warning("\n\n" + _valid_result.__str__())
_test_result = validate(valid_loader=test_dataloader, model=model, criterion=criterion, evaluate_flag="FinalTestingResult", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(test_dataset), result_file=log_result_file, details=True)
log.warning("\n\n" + _test_result.__str__())
if __name__ == '__main__':
main_app()

View File

@ -63,7 +63,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
self.global_log = global_log self.global_log = global_log
# Hierarchical 1: Control Flow Graph (CFG) embedding and pooling # Hierarchical 1: Control Flow Graph (CFG) embedding and pooling
print(type(model_params.cfg_filters), model_params.cfg_filters) # print(type(model_params.cfg_filters), model_params.cfg_filters)
if type(model_params.cfg_filters) == str: if type(model_params.cfg_filters) == str:
cfg_filter_list = [int(number_filter) for number_filter in model_params.cfg_filters.split("-")] cfg_filter_list = [int(number_filter) for number_filter in model_params.cfg_filters.split("-")]
else: else:
@ -89,7 +89,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling # Hierarchical 2: Function Call Graph (FCG) embedding and pooling
self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2, embedding_dim=cfg_filter_list[-1], padding_idx=external_vocab.pad_idx) self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2, embedding_dim=cfg_filter_list[-1], padding_idx=external_vocab.pad_idx)
print(type(model_params.fcg_filters), model_params.fcg_filters) # print(type(model_params.fcg_filters), model_params.fcg_filters)
if type(model_params.fcg_filters) == str: if type(model_params.fcg_filters) == str:
fcg_filter_list = [int(number_filter) for number_filter in model_params.fcg_filters.split("-")] fcg_filter_list = [int(number_filter) for number_filter in model_params.fcg_filters.split("-")]
else: else:

View File

@ -4,7 +4,7 @@ from datetime import datetime
import torch import torch
from torch_geometric.data import Dataset, DataLoader from torch_geometric.data import Dataset, DataLoader
from RealBatch import create_real_batch_data # noqa from utils.RealBatch import create_real_batch_data # noqa
class MalwareDetectionDataset(Dataset): class MalwareDetectionDataset(Dataset):
@ -66,7 +66,7 @@ def _simulating(_dataset, _batch_size: int):
if __name__ == '__main__': if __name__ == '__main__':
root_path: str = '/home/king/python/MalGraph-main/data/processed_dataset/DatasetJSON' root_path: str = '/home/king/python/data/processed_dataset/DatasetJSON_test'
i_batch_size = 2 i_batch_size = 2
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train') train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')

8
torch_test.py Normal file
View File

@ -0,0 +1,8 @@
import torch_geometric
import torch
if __name__ == '__main__':
# print(torch.__version__)
# print(torch.cuda.device_count())
# print(torch.cuda.get_device_name())
print(torch.cuda.nccl.is_available())
print(torch.cuda.nccl.version())