diff --git a/src/models/HierarchicalGraphModel.py b/src/models/HierarchicalGraphModel.py index 2b4e9b4..c32c249 100644 --- a/src/models/HierarchicalGraphModel.py +++ b/src/models/HierarchicalGraphModel.py @@ -129,7 +129,34 @@ class HierarchicalGraphNeuralNetwork(nn.Module): in_x = out_x local_batch.x = in_x return local_batch - + + # 基于多实例分解的CFG嵌入学习 + def forward_MID_cfg_gnn(self, local_batch): + cfg_embeddings = [] + # cfg_subgraph_loader是cfg(分解后的)列表,是二维的,每个元素都是一个acfg分解成的子图列表 + cfg_subgraph_loader = local_batch.cfg_subgraph_loader + # 聚合子图的嵌入,用以表示原本的cfg + for acfg in cfg_subgraph_loader: + subgraph_embeddings = [] + # 遍历当前cfg的子图列表,每个元素都是一个子图,它是一个Data对象 + # 计算子图的嵌入 + for subgraph in acfg: + in_x, edge_index = subgraph.x, subgraph.edge_index + for i in range(self.cfg_filter_length - 1): + out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index) + out_x = pt_f.relu(out_x, inplace=True) + out_x = self.dropout(out_x) + in_x = out_x + subgraph_embedding = torch.max(in_x, dim=0).values + subgraph_embeddings.append(subgraph_embedding) + cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0) + cfg_embeddings.append(cfg_embedding) + + cfg_embeddings = torch.stack(cfg_embeddings) + local_batch.x = cfg_embeddings + return local_batch + + def aggregate_cfg_batch_pooling(self, local_batch: Batch): if self.pool == 'global_max_pool': x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch) @@ -178,7 +205,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module): def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device): - rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch) + rtn_local_batch = self.forward_MID_cfg_gnn(local_batch=real_local_batch) x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch) # build the Function Call Graph (FCG) Data/Batch datasets diff --git a/src/utils/PreProcessedDataset.py b/src/utils/PreProcessedDataset.py index 898c785..753c54d 100644 --- a/src/utils/PreProcessedDataset.py +++ b/src/utils/PreProcessedDataset.py @@ -4,7 +4,7 @@ from datetime import datetime import torch from torch_geometric.data import Dataset, DataLoader -from utils.RealBatch import create_real_batch_data # noqa +from src.utils.RealBatch import create_real_batch_data # noqa class MalwareDetectionDataset(Dataset): @@ -69,7 +69,8 @@ def _simulating(_dataset, _batch_size: int): if __name__ == '__main__': - root_path: str = '/home/xiang/MalGraph/data/processed_dataset/DatasetJSON/' + # root_path: str = '/root/autodl-tmp/' + root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt' i_batch_size = 2 train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train') diff --git a/src/utils/RealBatch.py b/src/utils/RealBatch.py index 39313c8..14b059b 100644 --- a/src/utils/RealBatch.py +++ b/src/utils/RealBatch.py @@ -1,10 +1,22 @@ import torch -from torch_geometric.data import Batch -from torch_geometric.data import DataLoader -from pprint import pprint +import pymetis as metis +import numpy as np +import networkx as nx +from torch_geometric.data import Batch, Data +from torch_geometric.utils import to_networkx, from_networkx +from typing import List + +from src.utils.RemoveCycleEdgesTrueskill import perform_breaking_edges +from datetime import datetime + + +perform_MID = True def create_real_batch_data(one_batch: Batch): + if perform_MID: + return create_MID_real_batch_data(one_batch) + real = [] position = [0] count = 0 @@ -21,4 +33,179 @@ def create_real_batch_data(one_batch: Batch): return (None for _ in range(6)) else: real_batch = Batch.from_data_list(real) - return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets \ No newline at end of file + return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets + + +# cfg的多实例分解batch +def create_MID_real_batch_data(one_batch: Batch): + # 原始cfg列表 + real = [] + # 分解后的cfg列表,每个元素都是一个cfg分解后的子图列表list[Data],因此它是二维的 + real_decomposed_cfgs = [] + position = [0] + count = 0 + + assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len( + one_batch.hash), "size of each component must be equal to each other" + + for pe in one_batch.local_acfgs: + # 遍历pe中的acfg + for acfg in pe: + # 多实例分解acfg,返回一个list[Data] + sub_graphs = multi_instance_decompose(acfg) + real_decomposed_cfgs.append(sub_graphs) + real.append(acfg) + # 一个exe中的所有acfg数量 + count += len(pe) + # 记录每个exe中acfg的数量 + position.append(count) + + if len(one_batch.local_acfgs) == 1 and len(one_batch.local_acfgs[0]) == 0: + return (None for _ in range(6)) + else: + real_batch = Batch.from_data_list(real) + real_batch.cfg_subgraph_loader = real_decomposed_cfgs + return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets + + +# CFG的多实例分解 +# return list[Data] +def multi_instance_decompose(acfg: Data): + # edge_index : torch.tensor([[0, 1, 2], [1, 2, 3]]) + # acfg.x是每个块的11维属性张量 + # 只有一个节点的图,所以没有边信息,edge_index长度为0,不需要处理 + # if len(acfg.x) == 1: + # return [acfg] + # + # g = nx.DiGraph() + # g.add_edges_from(edge_index2edges(acfg.edge_index)) + + return metis_MID(acfg) + # return structure_MID(acfg, g) + # return topological_MID(acfg, g) + + +def metis_MID(acfg): + nparts = 3 + node_num = len(acfg.x) + if node_num < 10: + return [acfg] + G = to_networkx(acfg).to_undirected() + adjacency_list = [list(G.neighbors(node)) for node in sorted(G.nodes)] + _, parts = metis.part_graph(nparts=nparts, adjacency=adjacency_list, recursive=False) # 分解为3个子图 + sub_graphs: List[Data] = [] + subgraph_nodes: List[List[int]] = [] + for i, p in enumerate(parts): + while p >= len(subgraph_nodes): + subgraph_nodes.append([]) + subgraph_nodes[p].append(i) + + for sub_graph in subgraph_nodes: + if len(sub_graph) == 0: + continue + indices = torch.unique(torch.tensor(sub_graph)).long() + sub_G = G.subgraph(sub_graph) + sub_data = from_networkx(sub_G) + sub_data.x = acfg.x[indices] + sub_graphs.append(sub_data) + + return sub_graphs + + +# 将循环结构和剩余的层次结构分别保存为Data,返回list[Data] +def structure_MID(acfg, g): + result = [] + + # 提取图中的自环结构 + # self_loop = nx.selfloop_edges(g) + # result += [create_data(acfg.x, torch.tensor([[loop[0]], [loop[0]]])) for loop in self_loop] + + # 这里不能用self_loop,因为这个变量在被读取之后会被清空 + # g.remove_edges_from(nx.selfloop_edges(g)) + + # 提取图中的循环结构 + # cycles = list(nx.simple_cycles(g)) + # if len(cycles) > 0: + # max_cycle = max(len(cycle) for cycle in cycles) + # max_cycle = max(cycles, key=len) + # print(max_cycle) + # result += [create_data(acfg.x, torch.tensor([path[:-1], path[1:]])) for path in cycles] + + # time_start = datetime.now() + # 将图转换为DAG,尽可能保留原图的层次结构 + perform_breaking_edges(g) + graph_index = edges2edge_index(g.edges) + result.append(create_data(acfg.x, graph_index)) + # time_end = datetime.now() + # print("process time = {}".format(time_end - time_start)) + + return result + + +# 将图进行拓扑排序后进行dfs找出图中所有最长子路径,分别保存为Data,返回list[Data] +def topological_MID(acfg, g): + # 将图转换为DAG,尽可能保留原图的层次结构 + perform_breaking_edges(g) + # 拓扑排序 + topo_order = list(nx.topological_sort(g)) + # 初始化距离数组 + dist = {node: float('-inf') for node in g.nodes()} + # 初始化前驱节点数组 + prev = {node: [] for node in g.nodes()} + # 初始化起点节点的距离为0 + for node in g.nodes(): + if g.in_degree(node) == 0: + dist[node] = 0 + # 遍历所有节点 + for node in topo_order: + # 遍历所有后继节点 + for successor in g.successors(node): + # 更新距离 + if dist[successor] < dist[node] + 1: + dist[successor] = dist[node] + 1 + prev[successor] = [node] + elif dist[successor] == dist[node] + 1: + prev[successor].append(node) + + # 计算最长路径的长度 + max_length = max(dist.values()) + + # 初始化最长路径数组 + longest_paths = [] + + # 遍历所有终点节点 + for node in g.nodes(): + if g.out_degree(node) == 0 and dist[node] == max_length: + dfs(node, [node], prev, longest_paths) + + # 将acfg中所有最长子图路径转换为Data集合,也就是说一个acfg被转换为一个Data列表 + return [create_data(acfg.x, torch.tensor([path[:-1], path[1:]])) for path in longest_paths] + + +def dfs(node, path, prev, longest_paths): + if len(prev[node]) == 0: + longest_paths.append(path) + else: + for predecessor in prev[node]: + dfs(predecessor, [predecessor] + path, prev, longest_paths) + + +# 获取edge_index中出现过的所有元素,在x中仅保留这些元素所对应的索引 +# 用于快速创建子图的x属性,注意x和edge_index都是torch.tensor +def create_data(x, edge_index): + # 获取edge_index中出现过的元素 + indices = torch.unique(edge_index).long() + return Data(x[indices], edge_index) + + +# torch.tensor([[1, 2, 3], [2, 3, 4]]) => [(1, 2), (2, 3), (3, 4)] +# 将edge_index张量转换为edges数组 +def edge_index2edges(edge_index): + return list(zip(*edge_index.tolist())) + + +# OutEdgeView([(1, 2), (2, 3), (3, 4)]) => torch.tensor([[1, 2, 3], [2, 3, 4]]) +# 将edges数组转换为edge_index张量 +def edges2edge_index(edges): + edges = list(edges.items()) + return torch.tensor([list(edge[0]) for edge in edges]).t().contiguous() diff --git a/src/utils/RemoveCycleEdgesTrueskill.py b/src/utils/RemoveCycleEdgesTrueskill.py new file mode 100644 index 0000000..d452f3f --- /dev/null +++ b/src/utils/RemoveCycleEdgesTrueskill.py @@ -0,0 +1,170 @@ +import random +from trueskill import Rating, rate_1vs1 +import networkx as nx +import os + + +# noinspection DuplicatedCode +def __get_big_sccs(g): + num_big_sccs = 0 + big_sccs = [] + for sub in (g.subgraph(c).copy() for c in nx.strongly_connected_components(g)): + number_of_nodes = sub.number_of_nodes() + if number_of_nodes >= 2: + # strongly connected components + num_big_sccs += 1 + big_sccs.append(sub) + # print(" # big sccs: %d" % (num_big_sccs)) + return big_sccs + + +# noinspection DuplicatedCode +def __nodes_in_scc(sccs): + scc_nodes = [] + scc_edges = [] + for scc in sccs: + scc_nodes += list(scc.nodes()) + scc_edges += list(scc.edges()) + + # print("# nodes in big sccs: %d" % len(scc_nodes)) + # print("# edges in big sccs: %d" % len(scc_edges)) + return scc_nodes + + +def __scores_of_nodes_in_scc(sccs, players): + scc_nodes = __nodes_in_scc(sccs) + scc_nodes_score_dict = {} + for node in scc_nodes: + scc_nodes_score_dict[node] = players[node] + # print("# scores of nodes in scc: %d" % (len(scc_nodes_score_dict))) + return scc_nodes_score_dict + + +def __filter_big_scc(g, edges_to_be_removed): + # Given a graph g and edges to be removed + # Return a list of big scc subgraphs (# of nodes >= 2) + g.remove_edges_from(edges_to_be_removed) + sub_graphs = filter(lambda scc: scc.number_of_nodes() >= 2, + [g.subgraph(c).copy() for c in nx.strongly_connected_components(g)]) + return sub_graphs + + +def __remove_cycle_edges_by_agony_iterately(sccs, players, edges_to_be_removed): + while True: + graph = sccs.pop() + pair_max_agony = None + max_agony = -1 + for pair in graph.edges(): + u, v = pair + agony = max(players[u] - players[v], 0) + if agony >= max_agony: + pair_max_agony = (u, v) + max_agony = agony + edges_to_be_removed.append(pair_max_agony) + # print("graph: (%d,%d), edge to be removed: %s, agony: %0.4f" % (graph.number_of_nodes(),graph.number_of_edges(),pair_max_agony,max_agony)) + graph.remove_edges_from([pair_max_agony]) + # print("graph: (%d,%d), edge to be removed: %s" % (graph.number_of_nodes(),graph.number_of_edges(),pair_max_agony)) + sub_graphs = __filter_big_scc(graph, [pair_max_agony]) + if sub_graphs: + for index, sub in enumerate(sub_graphs): + sccs.append(sub) + if not sccs: + return + + +def __compute_trueskill(pairs, players): + if not players: + for u, v in pairs: + if u not in players: + players[u] = Rating() + if v not in players: + players[v] = Rating() + + random.shuffle(pairs) + for u, v in pairs: + players[v], players[u] = rate_1vs1(players[v], players[u]) + + return players + + +def __get_players_score(players, n_sigma): + relative_score = {} + for k, v in players.items(): + relative_score[k] = players[k].mu - n_sigma * players[k].sigma + return relative_score + + +def __measure_pairs_agreement(pairs, nodes_score): + # whether nodes in pairs agree with their ranking scores + num_correct_pairs = 0 + num_wrong_pairs = 0 + total_pairs = 0 + for u, v in pairs: + if u in nodes_score and v in nodes_score: + if nodes_score[u] <= nodes_score[v]: + num_correct_pairs += 1 + else: + num_wrong_pairs += 1 + total_pairs += 1 + if total_pairs != 0: + acc = num_correct_pairs * 1.0 / total_pairs + # print("correct pairs: %d, wrong pairs: %d, total pairs: %d, accuracy: %0.4f" % (num_correct_pairs,num_wrong_pairs,total_pairs,num_correct_pairs*1.0/total_pairs)) + else: + acc = 1 + # print("total pairs: 0, accuracy: 1") + return acc + + +def __trueskill_ratings(pairs, iter_times=15, n_sigma=3, threshold=0.85): + players = {} + for i in range(iter_times): + players = __compute_trueskill(pairs, players) + relative_scores = __get_players_score(players, n_sigma=n_sigma) + accu = __measure_pairs_agreement(pairs, relative_scores) + if accu >= threshold: + return relative_scores + # end = datetime.now() + # time_used = end - start + # print("time used in computing true skill: %0.4f s, iteration time is: %i" % (time_used.seconds, (i + 1))) + return relative_scores + + +# noinspection DuplicatedCode +# def breaking_cycles_by_TS(graph_path): +# g = nx.read_edgelist(graph_path, create_using=nx.DiGraph(), nodetype=int) +# players_score_dict = __trueskill_ratings(list(g.edges()), iter_times=15, n_sigma=3, threshold=0.95) +# g.remove_edges_from(list(nx.selfloop_edges(g))) +# big_sccs = __get_big_sccs(g) +# scc_nodes_score_dict = __scores_of_nodes_in_scc(big_sccs, players_score_dict) +# edges_to_be_removed = [] +# if len(big_sccs) == 0: +# print("After removal of self loop edgs: %s" % nx.is_directed_acyclic_graph(g)) +# return +# +# __remove_cycle_edges_by_agony_iterately(big_sccs, scc_nodes_score_dict, edges_to_be_removed) +# g.remove_edges_from(edges_to_be_removed) +# nx.write_edgelist(g, out_path) + + +# edgelist形式为[(x0, y0), (x1, y1), (x2, y2), (x3, y3)] +def perform_breaking_edges(g): + players_score_dict = __trueskill_ratings(list(g.edges()), iter_times=15, n_sigma=3, threshold=0.95) + g.remove_edges_from(list(nx.selfloop_edges(g))) + big_sccs = __get_big_sccs(g) + scc_nodes_score_dict = __scores_of_nodes_in_scc(big_sccs, players_score_dict) + edges_to_be_removed = [] + + # 移除自环已经是DAG + if len(big_sccs) == 0: + return + + __remove_cycle_edges_by_agony_iterately(big_sccs, scc_nodes_score_dict, edges_to_be_removed) + g.remove_edges_from(edges_to_be_removed) + + +if __name__ == '__main__': + # for test only + graph_path = 'D:\\hkn\\infected\\datasets\\text_only_nx\\text.edges' + out_path = 'D:\\hkn\\infected\\datasets\\text_only_nx\\result.edges' + + # breaking_cycles_by_TS(graph_path)