diff --git a/configs/default.yaml b/configs/default.yaml index 08ce646..e3ecef9 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,6 +1,6 @@ Data: - preprocess_root: "../data/processed_dataset/DatasetJSON/" - train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl" + preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON" + train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl" max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py Training: cuda: True # enable GPU training if cuda is available diff --git a/requirement_conda.txt b/requirement_conda.txt index ec5bd75..f9d3094 100644 --- a/requirement_conda.txt +++ b/requirement_conda.txt @@ -1,79 +1,83 @@ - -antlr4-python3-runtime==4.8 -ase==3.21.1 -ca-certificates==2021.1.19 -cached-property==1.5.2 -certifi==2020.12.5 -cffi==1.14.5 -chardet==4.0.0 -cmake==3.18.4.post1 -cycler==0.10.0 -dataclasses==0.6 -decorator==4.4.2 -future==0.18.2 -googledrivedownloader==0.4 -h5py==3.2.1 -hydra-core==1.0.6 -idna==2.10 -importlib-resources==5.1.2 -intel-openmp==2021.1.2 -isodate==0.6.0 -jinja2==2.11.3 -joblib==1.0.1 -kiwisolver==1.3.1 -ld_impl_linux-64==2.33.1 -libedit==3.1.20191231 -libffi==3.3 -libgcc-ng==9.1.0 -libstdcxx-ng==9.1.0 -llvmlite==0.35.0 -magma-cuda112==2.5.2 -markupsafe==1.1.1 -matplotlib==3.3.4 -mkl==2021.1.1 -mkl-include==2021.1.1 -ncurses==6.2 -networkx==2.5 -ninja==1.10.0.post2 -numba==0.52.0 -numpy==1.20.1 -omegaconf==2.0.6 -openssl==1.1.1j -pandas==1.2.3 -pillow==8.1.2 -pip==21.0.1 -prefetch-generator==1.0.1 -pycparser==2.20 -pyparsing==2.4.7 -python-dateutil==2.8.1 -python-louvain==0.15 -pytz==2021.1 -pyyaml==5.4.1 -rdflib==5.0.0 -readline==8.1 -requests==2.25.1 -scikit-learn==0.24.1 -scipy==1.6.1 -seaborn==0.11.1 -setuptools==52.0.0 -six==1.15.0 -sqlite==3.33.0 -tbb==2021.1.1 -texttable==1.6.3 -threadpoolctl==2.1.0 -tk==8.6.10 -torch==1.8.0+cu111 -torch-cluster==1.5.9 -torch-geometric==1.6.3 -torch-scatter==2.0.6 -torch-sparse==0.6.9 -torch-spline-conv==1.2.1 -torchaudio==0.8.0 -torchvision==0.9.0+cu111 -tqdm==4.59.0 -typing-extensions==3.7.4.3 -urllib3==1.26.3 -wheel==0.36.2 -xz==5.2.5 -zipp==3.4.1 -zlib==1.2.11 +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_libgcc_mutex=0.1=main +antlr4-python3-runtime=4.8=pypi_0 +ase=3.21.1=pypi_0 +ca-certificates=2021.1.19=h06a4308_1 +cached-property=1.5.2=pypi_0 +certifi=2020.12.5=py37h06a4308_0 +cffi=1.14.5=pypi_0 +chardet=4.0.0=pypi_0 +cmake=3.18.4.post1=pypi_0 +cycler=0.10.0=pypi_0 +dataclasses=0.6=pypi_0 +decorator=4.4.2=pypi_0 +future=0.18.2=pypi_0 +googledrivedownloader=0.4=pypi_0 +h5py=3.2.1=pypi_0 +hydra-core=1.0.6=pypi_0 +idna=2.10=pypi_0 +importlib-resources=5.1.2=pypi_0 +intel-openmp=2021.1.2=pypi_0 +isodate=0.6.0=pypi_0 +jinja2=2.11.3=pypi_0 +joblib=1.0.1=pypi_0 +kiwisolver=1.3.1=pypi_0 +ld_impl_linux-64=2.33.1=h53a641e_7 +libedit=3.1.20191231=h14c3975_1 +libffi=3.3=he6710b0_2 +libgcc-ng=9.1.0=hdf63c60_0 +libstdcxx-ng=9.1.0=hdf63c60_0 +llvmlite=0.35.0=pypi_0 +magma-cuda112=2.5.2=1 +markupsafe=1.1.1=pypi_0 +matplotlib=3.3.4=pypi_0 +mkl=2021.1.1=pypi_0 +mkl-include=2021.1.1=pypi_0 +ncurses=6.2=he6710b0_1 +networkx=2.5=pypi_0 +ninja=1.10.0.post2=pypi_0 +numba=0.52.0=pypi_0 +numpy=1.20.1=pypi_0 +omegaconf=2.0.6=pypi_0 +openssl=1.1.1j=h27cfd23_0 +pandas=1.2.3=pypi_0 +pillow=8.1.2=pypi_0 +pip=21.0.1=py37h06a4308_0 +prefetch-generator=1.0.1=pypi_0 +pycparser=2.20=pypi_0 +pyparsing=2.4.7=pypi_0 +python=3.7.9=h7579374_0 +python-dateutil=2.8.1=pypi_0 +python-louvain=0.15=pypi_0 +pytz=2021.1=pypi_0 +pyyaml=5.4.1=pypi_0 +rdflib=5.0.0=pypi_0 +readline=8.1=h27cfd23_0 +requests=2.25.1=pypi_0 +scikit-learn=0.24.1=pypi_0 +scipy=1.6.1=pypi_0 +seaborn=0.11.1=pypi_0 +setuptools=52.0.0=py37h06a4308_0 +six=1.15.0=pypi_0 +sqlite=3.33.0=h62c20be_0 +tbb=2021.1.1=pypi_0 +texttable=1.6.3=pypi_0 +threadpoolctl=2.1.0=pypi_0 +tk=8.6.10=hbc83047_0 +torch=1.8.0+cu111=pypi_0 +torch-cluster=1.5.9=pypi_0 +torch-geometric=1.6.3=pypi_0 +torch-scatter=2.0.6=pypi_0 +torch-sparse=0.6.9=pypi_0 +torch-spline-conv=1.2.1=pypi_0 +torchaudio=0.8.0=pypi_0 +torchvision=0.9.0+cu111=pypi_0 +tqdm=4.59.0=pypi_0 +typing-extensions=3.7.4.3=pypi_0 +urllib3=1.26.3=pypi_0 +wheel=0.36.2=pyhd3eb1b0_0 +xz=5.2.5=h7b6447c_0 +zipp=3.4.1=pypi_0 +zlib=1.2.11=h7b6447c_3 diff --git a/samples/PreProcess.py b/samples/PreProcess.py index 9970f53..6a5ba3d 100644 --- a/samples/PreProcess.py +++ b/samples/PreProcess.py @@ -1,4 +1,6 @@ import json +import os + import torch from torch_geometric.data import Data from tqdm import tqdm @@ -6,37 +8,68 @@ from tqdm import tqdm from src.utils.Vocabulary import Vocab -def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab): -#def parse_json_list_2_pyg_object(jsonl_file: str): +def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save_path: str, file_type: str): + # def parse_json_list_2_pyg_object(jsonl_file: str): + train_type = ['train', 'valid', 'test'] index = 0 - with open(jsonl_file, "r", encoding="utf-8") as file: - for item in tqdm(file): - item = json.loads(item) - item_hash = item['hash'] - - acfg_list = [] - for one_acfg in item['acfg_list']: # list of dict of acfg - block_features = one_acfg['block_features'] - block_edges = one_acfg['block_edges'] - one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float), edge_index=torch.tensor(block_edges, dtype=torch.long)) - acfg_list.append(one_acfg_data) - - item_function_names = item['function_names'] - item_function_edges = item['function_edges'] - - local_function_name_list = item_function_names[:len(acfg_list)] - assert len(acfg_list) == len(local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List" - external_function_name_list = item_function_names[len(acfg_list):] - - external_function_index_list = [vocab[f_name] for f_name in external_function_name_list] - index += 1 - torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list, function_edges=item_function_edges, targets=label), "./cache/benign_{}.pt".format(index)) + file_index = 0 + type_index = 0 + valid_flag = True + test_flag = True + file_len = len(os.listdir(jsonl_file)) + + for file in tqdm(os.listdir(jsonl_file)): + if index >= file_len * 0.8 and valid_flag: + type_index += 1 + valid_flag = False + file_index = 0 + print("make valid set") + elif index >= file_len * 0.9 and test_flag: + type_index += 1 + test_flag = False + file_index = 0 + print("make test set") + j = json_to_pt(file=jsonl_file + file, label=label, vocab=vocab, save_path=save_path, file_type=file_type, train_type=train_type[type_index], index=file_index) + index += 1 + file_index += 1 + + +def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int): + with open(file, "r", encoding="utf-8") as item: + line = item.readline() + item = json.loads(line) + item_hash = item['hash'] + acfg_list = [] + for one_acfg in item['acfg_list']: # list of dict of acfg + block_features = one_acfg['block_features'] + block_edges = one_acfg['block_edges'] + one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float), + edge_index=torch.tensor(block_edges, dtype=torch.long)) + acfg_list.append(one_acfg_data) + + item_function_names = item['function_names'] + item_function_edges = item['function_edges'] + + local_function_name_list = item_function_names[:len(acfg_list)] + assert len(acfg_list) == len( + local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List" + external_function_name_list = item_function_names[len(acfg_list):] + + external_function_index_list = [vocab[f_name] for f_name in external_function_name_list] + + torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list, + function_edges=item_function_edges, targets=label), + save_path + "{}_{}/{}_{}.pt".format(train_type, file_type, file_type, index)) + return True + if __name__ == '__main__': - json_path = "./benign_result.jsonl" + json_path = "./jsonl/infected_jsonl/" train_vocab_file = "../data/processed_dataset/train_external_function_name_vocab.jsonl" - # train_vocab_file = "./res.jsonl" + save_vocab_file = "../data/processed_dataset/DatasetJSON/" + file_type = "malware" max_vocab_size = 10000 vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size) - parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary) + parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary, save_path=save_vocab_file, + file_type=file_type) diff --git a/samples/funCount.py b/samples/funCount.py index 42a1bdb..034bd9b 100644 --- a/samples/funCount.py +++ b/samples/funCount.py @@ -1,31 +1,38 @@ import json +import os +from itertools import islice +import heapq from tqdm import tqdm if __name__ == '__main__': - mal_file_name = './malware_result.jsonl' - ben_file_name = './benign-result.jsonl' - fil = open(mal_file_name, mode='r') + mal_file_name = './jsonl/infected_jsonl/' + ben_file_name = './jsonl/refind_jsonl/' fun_name_dict = {} - for item in tqdm(fil): - item = json.loads(item) - item_fun_list = item['function_names'] - for fun_name in item_fun_list: - if fun_name_dict.get(fun_name) is not None: - fun_name_dict[fun_name] += 1 - else: - fun_name_dict[fun_name] = 1 - fil = open(mal_file_name, mode='r') - for item in tqdm(fil): - item = json.loads(item) - item_fun_list = item['function_names'] - for fun_name in item_fun_list: - if fun_name_dict.get(fun_name) is not None: - fun_name_dict[fun_name] += 1 - else: - fun_name_dict[fun_name] = 1 - + for file in tqdm(os.listdir(mal_file_name)): + with open(mal_file_name + file, 'r') as item: + item = json.loads(item.readline()) + item_fun_list = item['function_names'] + for fun_name in item_fun_list: + if fun_name != 'start' and fun_name != 'start_0' and 'sub_' not in fun_name: + if fun_name_dict.get(fun_name) is not None: + fun_name_dict[fun_name] += 1 + else: + fun_name_dict[fun_name] = 1 + for file in tqdm(os.listdir(ben_file_name)): + with open(ben_file_name + file, 'r') as item: + item = json.loads(item.readline()) + item_fun_list = item['function_names'] + for fun_name in item_fun_list: + if fun_name != 'start' and fun_name != 'start_0' and 'sub_' not in fun_name: + if fun_name_dict.get(fun_name) is not None: + fun_name_dict[fun_name] += 1 + else: + fun_name_dict[fun_name] = 1 with open('./res.jsonl', 'w') as file: - for key, value in fun_name_dict.items(): + largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1]) + for key, value in largest_10000_items: temp = {"f_name": key, "count": value} file.write(json.dumps(temp) + '\n') + + diff --git a/src/DistTrainModel.py b/src/DistTrainModel.py index 9869479..91f09ee 100644 --- a/src/DistTrainModel.py +++ b/src/DistTrainModel.py @@ -185,8 +185,8 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger, log_result_file: str): - # dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) - dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) + dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) + # dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) torch.cuda.set_device(local_rank) # model configure @@ -304,7 +304,7 @@ def main_app(config: DictConfig): num_gpus = torch.cuda.device_count() log.info("Total number of GPUs = {}".format(num_gpus)) torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) - + # main_train_worker(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file, "") best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0)) else: @@ -312,7 +312,7 @@ def main_app(config: DictConfig): # model re-init and loading log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file)) - device = torch.device('cuda') + device = torch.device('cuda:0') train_vocab_path = _train_params.external_func_vocab_file vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size) @@ -343,4 +343,4 @@ def main_app(config: DictConfig): if __name__ == '__main__': - main_app() \ No newline at end of file + main_app() diff --git a/src/DistTrainModel_dual.py b/src/DistTrainModel_dual.py new file mode 100644 index 0000000..fff2073 --- /dev/null +++ b/src/DistTrainModel_dual.py @@ -0,0 +1,392 @@ +import logging +import math +import os +import random +from datetime import datetime + +import hydra +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as torch_mp +import torch.utils.data +import torch.utils.data.distributed +from hydra.utils import to_absolute_path +from omegaconf import DictConfig +from prefetch_generator import BackgroundGenerator +from sklearn.metrics import roc_auc_score, roc_curve +from torch import nn +from torch_geometric.data import DataLoader +from tqdm import tqdm + +from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork +from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr +from utils.ParameterClasses import ModelParams, TrainParams, OptimizerParams, OneEpochResult +from utils.PreProcessedDataset import MalwareDetectionDataset +from utils.RealBatch import create_real_batch_data +from utils.Vocabulary import Vocab + + +class DataLoaderX(DataLoader): + def __iter__(self): + return BackgroundGenerator(super().__iter__()) + + +def reduce_sum(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa + return rt + + +def reduce_mean(tensor, nprocs): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa + rt /= nprocs + return rt + + + +def all_gather_concat(tensor): + tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] + dist.all_gather(tensors_gather, tensor, async_op=False) + output = torch.cat(tensors_gather, dim=0) + return output + + +def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, optimizer, nprocs, idx_epoch, best_auc, best_model_file, original_valid_length, result_file): + # print(train_loader.dataset.__dict__) + model.train() + local_device = torch.device("cuda", local_rank) + write_into(file_name_path=result_file, log_str="The local device = {} among {} nprocs in the {}-th epoch.".format(local_device, nprocs, idx_epoch)) + + until_sum_reduced_loss = 0.0 + smooth_avg_reduced_loss_list = [] + + for _idx_bt, _batch in enumerate(tqdm(train_loader, desc="reading _batch from local_rank={}".format(local_rank))): + model.train() + _real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=_batch) + if _real_batch is None: + write_into(result_file, + "{}\n_real_batch is None in creating the real batch data of training ... ".format("*-" * 100)) + continue + + _real_batch = _real_batch.to(local_device) + _position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True) + _true_classes = _true_classes.float().cuda(local_rank, non_blocking=True) + + train_batch_pred = model(real_local_batch=_real_batch, + real_bt_positions=_position, + bt_external_names=_external_list, + bt_all_function_edges=_function_edges, + local_device=local_device) + train_batch_pred = train_batch_pred.squeeze() + + loss = criterion(train_batch_pred, _true_classes) + + torch.distributed.barrier() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + reduced_loss = reduce_mean(loss, nprocs) + until_sum_reduced_loss += reduced_loss.item() + smooth_avg_reduced_loss_list.append(until_sum_reduced_loss / (_idx_bt + 1)) + + if _idx_bt != 0 and (_idx_bt % math.ceil(len(train_loader) / 3) == 0 or _idx_bt == int(len(train_loader) - 1)): + + val_start_time = datetime.now() + if local_rank == 0: + write_into(result_file, "\nIn {}-th epoch, {}-th batch, we start to validate ... ".format(idx_epoch, _idx_bt)) + + _eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt) + valid_result = validate(local_rank=local_rank, + valid_loader=valid_loader, + model=model, + criterion=criterion, + evaluate_flag=_eval_flag, + distributed=True, # 分布式 + nprocs=nprocs, + original_valid_length=original_valid_length, + result_file=result_file, + details=True # 验证细节 + ) + + if best_auc < valid_result.ROC_AUC_Score: + _info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch, + _idx_bt, + best_auc, + valid_result.ROC_AUC_Score, + best_model_file) + best_auc = valid_result.ROC_AUC_Score + torch.save(model.module.state_dict(), best_model_file) + else: + _info = "[AUC NOT Increased!] AUC decreased from {:.5f} to {:.5f}!".format(best_auc, valid_result.ROC_AUC_Score) + + if local_rank == 0: + write_into(result_file, valid_result.__str__()) + write_into(result_file, _info) + write_into(result_file, "[#One Validation Time#] Consume about {} time period for one validation.".format(datetime.now() - val_start_time)) + + return smooth_avg_reduced_loss_list, best_auc + + +def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details): + model.eval() + if distributed: + local_device = torch.device("cuda", local_rank) + else: + local_device = torch.device("cuda") + + sum_loss = torch.tensor(0.0, dtype=torch.float, device=local_device) + n_samples = torch.tensor(0, dtype=torch.int, device=local_device) + + all_true_classes = [] + all_positive_probs = [] + + with torch.no_grad(): + for idx_batch, data in enumerate(tqdm(valid_loader)): + _real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data) + if _real_batch is None: + write_into(result_file, "{}\n_real_batch is None in creating the real batch data of validation ... ".format("*-" * 100)) + continue + _real_batch = _real_batch.to(local_device) + _position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True) + _true_classes = _true_classes.float().cuda(local_rank, non_blocking=True) + + batch_pred = model(real_local_batch=_real_batch, + real_bt_positions=_position, + bt_external_names=_external_list, + bt_all_function_edges=_function_edges, + local_device=local_device) + batch_pred = batch_pred.squeeze(-1) + loss = criterion(batch_pred, _true_classes) + sum_loss += loss.item() + + n_samples += len(batch_pred) + + all_true_classes.append(_true_classes) + all_positive_probs.append(batch_pred) + + avg_loss = sum_loss / (idx_batch + 1) + all_true_classes = torch.cat(all_true_classes, dim=0) + all_positive_probs = torch.cat(all_positive_probs, dim=0) + + if distributed: + torch.distributed.barrier() + reduced_n_samples = reduce_sum(n_samples) + reduced_avg_loss = reduce_mean(avg_loss, nprocs) + gather_true_classes = all_gather_concat(all_true_classes).detach().cpu().numpy() + gather_positive_prods = all_gather_concat(all_positive_probs).detach().cpu().numpy() + + gather_true_classes = gather_true_classes[:original_valid_length] + gather_positive_prods = gather_positive_prods[:original_valid_length] + + else: + reduced_n_samples = n_samples + reduced_avg_loss = avg_loss + gather_true_classes = all_true_classes.detach().cpu().numpy() + gather_positive_prods = all_positive_probs.detach().cpu().numpy() + + # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html + # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html + # try: + # _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) + # except ValueError: + # _roc_auc_score = 0.001 + print(gather_true_classes) + print(gather_positive_prods) + # try: + # _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) + # except ValueError: + # pass + _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) + _fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods) + if details is True: + _100_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.01) + _1000_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.001) + else: + _100_info, _1000_info = "None", "None" + + _eval_result = OneEpochResult(Epoch_Flag=evaluate_flag, + Number_Samples=reduced_n_samples, + Avg_Loss=reduced_avg_loss, + Info_100=_100_info, + Info_1000=_1000_info, + ROC_AUC_Score=_roc_auc_score, + Thresholds=_thresholds, + TPRs=_tpr, + FPRs=_fpr) + return _eval_result + + +def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger, + log_result_file: str): + # dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) + dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) + torch.cuda.set_device(local_rank) + + # model configure + vocab = Vocab(freq_file=train_params.external_func_vocab_file, max_vocab_size=train_params.max_vocab_size) + + if model_params.ablation_models.lower() == "full": + model = HierarchicalGraphNeuralNetwork(model_params=model_params, external_vocab=vocab, global_log=global_log) + else: + raise NotImplementedError + + model.cuda(local_rank) + model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + + if local_rank == 0: + write_into(file_name_path=log_result_file, log_str=model.__str__()) + + # loss function + criterion = nn.BCELoss().cuda(local_rank) + + lr = optimizer_params.lr + if optimizer_params.optimizer_name.lower() == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + elif optimizer_params.optimizer_name.lower() == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=optimizer_params.weight_decay) + else: + raise NotImplementedError + + max_epochs = train_params.max_epochs + + dataset_root_path = train_params.processed_files_path + train_batch_size = train_params.train_bs + test_batch_size = train_params.test_bs + + # training dataset & dataloader + train_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="train") + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler) + # validation dataset & dataloader + valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid") + valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) + valid_loader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, pin_memory=True, sampler=valid_sampler) + + if local_rank == 0: + write_into(file_name_path=log_result_file, log_str="Training dataset={}, sampler={}, loader={}".format(len(train_dataset), len(train_sampler), len(train_loader))) + write_into(file_name_path=log_result_file, log_str="Validation dataset={}, sampler={}, loader={}".format(len(valid_dataset), len(valid_sampler), len(valid_loader))) + + best_auc = 0.0 + ori_valid_length = len(valid_dataset) + best_model_path = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(local_rank)) + + all_batch_avg_smooth_loss_list = [] + for epoch in range(max_epochs): + train_sampler.set_epoch(epoch) + valid_sampler.set_epoch(epoch) + + # train for one epoch + time_start = datetime.now() + if local_rank == 0: + write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50)) + + smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, + train_loader=train_loader, + valid_loader=valid_loader, + model=model, + criterion=criterion, + optimizer=optimizer, + nprocs=nprocs, + idx_epoch=epoch, + best_auc=best_auc, + best_model_file=best_model_path, + original_valid_length=ori_valid_length, + result_file=log_result_file) + all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list) + + # adjust learning rate + for param_group in optimizer.param_groups: + param_group['lr'] = param_group['lr'] / optimizer_params.learning_anneal + + time_end = datetime.now() + if local_rank == 0: + write_into(log_result_file, "\n{} end of {}-epoch, with best_auc={}, len of loss={}, end time={}, and time period={} {}".format("*" * 50, epoch, best_auc, + len(smooth_avg_reduced_loss_list), + time_end.strftime("%Y-%m-%d@%H:%M:%S"), + time_end - time_start, "*" * 50)) + + +# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default +@hydra.main(config_path="../configs/", config_name="default.yaml") +def main_app(config: DictConfig): + # set seed for determinism for reproduction + random.seed(config.Training.seed) + np.random.seed(config.Training.seed) + torch.manual_seed(config.Training.seed) + torch.cuda.manual_seed(config.Training.seed) + torch.cuda.manual_seed_all(config.Training.seed) + + # setting hyper-parameter for Training / Model / Optimizer + _train_params = TrainParams(processed_files_path=to_absolute_path(config.Data.preprocess_root), max_epochs=config.Training.max_epoches, train_bs=config.Training.train_batch_size, test_bs=config.Training.test_batch_size, external_func_vocab_file=to_absolute_path(config.Data.train_vocab_file), max_vocab_size=config.Data.max_vocab_size) + _model_params = ModelParams(gnn_type=config.Model.gnn_type, pool_type=config.Model.pool_type, acfg_init_dims=config.Model.acfg_node_init_dims, cfg_filters=config.Model.cfg_filters, fcg_filters=config.Model.fcg_filters, number_classes=config.Model.number_classes, dropout_rate=config.Model.drapout_rate, ablation_models=config.Model.ablation_models) + _optim_params = OptimizerParams(optimizer_name=config.Optimizer.name, lr=config.Optimizer.learning_rate, weight_decay=config.Optimizer.weight_decay, learning_anneal=config.Optimizer.learning_anneal) + + # logging + log = logging.getLogger("DistTrainModel.py") + log.setLevel("DEBUG") + log.warning("Hydra's Current Working Directory: {}".format(os.getcwd())) + + # setting for the log directory + result_file = '{}_{}_{}_ACFG_{}_FCG_{}_Epoch_{}_TrainBS_{}_LR_{}_Time_{}.txt'.format(_model_params.ablation_models, _model_params.gnn_type, _model_params.pool_type, + _model_params.cfg_filters, _model_params.fcg_filters, _train_params.max_epochs, + _train_params.train_bs, _optim_params.lr, datetime.now().strftime("%Y%m%d_%H%M%S")) + log_result_file = os.path.join(os.getcwd(), result_file) + + _other_params = {"Hydra's Current Working Directory": os.getcwd(), "seed": config.Training.seed, "log result file": log_result_file, "only_test_path": config.Training.only_test_path} + + params_print_log(_train_params.__dict__, log_result_file) + params_print_log(_model_params.__dict__, log_result_file) + params_print_log(_optim_params.__dict__, log_result_file) + params_print_log(_other_params, log_result_file) + + if config.Training.only_test_path == 'None': + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(config.Training.dist_port) + # num_gpus = 1 + num_gpus = torch.cuda.device_count() + log.info("Total number of GPUs = {}".format(num_gpus)) + torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) + + best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0)) + + else: + best_model_file = config.Training.only_test_path + + # model re-init and loading + log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file)) + device = torch.device('cuda') + train_vocab_path = _train_params.external_func_vocab_file + vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size) + + if _model_params.ablation_models.lower() == "full": + model = HierarchicalGraphNeuralNetwork(model_params=_model_params, external_vocab=vocab, global_log=log) + else: + raise NotImplementedError + model.to(device) + model.load_state_dict(torch.load(best_model_file, map_location=device)) + criterion = nn.BCELoss().cuda() + + test_batch_size = config.Training.test_batch_size + dataset_root_path = _train_params.processed_files_path + # validation dataset & dataloader + valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid") + valid_dataloader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, shuffle=False) + log.info("Total number of all validation samples = {} ".format(len(valid_dataset))) + + # testing dataset & dataloader + test_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="test") + test_dataloader = DataLoaderX(dataset=test_dataset, batch_size=test_batch_size, shuffle=False) + log.info("Total number of all testing samples = {} ".format(len(test_dataset))) + + _valid_result = validate(valid_loader=valid_dataloader, model=model, criterion=criterion, evaluate_flag="DoubleCheckValidation", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(valid_dataset), result_file=log_result_file, details=True) + log.warning("\n\n" + _valid_result.__str__()) + _test_result = validate(valid_loader=test_dataloader, model=model, criterion=criterion, evaluate_flag="FinalTestingResult", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(test_dataset), result_file=log_result_file, details=True) + log.warning("\n\n" + _test_result.__str__()) + + +if __name__ == '__main__': + main_app() diff --git a/src/models/HierarchicalGraphModel.py b/src/models/HierarchicalGraphModel.py index 2b4e9b4..3855083 100644 --- a/src/models/HierarchicalGraphModel.py +++ b/src/models/HierarchicalGraphModel.py @@ -63,7 +63,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module): self.global_log = global_log # Hierarchical 1: Control Flow Graph (CFG) embedding and pooling - print(type(model_params.cfg_filters), model_params.cfg_filters) + # print(type(model_params.cfg_filters), model_params.cfg_filters) if type(model_params.cfg_filters) == str: cfg_filter_list = [int(number_filter) for number_filter in model_params.cfg_filters.split("-")] else: @@ -89,7 +89,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module): # Hierarchical 2: Function Call Graph (FCG) embedding and pooling self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2, embedding_dim=cfg_filter_list[-1], padding_idx=external_vocab.pad_idx) - print(type(model_params.fcg_filters), model_params.fcg_filters) + # print(type(model_params.fcg_filters), model_params.fcg_filters) if type(model_params.fcg_filters) == str: fcg_filter_list = [int(number_filter) for number_filter in model_params.fcg_filters.split("-")] else: diff --git a/src/utils/PreProcessedDataset.py b/src/utils/PreProcessedDataset.py index 2797891..87e7010 100644 --- a/src/utils/PreProcessedDataset.py +++ b/src/utils/PreProcessedDataset.py @@ -4,7 +4,7 @@ from datetime import datetime import torch from torch_geometric.data import Dataset, DataLoader -from RealBatch import create_real_batch_data # noqa +from utils.RealBatch import create_real_batch_data # noqa class MalwareDetectionDataset(Dataset): @@ -66,7 +66,7 @@ def _simulating(_dataset, _batch_size: int): if __name__ == '__main__': - root_path: str = '/home/king/python/MalGraph-main/data/processed_dataset/DatasetJSON' + root_path: str = '/home/king/python/data/processed_dataset/DatasetJSON_test' i_batch_size = 2 train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train') diff --git a/torch_test.py b/torch_test.py new file mode 100644 index 0000000..9044dc9 --- /dev/null +++ b/torch_test.py @@ -0,0 +1,8 @@ +import torch_geometric +import torch +if __name__ == '__main__': + # print(torch.__version__) + # print(torch.cuda.device_count()) + # print(torch.cuda.get_device_name()) + print(torch.cuda.nccl.is_available()) + print(torch.cuda.nccl.version())