diff --git a/src/models/HierarchicalGraphModel.py b/src/models/HierarchicalGraphModel.py index c32c249..c7ddec2 100644 --- a/src/models/HierarchicalGraphModel.py +++ b/src/models/HierarchicalGraphModel.py @@ -130,33 +130,34 @@ class HierarchicalGraphNeuralNetwork(nn.Module): local_batch.x = in_x return local_batch - # 基于多实例分解的CFG嵌入学习 - def forward_MID_cfg_gnn(self, local_batch): + # 多实例分解的CFG嵌入学习 + def forward_MID_cfg_gnn(self, local_batch: Batch): + device = torch.device('cuda') cfg_embeddings = [] - # cfg_subgraph_loader是cfg(分解后的)列表,是二维的,每个元素都是一个acfg分解成的子图列表 cfg_subgraph_loader = local_batch.cfg_subgraph_loader - # 聚合子图的嵌入,用以表示原本的cfg for acfg in cfg_subgraph_loader: subgraph_embeddings = [] - # 遍历当前cfg的子图列表,每个元素都是一个子图,它是一个Data对象 - # 计算子图的嵌入 for subgraph in acfg: in_x, edge_index = subgraph.x, subgraph.edge_index + batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device) for i in range(self.cfg_filter_length - 1): out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index) out_x = pt_f.relu(out_x, inplace=True) out_x = self.dropout(out_x) in_x = out_x - subgraph_embedding = torch.max(in_x, dim=0).values - subgraph_embeddings.append(subgraph_embedding) + subgraph_embedding = global_mean_pool(in_x, batch) + subgraph_embeddings.append(subgraph_embedding.squeeze(0)) cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0) cfg_embeddings.append(cfg_embedding) cfg_embeddings = torch.stack(cfg_embeddings) - local_batch.x = cfg_embeddings + # 创建一个新的 batch 向量 + batch_size = cfg_embeddings.size(0) + new_batch = torch.arange(batch_size) + local_batch.x = cfg_embeddings.to(device) + local_batch.batch = new_batch.to(device) return local_batch - def aggregate_cfg_batch_pooling(self, local_batch: Batch): if self.pool == 'global_max_pool': x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch) diff --git a/src/utils/PreProcessedDataset.py b/src/utils/PreProcessedDataset.py index 753c54d..c01f137 100644 --- a/src/utils/PreProcessedDataset.py +++ b/src/utils/PreProcessedDataset.py @@ -69,8 +69,8 @@ def _simulating(_dataset, _batch_size: int): if __name__ == '__main__': - # root_path: str = '/root/autodl-tmp/' - root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt' + root_path: str = '/root/autodl-tmp/' + # root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt' i_batch_size = 2 train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train') diff --git a/src/utils/RealBatch.py b/src/utils/RealBatch.py index 14b059b..b8845b8 100644 --- a/src/utils/RealBatch.py +++ b/src/utils/RealBatch.py @@ -81,8 +81,6 @@ def multi_instance_decompose(acfg: Data): # g.add_edges_from(edge_index2edges(acfg.edge_index)) return metis_MID(acfg) - # return structure_MID(acfg, g) - # return topological_MID(acfg, g) def metis_MID(acfg): diff --git a/src/utils/util.py b/src/utils/util.py new file mode 100644 index 0000000..81f91fa --- /dev/null +++ b/src/utils/util.py @@ -0,0 +1,85 @@ +import os +import shutil +import random + + +def transfer_remote(): + samples_dir = '/root/autodl-tmp' + all_benign = '/root/autodl-tmp/all_benign' + one_family_malware = '/root/autodl-tmp/one_family_malware' + + sample = ['malware', 'benign'] + tags = ['test', 'train', 'valid'] + for s in sample: + index = 0 + for t in tags: + file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s)) + for file in os.listdir(file_dir): + dest_dir = all_benign if s == 'benign' else one_family_malware + shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index))) + index += 1 + + delete_all_remote() + + +def delete_all_remote(): + samples_dir = '/root/autodl-tmp' + sample = ['malware', 'benign'] + tags = ['test', 'train', 'valid'] + for s in sample: + for t in tags: + file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s)) + for f in os.listdir(file_dir): + os.remove(os.path.join(file_dir, f)) + + +# 重命名pt文件使之与代码相符 +def rename(file_dir, mal_or_be, postfix): + tag_set = ['train', 'test', 'valid'] + for tag in tag_set: + data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix)) + for index, f in enumerate(os.listdir(data_dir)): + os.rename(os.path.join(data_dir, f), os.path.join(data_dir, 'm' + f)) + for tag in tag_set: + data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix)) + for index, f in enumerate(os.listdir(data_dir)): + os.rename(os.path.join(data_dir, f), os.path.join(data_dir, '{}_{}.pt'.format(mal_or_be, index))) + + +def split_samples(flag): + postfix = '' + file_dir = '/root/autodl-tmp' + if flag == 'one_family': + path = os.path.join(file_dir, 'one_family_malware') + tag = 'malware' + elif flag == 'standard': + path = os.path.join(file_dir, 'all') + postfix = '_backup' + tag = 'malware' + elif flag == 'benign': + path = os.path.join(file_dir, 'all_benign') + tag = 'benign' + else: + print('flag not implemented') + return + + os_list = os.listdir(path) + random.shuffle(os_list) + # 8/1/1 分数据 + train_len = int(len(os_list) * 0.6) + test_len = int(train_len / 3) + for index, f in enumerate(os_list): + if index < train_len: + shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'train_{}'.format(tag) + postfix)) + elif train_len <= index < train_len + test_len: + shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'test_{}'.format(tag) + postfix)) + else: + shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'valid_{}'.format(tag) + postfix)) + rename(file_dir, tag, postfix) + + +if __name__ == '__main__': + # transfer_remote() + delete_all_remote() + split_samples('one_family') + split_samples('benign')