This commit is contained in:
TinyCaviar 2023-09-20 16:31:33 +08:00
parent f0f6a55d84
commit 5d1e7e1ed0
4 changed files with 98 additions and 14 deletions

View File

@ -130,33 +130,34 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
local_batch.x = in_x local_batch.x = in_x
return local_batch return local_batch
# 基于多实例分解的CFG嵌入学习 # 多实例分解的CFG嵌入学习
def forward_MID_cfg_gnn(self, local_batch): def forward_MID_cfg_gnn(self, local_batch: Batch):
device = torch.device('cuda')
cfg_embeddings = [] cfg_embeddings = []
# cfg_subgraph_loader是cfg(分解后的)列表是二维的每个元素都是一个acfg分解成的子图列表
cfg_subgraph_loader = local_batch.cfg_subgraph_loader cfg_subgraph_loader = local_batch.cfg_subgraph_loader
# 聚合子图的嵌入用以表示原本的cfg
for acfg in cfg_subgraph_loader: for acfg in cfg_subgraph_loader:
subgraph_embeddings = [] subgraph_embeddings = []
# 遍历当前cfg的子图列表每个元素都是一个子图它是一个Data对象
# 计算子图的嵌入
for subgraph in acfg: for subgraph in acfg:
in_x, edge_index = subgraph.x, subgraph.edge_index in_x, edge_index = subgraph.x, subgraph.edge_index
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
for i in range(self.cfg_filter_length - 1): for i in range(self.cfg_filter_length - 1):
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index) out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
out_x = pt_f.relu(out_x, inplace=True) out_x = pt_f.relu(out_x, inplace=True)
out_x = self.dropout(out_x) out_x = self.dropout(out_x)
in_x = out_x in_x = out_x
subgraph_embedding = torch.max(in_x, dim=0).values subgraph_embedding = global_mean_pool(in_x, batch)
subgraph_embeddings.append(subgraph_embedding) subgraph_embeddings.append(subgraph_embedding.squeeze(0))
cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0) cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0)
cfg_embeddings.append(cfg_embedding) cfg_embeddings.append(cfg_embedding)
cfg_embeddings = torch.stack(cfg_embeddings) cfg_embeddings = torch.stack(cfg_embeddings)
local_batch.x = cfg_embeddings # 创建一个新的 batch 向量
batch_size = cfg_embeddings.size(0)
new_batch = torch.arange(batch_size)
local_batch.x = cfg_embeddings.to(device)
local_batch.batch = new_batch.to(device)
return local_batch return local_batch
def aggregate_cfg_batch_pooling(self, local_batch: Batch): def aggregate_cfg_batch_pooling(self, local_batch: Batch):
if self.pool == 'global_max_pool': if self.pool == 'global_max_pool':
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch) x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)

View File

@ -69,8 +69,8 @@ def _simulating(_dataset, _batch_size: int):
if __name__ == '__main__': if __name__ == '__main__':
# root_path: str = '/root/autodl-tmp/' root_path: str = '/root/autodl-tmp/'
root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt' # root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt'
i_batch_size = 2 i_batch_size = 2
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train') train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')

View File

@ -81,8 +81,6 @@ def multi_instance_decompose(acfg: Data):
# g.add_edges_from(edge_index2edges(acfg.edge_index)) # g.add_edges_from(edge_index2edges(acfg.edge_index))
return metis_MID(acfg) return metis_MID(acfg)
# return structure_MID(acfg, g)
# return topological_MID(acfg, g)
def metis_MID(acfg): def metis_MID(acfg):

85
src/utils/util.py Normal file
View File

@ -0,0 +1,85 @@
import os
import shutil
import random
def transfer_remote():
samples_dir = '/root/autodl-tmp'
all_benign = '/root/autodl-tmp/all_benign'
one_family_malware = '/root/autodl-tmp/one_family_malware'
sample = ['malware', 'benign']
tags = ['test', 'train', 'valid']
for s in sample:
index = 0
for t in tags:
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
for file in os.listdir(file_dir):
dest_dir = all_benign if s == 'benign' else one_family_malware
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
index += 1
delete_all_remote()
def delete_all_remote():
samples_dir = '/root/autodl-tmp'
sample = ['malware', 'benign']
tags = ['test', 'train', 'valid']
for s in sample:
for t in tags:
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
for f in os.listdir(file_dir):
os.remove(os.path.join(file_dir, f))
# 重命名pt文件使之与代码相符
def rename(file_dir, mal_or_be, postfix):
tag_set = ['train', 'test', 'valid']
for tag in tag_set:
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
for index, f in enumerate(os.listdir(data_dir)):
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, 'm' + f))
for tag in tag_set:
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
for index, f in enumerate(os.listdir(data_dir)):
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, '{}_{}.pt'.format(mal_or_be, index)))
def split_samples(flag):
postfix = ''
file_dir = '/root/autodl-tmp'
if flag == 'one_family':
path = os.path.join(file_dir, 'one_family_malware')
tag = 'malware'
elif flag == 'standard':
path = os.path.join(file_dir, 'all')
postfix = '_backup'
tag = 'malware'
elif flag == 'benign':
path = os.path.join(file_dir, 'all_benign')
tag = 'benign'
else:
print('flag not implemented')
return
os_list = os.listdir(path)
random.shuffle(os_list)
# 8/1/1 分数据
train_len = int(len(os_list) * 0.6)
test_len = int(train_len / 3)
for index, f in enumerate(os_list):
if index < train_len:
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'train_{}'.format(tag) + postfix))
elif train_len <= index < train_len + test_len:
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'test_{}'.format(tag) + postfix))
else:
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'valid_{}'.format(tag) + postfix))
rename(file_dir, tag, postfix)
if __name__ == '__main__':
# transfer_remote()
delete_all_remote()
split_samples('one_family')
split_samples('benign')