backup
This commit is contained in:
parent
5d1e7e1ed0
commit
25093d8e47
@ -15,9 +15,11 @@ from hydra.utils import to_absolute_path
|
|||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from prefetch_generator import BackgroundGenerator
|
from prefetch_generator import BackgroundGenerator
|
||||||
from sklearn.metrics import roc_auc_score, roc_curve
|
from sklearn.metrics import roc_auc_score, roc_curve
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch_geometric.data import DataLoader
|
from torch_geometric.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||||
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
|
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
|
||||||
@ -95,7 +97,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op
|
|||||||
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
|
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
|
||||||
valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs,
|
valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs,
|
||||||
original_valid_length=original_valid_length, result_file=result_file, details=False)
|
original_valid_length=original_valid_length, result_file=result_file, details=False)
|
||||||
|
|
||||||
if best_auc < valid_result.ROC_AUC_Score:
|
if best_auc < valid_result.ROC_AUC_Score:
|
||||||
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
|
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
|
||||||
_idx_bt,
|
_idx_bt,
|
||||||
@ -246,7 +248,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
|
|||||||
time_start = datetime.now()
|
time_start = datetime.now()
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||||
|
|
||||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
||||||
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
||||||
original_valid_length=ori_valid_length, result_file=log_result_file)
|
original_valid_length=ori_valid_length, result_file=log_result_file)
|
||||||
|
@ -13,6 +13,8 @@ sys.path.append("..")
|
|||||||
from utils.ParameterClasses import ModelParams # noqa
|
from utils.ParameterClasses import ModelParams # noqa
|
||||||
from utils.Vocabulary import Vocab # noqa
|
from utils.Vocabulary import Vocab # noqa
|
||||||
|
|
||||||
|
perform_MID = True
|
||||||
|
|
||||||
|
|
||||||
def div_with_small_value(n, d, eps=1e-8):
|
def div_with_small_value(n, d, eps=1e-8):
|
||||||
d = d * (d > eps).float() + eps * (d <= eps).float()
|
d = d * (d > eps).float() + eps * (d <= eps).float()
|
||||||
@ -119,8 +121,11 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
|||||||
|
|
||||||
# self.last_activation = nn.Softmax(dim=1)
|
# self.last_activation = nn.Softmax(dim=1)
|
||||||
# self.last_activation = nn.LogSoftmax(dim=1)
|
# self.last_activation = nn.LogSoftmax(dim=1)
|
||||||
|
|
||||||
def forward_cfg_gnn(self, local_batch: Batch):
|
def forward_cfg_gnn(self, local_batch: Batch):
|
||||||
|
if perform_MID:
|
||||||
|
return self.forward_MID_cfg_gnn(local_batch)
|
||||||
|
|
||||||
in_x, edge_index = local_batch.x, local_batch.edge_index
|
in_x, edge_index = local_batch.x, local_batch.edge_index
|
||||||
for i in range(self.cfg_filter_length - 1):
|
for i in range(self.cfg_filter_length - 1):
|
||||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||||
@ -138,7 +143,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
|||||||
for acfg in cfg_subgraph_loader:
|
for acfg in cfg_subgraph_loader:
|
||||||
subgraph_embeddings = []
|
subgraph_embeddings = []
|
||||||
for subgraph in acfg:
|
for subgraph in acfg:
|
||||||
in_x, edge_index = subgraph.x, subgraph.edge_index
|
in_x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device)
|
||||||
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
|
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
|
||||||
for i in range(self.cfg_filter_length - 1):
|
for i in range(self.cfg_filter_length - 1):
|
||||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||||
@ -206,7 +211,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
|||||||
|
|
||||||
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device):
|
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device):
|
||||||
|
|
||||||
rtn_local_batch = self.forward_MID_cfg_gnn(local_batch=real_local_batch)
|
rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch)
|
||||||
x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch)
|
x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch)
|
||||||
|
|
||||||
# build the Function Call Graph (FCG) Data/Batch datasets
|
# build the Function Call Graph (FCG) Data/Batch datasets
|
||||||
|
@ -19,10 +19,10 @@ def transfer_remote():
|
|||||||
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
|
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
delete_all_remote()
|
delete_remote()
|
||||||
|
|
||||||
|
|
||||||
def delete_all_remote():
|
def delete_remote():
|
||||||
samples_dir = '/root/autodl-tmp'
|
samples_dir = '/root/autodl-tmp'
|
||||||
sample = ['malware', 'benign']
|
sample = ['malware', 'benign']
|
||||||
tags = ['test', 'train', 'valid']
|
tags = ['test', 'train', 'valid']
|
||||||
@ -33,6 +33,16 @@ def delete_all_remote():
|
|||||||
os.remove(os.path.join(file_dir, f))
|
os.remove(os.path.join(file_dir, f))
|
||||||
|
|
||||||
|
|
||||||
|
def delete_remote_backup():
|
||||||
|
samples_dir = '/root/autodl-tmp'
|
||||||
|
dir_name = ['all', 'all_benign', 'one_family_malware', 'test_malware_backup', 'valid_malware_backup', 'train_malware_backup']
|
||||||
|
for name in dir_name:
|
||||||
|
file_dir = os.path.join(samples_dir, name)
|
||||||
|
if os.path.exists(file_dir):
|
||||||
|
for f in os.listdir(file_dir):
|
||||||
|
os.remove(os.path.join(file_dir, f))
|
||||||
|
|
||||||
|
|
||||||
# 重命名pt文件使之与代码相符
|
# 重命名pt文件使之与代码相符
|
||||||
def rename(file_dir, mal_or_be, postfix):
|
def rename(file_dir, mal_or_be, postfix):
|
||||||
tag_set = ['train', 'test', 'valid']
|
tag_set = ['train', 'test', 'valid']
|
||||||
@ -65,7 +75,7 @@ def split_samples(flag):
|
|||||||
|
|
||||||
os_list = os.listdir(path)
|
os_list = os.listdir(path)
|
||||||
random.shuffle(os_list)
|
random.shuffle(os_list)
|
||||||
# 8/1/1 分数据
|
# 6/2/2 分数据
|
||||||
train_len = int(len(os_list) * 0.6)
|
train_len = int(len(os_list) * 0.6)
|
||||||
test_len = int(train_len / 3)
|
test_len = int(train_len / 3)
|
||||||
for index, f in enumerate(os_list):
|
for index, f in enumerate(os_list):
|
||||||
@ -79,7 +89,9 @@ def split_samples(flag):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# delete_remote_backup()
|
||||||
# transfer_remote()
|
# transfer_remote()
|
||||||
delete_all_remote()
|
# delete_remote()
|
||||||
|
split_samples('standard')
|
||||||
split_samples('one_family')
|
split_samples('one_family')
|
||||||
split_samples('benign')
|
split_samples('benign')
|
||||||
|
Loading…
Reference in New Issue
Block a user