import torch.nn as nn from dgl.nn.pytorch import GraphConv import torch.nn.functional as F import dgl import dgl.nn as dglnn import torch import sys def load_model(model_name="",device="cpu"): model="" if model_name == "GCN_base": model = GCN_base(1, 256, 2).to(device) elif model_name == "GCN_malconv_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_malconv_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_malconv_64": model = GCN_node_feature(64 , 32, 2).to(device) elif model_name == "GCN_malconv_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_malconv_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GCN_n_gram_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_n_gram_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_n_gram_64": model = GCN_node_feature(64, 32, 2).to(device) elif model_name == "GCN_n_gram_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_n_gram_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GCN_n_gram_512": model = GCN_node_feature(512, 32, 2).to(device) elif model_name == "GCN_word_frequency_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_word_frequency_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_word_frequency_64": model = GCN_node_feature(64, 32, 2).to(device) elif model_name == "GCN_word_frequency_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_word_frequency_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GCN_asm2vec_base_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_asm2vec_base_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_asm2vec_base_64": model = GCN_node_feature(64, 32, 2).to(device) elif model_name == "GCN_asm2vec_base_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_asm2vec_base_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GCN_asm2vec_s_base_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_asm2vec_s_base_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_asm2vec_s_base_64": model = GCN_node_feature(64, 32, 2).to(device) elif model_name == "GCN_asm2vec_s_base_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_asm2vec_s_base_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GCN_asm2vec_s368_base_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_asm2vec_s368_base_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_asm2vec_s368_base_64": model = GCN_node_feature(64, 32, 2).to(device) elif model_name == "GCN_asm2vec_s368_base_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_asm2vec_s368_base_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GCN_asm2vec_plus_16": model = GCN_node_feature(16, 32, 2).to(device) elif model_name == "GCN_asm2vec_plus_32": model = GCN_node_feature(32, 32, 2).to(device) elif model_name == "GCN_asm2vec_plus_64": model = GCN_node_feature(64, 32, 2).to(device) elif model_name == "GCN_asm2vec_plus_128": model = GCN_node_feature(128, 32, 2).to(device) elif model_name == "GCN_asm2vec_plus_256": model = GCN_node_feature(256, 32, 2).to(device) elif model_name == "GAT": model = GAT(100, 100, 100,2,heads=[8, 1]).to(device) return model class GCN_base(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(GCN_base, self).__init__() self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积 self.conv1.set_allow_zero_in_degree(True) self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积 self.conv2.set_allow_zero_in_degree(True) self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器 def forward(self, g): """g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量 """ # 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度 h = g.in_degrees().view(-1, 1).float() # [N, 1] # 执行图卷积和激活函数 h = F.relu(self.conv1(g, h)) # [N, hidden_dim] h = F.relu(self.conv2(g, h)) # [N, hidden_dim] g.ndata['h'] = h # 将特征赋予到图的节点 # 通过平均池化每个节点的表示得到图表示 hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim] return self.classify(hg) # [n, n_classes] class GCN_node_feature(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(GCN_node_feature, self).__init__() self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积 self.conv1.set_allow_zero_in_degree(True) self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积 self.conv2.set_allow_zero_in_degree(True) self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器 def forward(self, g): """g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量 """ # 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度 h = g.ndata['feature'] # 执行图卷积和激活函数 h = F.relu(self.conv1(g, h)) # [N, hidden_dim] h = F.relu(self.conv2(g, h)) # [N, hidden_dim] g.ndata['h'] = h # 将特征赋予到图的节点 # 通过平均池化每个节点的表示得到图表示 hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim] return self.classify(hg) # [n, n_classes] class GAT(nn.Module): def __init__(self, in_size, hid_size, out_size,class_size, heads): super().__init__() self.gat_layers = nn.ModuleList() # two-layer GAT self.gat_layers.append( dglnn.GATConv( in_size, hid_size, heads[0], feat_drop=0.6, attn_drop=0.6, activation=F.elu, ) ) self.gat_layers.append( dglnn.GATConv( hid_size * heads[0], out_size, heads[1], feat_drop=0.6, attn_drop=0.6, activation=None, ) ) self.classify = nn.Linear(hid_size, class_size) def forward(self, g): h = g.ndata['feature'] # h = inputs for i, layer in enumerate(self.gat_layers): h = layer(g, h) if i == 1: # last layer h = h.mean(1) else: # other layer(s) h = h.flatten(1) g.ndata['h'] = h # 将特征赋予到图的节点 # 通过平均池化每个节点的表示得到图表示 hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim] # print() return self.classify(hg) # [n, n_classes] # return h