From 25093d8e47c801780d1554b9e2eed2e69a176972 Mon Sep 17 00:00:00 2001 From: TinyCaviar <379645931@qq.com> Date: Thu, 21 Sep 2023 17:15:21 +0800 Subject: [PATCH] backup --- src/DistTrainModel.py | 6 ++++-- src/models/HierarchicalGraphModel.py | 11 ++++++++--- src/utils/util.py | 20 ++++++++++++++++---- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/DistTrainModel.py b/src/DistTrainModel.py index 180a3d8..574443d 100644 --- a/src/DistTrainModel.py +++ b/src/DistTrainModel.py @@ -15,9 +15,11 @@ from hydra.utils import to_absolute_path from omegaconf import DictConfig from prefetch_generator import BackgroundGenerator from sklearn.metrics import roc_auc_score, roc_curve +import matplotlib.pyplot as plt from torch import nn from torch_geometric.data import DataLoader from tqdm import tqdm +from typing import List from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr @@ -95,7 +97,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op _eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt) valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs, original_valid_length=original_valid_length, result_file=result_file, details=False) - + if best_auc < valid_result.ROC_AUC_Score: _info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch, _idx_bt, @@ -246,7 +248,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m time_start = datetime.now() if local_rank == 0: write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50)) - + smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion, optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path, original_valid_length=ori_valid_length, result_file=log_result_file) diff --git a/src/models/HierarchicalGraphModel.py b/src/models/HierarchicalGraphModel.py index c7ddec2..4656201 100644 --- a/src/models/HierarchicalGraphModel.py +++ b/src/models/HierarchicalGraphModel.py @@ -13,6 +13,8 @@ sys.path.append("..") from utils.ParameterClasses import ModelParams # noqa from utils.Vocabulary import Vocab # noqa +perform_MID = True + def div_with_small_value(n, d, eps=1e-8): d = d * (d > eps).float() + eps * (d <= eps).float() @@ -119,8 +121,11 @@ class HierarchicalGraphNeuralNetwork(nn.Module): # self.last_activation = nn.Softmax(dim=1) # self.last_activation = nn.LogSoftmax(dim=1) - + def forward_cfg_gnn(self, local_batch: Batch): + if perform_MID: + return self.forward_MID_cfg_gnn(local_batch) + in_x, edge_index = local_batch.x, local_batch.edge_index for i in range(self.cfg_filter_length - 1): out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index) @@ -138,7 +143,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module): for acfg in cfg_subgraph_loader: subgraph_embeddings = [] for subgraph in acfg: - in_x, edge_index = subgraph.x, subgraph.edge_index + in_x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device) batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device) for i in range(self.cfg_filter_length - 1): out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index) @@ -206,7 +211,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module): def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device): - rtn_local_batch = self.forward_MID_cfg_gnn(local_batch=real_local_batch) + rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch) x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch) # build the Function Call Graph (FCG) Data/Batch datasets diff --git a/src/utils/util.py b/src/utils/util.py index 81f91fa..868eb13 100644 --- a/src/utils/util.py +++ b/src/utils/util.py @@ -19,10 +19,10 @@ def transfer_remote(): shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index))) index += 1 - delete_all_remote() + delete_remote() -def delete_all_remote(): +def delete_remote(): samples_dir = '/root/autodl-tmp' sample = ['malware', 'benign'] tags = ['test', 'train', 'valid'] @@ -33,6 +33,16 @@ def delete_all_remote(): os.remove(os.path.join(file_dir, f)) +def delete_remote_backup(): + samples_dir = '/root/autodl-tmp' + dir_name = ['all', 'all_benign', 'one_family_malware', 'test_malware_backup', 'valid_malware_backup', 'train_malware_backup'] + for name in dir_name: + file_dir = os.path.join(samples_dir, name) + if os.path.exists(file_dir): + for f in os.listdir(file_dir): + os.remove(os.path.join(file_dir, f)) + + # 重命名pt文件使之与代码相符 def rename(file_dir, mal_or_be, postfix): tag_set = ['train', 'test', 'valid'] @@ -65,7 +75,7 @@ def split_samples(flag): os_list = os.listdir(path) random.shuffle(os_list) - # 8/1/1 分数据 + # 6/2/2 分数据 train_len = int(len(os_list) * 0.6) test_len = int(train_len / 3) for index, f in enumerate(os_list): @@ -79,7 +89,9 @@ def split_samples(flag): if __name__ == '__main__': + # delete_remote_backup() # transfer_remote() - delete_all_remote() + # delete_remote() + split_samples('standard') split_samples('one_family') split_samples('benign')