backup
This commit is contained in:
parent
5d1e7e1ed0
commit
25093d8e47
@ -15,9 +15,11 @@ from hydra.utils import to_absolute_path
|
||||
from omegaconf import DictConfig
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
import matplotlib.pyplot as plt
|
||||
from torch import nn
|
||||
from torch_geometric.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
|
||||
@ -95,7 +97,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op
|
||||
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
|
||||
valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs,
|
||||
original_valid_length=original_valid_length, result_file=result_file, details=False)
|
||||
|
||||
|
||||
if best_auc < valid_result.ROC_AUC_Score:
|
||||
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
|
||||
_idx_bt,
|
||||
@ -246,7 +248,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
|
||||
time_start = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||
|
||||
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
||||
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length, result_file=log_result_file)
|
||||
|
@ -13,6 +13,8 @@ sys.path.append("..")
|
||||
from utils.ParameterClasses import ModelParams # noqa
|
||||
from utils.Vocabulary import Vocab # noqa
|
||||
|
||||
perform_MID = True
|
||||
|
||||
|
||||
def div_with_small_value(n, d, eps=1e-8):
|
||||
d = d * (d > eps).float() + eps * (d <= eps).float()
|
||||
@ -119,8 +121,11 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
|
||||
# self.last_activation = nn.Softmax(dim=1)
|
||||
# self.last_activation = nn.LogSoftmax(dim=1)
|
||||
|
||||
|
||||
def forward_cfg_gnn(self, local_batch: Batch):
|
||||
if perform_MID:
|
||||
return self.forward_MID_cfg_gnn(local_batch)
|
||||
|
||||
in_x, edge_index = local_batch.x, local_batch.edge_index
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
@ -138,7 +143,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
for acfg in cfg_subgraph_loader:
|
||||
subgraph_embeddings = []
|
||||
for subgraph in acfg:
|
||||
in_x, edge_index = subgraph.x, subgraph.edge_index
|
||||
in_x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device)
|
||||
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
@ -206,7 +211,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
|
||||
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device):
|
||||
|
||||
rtn_local_batch = self.forward_MID_cfg_gnn(local_batch=real_local_batch)
|
||||
rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch)
|
||||
x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch)
|
||||
|
||||
# build the Function Call Graph (FCG) Data/Batch datasets
|
||||
|
@ -19,10 +19,10 @@ def transfer_remote():
|
||||
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
|
||||
index += 1
|
||||
|
||||
delete_all_remote()
|
||||
delete_remote()
|
||||
|
||||
|
||||
def delete_all_remote():
|
||||
def delete_remote():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
sample = ['malware', 'benign']
|
||||
tags = ['test', 'train', 'valid']
|
||||
@ -33,6 +33,16 @@ def delete_all_remote():
|
||||
os.remove(os.path.join(file_dir, f))
|
||||
|
||||
|
||||
def delete_remote_backup():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
dir_name = ['all', 'all_benign', 'one_family_malware', 'test_malware_backup', 'valid_malware_backup', 'train_malware_backup']
|
||||
for name in dir_name:
|
||||
file_dir = os.path.join(samples_dir, name)
|
||||
if os.path.exists(file_dir):
|
||||
for f in os.listdir(file_dir):
|
||||
os.remove(os.path.join(file_dir, f))
|
||||
|
||||
|
||||
# 重命名pt文件使之与代码相符
|
||||
def rename(file_dir, mal_or_be, postfix):
|
||||
tag_set = ['train', 'test', 'valid']
|
||||
@ -65,7 +75,7 @@ def split_samples(flag):
|
||||
|
||||
os_list = os.listdir(path)
|
||||
random.shuffle(os_list)
|
||||
# 8/1/1 分数据
|
||||
# 6/2/2 分数据
|
||||
train_len = int(len(os_list) * 0.6)
|
||||
test_len = int(train_len / 3)
|
||||
for index, f in enumerate(os_list):
|
||||
@ -79,7 +89,9 @@ def split_samples(flag):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# delete_remote_backup()
|
||||
# transfer_remote()
|
||||
delete_all_remote()
|
||||
# delete_remote()
|
||||
split_samples('standard')
|
||||
split_samples('one_family')
|
||||
split_samples('benign')
|
||||
|
Loading…
Reference in New Issue
Block a user