backup
This commit is contained in:
parent
f0f6a55d84
commit
5d1e7e1ed0
@ -130,33 +130,34 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
|||||||
local_batch.x = in_x
|
local_batch.x = in_x
|
||||||
return local_batch
|
return local_batch
|
||||||
|
|
||||||
# 基于多实例分解的CFG嵌入学习
|
# 多实例分解的CFG嵌入学习
|
||||||
def forward_MID_cfg_gnn(self, local_batch):
|
def forward_MID_cfg_gnn(self, local_batch: Batch):
|
||||||
|
device = torch.device('cuda')
|
||||||
cfg_embeddings = []
|
cfg_embeddings = []
|
||||||
# cfg_subgraph_loader是cfg(分解后的)列表,是二维的,每个元素都是一个acfg分解成的子图列表
|
|
||||||
cfg_subgraph_loader = local_batch.cfg_subgraph_loader
|
cfg_subgraph_loader = local_batch.cfg_subgraph_loader
|
||||||
# 聚合子图的嵌入,用以表示原本的cfg
|
|
||||||
for acfg in cfg_subgraph_loader:
|
for acfg in cfg_subgraph_loader:
|
||||||
subgraph_embeddings = []
|
subgraph_embeddings = []
|
||||||
# 遍历当前cfg的子图列表,每个元素都是一个子图,它是一个Data对象
|
|
||||||
# 计算子图的嵌入
|
|
||||||
for subgraph in acfg:
|
for subgraph in acfg:
|
||||||
in_x, edge_index = subgraph.x, subgraph.edge_index
|
in_x, edge_index = subgraph.x, subgraph.edge_index
|
||||||
|
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
|
||||||
for i in range(self.cfg_filter_length - 1):
|
for i in range(self.cfg_filter_length - 1):
|
||||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||||
out_x = pt_f.relu(out_x, inplace=True)
|
out_x = pt_f.relu(out_x, inplace=True)
|
||||||
out_x = self.dropout(out_x)
|
out_x = self.dropout(out_x)
|
||||||
in_x = out_x
|
in_x = out_x
|
||||||
subgraph_embedding = torch.max(in_x, dim=0).values
|
subgraph_embedding = global_mean_pool(in_x, batch)
|
||||||
subgraph_embeddings.append(subgraph_embedding)
|
subgraph_embeddings.append(subgraph_embedding.squeeze(0))
|
||||||
cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0)
|
cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0)
|
||||||
cfg_embeddings.append(cfg_embedding)
|
cfg_embeddings.append(cfg_embedding)
|
||||||
|
|
||||||
cfg_embeddings = torch.stack(cfg_embeddings)
|
cfg_embeddings = torch.stack(cfg_embeddings)
|
||||||
local_batch.x = cfg_embeddings
|
# 创建一个新的 batch 向量
|
||||||
|
batch_size = cfg_embeddings.size(0)
|
||||||
|
new_batch = torch.arange(batch_size)
|
||||||
|
local_batch.x = cfg_embeddings.to(device)
|
||||||
|
local_batch.batch = new_batch.to(device)
|
||||||
return local_batch
|
return local_batch
|
||||||
|
|
||||||
|
|
||||||
def aggregate_cfg_batch_pooling(self, local_batch: Batch):
|
def aggregate_cfg_batch_pooling(self, local_batch: Batch):
|
||||||
if self.pool == 'global_max_pool':
|
if self.pool == 'global_max_pool':
|
||||||
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)
|
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)
|
||||||
|
@ -69,8 +69,8 @@ def _simulating(_dataset, _batch_size: int):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# root_path: str = '/root/autodl-tmp/'
|
root_path: str = '/root/autodl-tmp/'
|
||||||
root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt'
|
# root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt'
|
||||||
i_batch_size = 2
|
i_batch_size = 2
|
||||||
|
|
||||||
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
|
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
|
||||||
|
@ -81,8 +81,6 @@ def multi_instance_decompose(acfg: Data):
|
|||||||
# g.add_edges_from(edge_index2edges(acfg.edge_index))
|
# g.add_edges_from(edge_index2edges(acfg.edge_index))
|
||||||
|
|
||||||
return metis_MID(acfg)
|
return metis_MID(acfg)
|
||||||
# return structure_MID(acfg, g)
|
|
||||||
# return topological_MID(acfg, g)
|
|
||||||
|
|
||||||
|
|
||||||
def metis_MID(acfg):
|
def metis_MID(acfg):
|
||||||
|
85
src/utils/util.py
Normal file
85
src/utils/util.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_remote():
|
||||||
|
samples_dir = '/root/autodl-tmp'
|
||||||
|
all_benign = '/root/autodl-tmp/all_benign'
|
||||||
|
one_family_malware = '/root/autodl-tmp/one_family_malware'
|
||||||
|
|
||||||
|
sample = ['malware', 'benign']
|
||||||
|
tags = ['test', 'train', 'valid']
|
||||||
|
for s in sample:
|
||||||
|
index = 0
|
||||||
|
for t in tags:
|
||||||
|
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
|
||||||
|
for file in os.listdir(file_dir):
|
||||||
|
dest_dir = all_benign if s == 'benign' else one_family_malware
|
||||||
|
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
delete_all_remote()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_all_remote():
|
||||||
|
samples_dir = '/root/autodl-tmp'
|
||||||
|
sample = ['malware', 'benign']
|
||||||
|
tags = ['test', 'train', 'valid']
|
||||||
|
for s in sample:
|
||||||
|
for t in tags:
|
||||||
|
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
|
||||||
|
for f in os.listdir(file_dir):
|
||||||
|
os.remove(os.path.join(file_dir, f))
|
||||||
|
|
||||||
|
|
||||||
|
# 重命名pt文件使之与代码相符
|
||||||
|
def rename(file_dir, mal_or_be, postfix):
|
||||||
|
tag_set = ['train', 'test', 'valid']
|
||||||
|
for tag in tag_set:
|
||||||
|
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
|
||||||
|
for index, f in enumerate(os.listdir(data_dir)):
|
||||||
|
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, 'm' + f))
|
||||||
|
for tag in tag_set:
|
||||||
|
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
|
||||||
|
for index, f in enumerate(os.listdir(data_dir)):
|
||||||
|
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, '{}_{}.pt'.format(mal_or_be, index)))
|
||||||
|
|
||||||
|
|
||||||
|
def split_samples(flag):
|
||||||
|
postfix = ''
|
||||||
|
file_dir = '/root/autodl-tmp'
|
||||||
|
if flag == 'one_family':
|
||||||
|
path = os.path.join(file_dir, 'one_family_malware')
|
||||||
|
tag = 'malware'
|
||||||
|
elif flag == 'standard':
|
||||||
|
path = os.path.join(file_dir, 'all')
|
||||||
|
postfix = '_backup'
|
||||||
|
tag = 'malware'
|
||||||
|
elif flag == 'benign':
|
||||||
|
path = os.path.join(file_dir, 'all_benign')
|
||||||
|
tag = 'benign'
|
||||||
|
else:
|
||||||
|
print('flag not implemented')
|
||||||
|
return
|
||||||
|
|
||||||
|
os_list = os.listdir(path)
|
||||||
|
random.shuffle(os_list)
|
||||||
|
# 8/1/1 分数据
|
||||||
|
train_len = int(len(os_list) * 0.6)
|
||||||
|
test_len = int(train_len / 3)
|
||||||
|
for index, f in enumerate(os_list):
|
||||||
|
if index < train_len:
|
||||||
|
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'train_{}'.format(tag) + postfix))
|
||||||
|
elif train_len <= index < train_len + test_len:
|
||||||
|
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'test_{}'.format(tag) + postfix))
|
||||||
|
else:
|
||||||
|
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'valid_{}'.format(tag) + postfix))
|
||||||
|
rename(file_dir, tag, postfix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# transfer_remote()
|
||||||
|
delete_all_remote()
|
||||||
|
split_samples('one_family')
|
||||||
|
split_samples('benign')
|
Loading…
Reference in New Issue
Block a user