backup
This commit is contained in:
parent
f0f6a55d84
commit
5d1e7e1ed0
@ -130,33 +130,34 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
local_batch.x = in_x
|
||||
return local_batch
|
||||
|
||||
# 基于多实例分解的CFG嵌入学习
|
||||
def forward_MID_cfg_gnn(self, local_batch):
|
||||
# 多实例分解的CFG嵌入学习
|
||||
def forward_MID_cfg_gnn(self, local_batch: Batch):
|
||||
device = torch.device('cuda')
|
||||
cfg_embeddings = []
|
||||
# cfg_subgraph_loader是cfg(分解后的)列表,是二维的,每个元素都是一个acfg分解成的子图列表
|
||||
cfg_subgraph_loader = local_batch.cfg_subgraph_loader
|
||||
# 聚合子图的嵌入,用以表示原本的cfg
|
||||
for acfg in cfg_subgraph_loader:
|
||||
subgraph_embeddings = []
|
||||
# 遍历当前cfg的子图列表,每个元素都是一个子图,它是一个Data对象
|
||||
# 计算子图的嵌入
|
||||
for subgraph in acfg:
|
||||
in_x, edge_index = subgraph.x, subgraph.edge_index
|
||||
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
out_x = pt_f.relu(out_x, inplace=True)
|
||||
out_x = self.dropout(out_x)
|
||||
in_x = out_x
|
||||
subgraph_embedding = torch.max(in_x, dim=0).values
|
||||
subgraph_embeddings.append(subgraph_embedding)
|
||||
subgraph_embedding = global_mean_pool(in_x, batch)
|
||||
subgraph_embeddings.append(subgraph_embedding.squeeze(0))
|
||||
cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0)
|
||||
cfg_embeddings.append(cfg_embedding)
|
||||
|
||||
cfg_embeddings = torch.stack(cfg_embeddings)
|
||||
local_batch.x = cfg_embeddings
|
||||
# 创建一个新的 batch 向量
|
||||
batch_size = cfg_embeddings.size(0)
|
||||
new_batch = torch.arange(batch_size)
|
||||
local_batch.x = cfg_embeddings.to(device)
|
||||
local_batch.batch = new_batch.to(device)
|
||||
return local_batch
|
||||
|
||||
|
||||
def aggregate_cfg_batch_pooling(self, local_batch: Batch):
|
||||
if self.pool == 'global_max_pool':
|
||||
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)
|
||||
|
@ -69,8 +69,8 @@ def _simulating(_dataset, _batch_size: int):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# root_path: str = '/root/autodl-tmp/'
|
||||
root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt'
|
||||
root_path: str = '/root/autodl-tmp/'
|
||||
# root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt'
|
||||
i_batch_size = 2
|
||||
|
||||
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
|
||||
|
@ -81,8 +81,6 @@ def multi_instance_decompose(acfg: Data):
|
||||
# g.add_edges_from(edge_index2edges(acfg.edge_index))
|
||||
|
||||
return metis_MID(acfg)
|
||||
# return structure_MID(acfg, g)
|
||||
# return topological_MID(acfg, g)
|
||||
|
||||
|
||||
def metis_MID(acfg):
|
||||
|
85
src/utils/util.py
Normal file
85
src/utils/util.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
|
||||
|
||||
def transfer_remote():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
all_benign = '/root/autodl-tmp/all_benign'
|
||||
one_family_malware = '/root/autodl-tmp/one_family_malware'
|
||||
|
||||
sample = ['malware', 'benign']
|
||||
tags = ['test', 'train', 'valid']
|
||||
for s in sample:
|
||||
index = 0
|
||||
for t in tags:
|
||||
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
|
||||
for file in os.listdir(file_dir):
|
||||
dest_dir = all_benign if s == 'benign' else one_family_malware
|
||||
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
|
||||
index += 1
|
||||
|
||||
delete_all_remote()
|
||||
|
||||
|
||||
def delete_all_remote():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
sample = ['malware', 'benign']
|
||||
tags = ['test', 'train', 'valid']
|
||||
for s in sample:
|
||||
for t in tags:
|
||||
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
|
||||
for f in os.listdir(file_dir):
|
||||
os.remove(os.path.join(file_dir, f))
|
||||
|
||||
|
||||
# 重命名pt文件使之与代码相符
|
||||
def rename(file_dir, mal_or_be, postfix):
|
||||
tag_set = ['train', 'test', 'valid']
|
||||
for tag in tag_set:
|
||||
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
|
||||
for index, f in enumerate(os.listdir(data_dir)):
|
||||
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, 'm' + f))
|
||||
for tag in tag_set:
|
||||
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
|
||||
for index, f in enumerate(os.listdir(data_dir)):
|
||||
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, '{}_{}.pt'.format(mal_or_be, index)))
|
||||
|
||||
|
||||
def split_samples(flag):
|
||||
postfix = ''
|
||||
file_dir = '/root/autodl-tmp'
|
||||
if flag == 'one_family':
|
||||
path = os.path.join(file_dir, 'one_family_malware')
|
||||
tag = 'malware'
|
||||
elif flag == 'standard':
|
||||
path = os.path.join(file_dir, 'all')
|
||||
postfix = '_backup'
|
||||
tag = 'malware'
|
||||
elif flag == 'benign':
|
||||
path = os.path.join(file_dir, 'all_benign')
|
||||
tag = 'benign'
|
||||
else:
|
||||
print('flag not implemented')
|
||||
return
|
||||
|
||||
os_list = os.listdir(path)
|
||||
random.shuffle(os_list)
|
||||
# 8/1/1 分数据
|
||||
train_len = int(len(os_list) * 0.6)
|
||||
test_len = int(train_len / 3)
|
||||
for index, f in enumerate(os_list):
|
||||
if index < train_len:
|
||||
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'train_{}'.format(tag) + postfix))
|
||||
elif train_len <= index < train_len + test_len:
|
||||
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'test_{}'.format(tag) + postfix))
|
||||
else:
|
||||
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'valid_{}'.format(tag) + postfix))
|
||||
rename(file_dir, tag, postfix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# transfer_remote()
|
||||
delete_all_remote()
|
||||
split_samples('one_family')
|
||||
split_samples('benign')
|
Loading…
Reference in New Issue
Block a user