This commit is contained in:
TinyCaviar 2023-09-21 17:15:21 +08:00
parent 5d1e7e1ed0
commit 25093d8e47
3 changed files with 28 additions and 9 deletions

View File

@ -15,9 +15,11 @@ from hydra.utils import to_absolute_path
from omegaconf import DictConfig
from prefetch_generator import BackgroundGenerator
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
from torch import nn
from torch_geometric.data import DataLoader
from tqdm import tqdm
from typing import List
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr

View File

@ -13,6 +13,8 @@ sys.path.append("..")
from utils.ParameterClasses import ModelParams # noqa
from utils.Vocabulary import Vocab # noqa
perform_MID = True
def div_with_small_value(n, d, eps=1e-8):
d = d * (d > eps).float() + eps * (d <= eps).float()
@ -121,6 +123,9 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
# self.last_activation = nn.LogSoftmax(dim=1)
def forward_cfg_gnn(self, local_batch: Batch):
if perform_MID:
return self.forward_MID_cfg_gnn(local_batch)
in_x, edge_index = local_batch.x, local_batch.edge_index
for i in range(self.cfg_filter_length - 1):
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
@ -138,7 +143,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
for acfg in cfg_subgraph_loader:
subgraph_embeddings = []
for subgraph in acfg:
in_x, edge_index = subgraph.x, subgraph.edge_index
in_x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device)
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
for i in range(self.cfg_filter_length - 1):
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
@ -206,7 +211,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device):
rtn_local_batch = self.forward_MID_cfg_gnn(local_batch=real_local_batch)
rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch)
x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch)
# build the Function Call Graph (FCG) Data/Batch datasets

View File

@ -19,10 +19,10 @@ def transfer_remote():
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
index += 1
delete_all_remote()
delete_remote()
def delete_all_remote():
def delete_remote():
samples_dir = '/root/autodl-tmp'
sample = ['malware', 'benign']
tags = ['test', 'train', 'valid']
@ -33,6 +33,16 @@ def delete_all_remote():
os.remove(os.path.join(file_dir, f))
def delete_remote_backup():
samples_dir = '/root/autodl-tmp'
dir_name = ['all', 'all_benign', 'one_family_malware', 'test_malware_backup', 'valid_malware_backup', 'train_malware_backup']
for name in dir_name:
file_dir = os.path.join(samples_dir, name)
if os.path.exists(file_dir):
for f in os.listdir(file_dir):
os.remove(os.path.join(file_dir, f))
# 重命名pt文件使之与代码相符
def rename(file_dir, mal_or_be, postfix):
tag_set = ['train', 'test', 'valid']
@ -65,7 +75,7 @@ def split_samples(flag):
os_list = os.listdir(path)
random.shuffle(os_list)
# 8/1/1 分数据
# 6/2/2 分数据
train_len = int(len(os_list) * 0.6)
test_len = int(train_len / 3)
for index, f in enumerate(os_list):
@ -79,7 +89,9 @@ def split_samples(flag):
if __name__ == '__main__':
# delete_remote_backup()
# transfer_remote()
delete_all_remote()
# delete_remote()
split_samples('standard')
split_samples('one_family')
split_samples('benign')