180 lines
7.5 KiB
Python
180 lines
7.5 KiB
Python
|
import torch.nn as nn
|
|||
|
from dgl.nn.pytorch import GraphConv
|
|||
|
import torch.nn.functional as F
|
|||
|
import dgl
|
|||
|
import dgl.nn as dglnn
|
|||
|
import torch
|
|||
|
import sys
|
|||
|
|
|||
|
def load_model(model_name="",device="cpu"):
|
|||
|
model=""
|
|||
|
if model_name == "GCN_base":
|
|||
|
model = GCN_base(1, 256, 2).to(device)
|
|||
|
elif model_name == "GCN_malconv_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_malconv_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_malconv_64":
|
|||
|
model = GCN_node_feature(64 , 32, 2).to(device)
|
|||
|
elif model_name == "GCN_malconv_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_malconv_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GCN_n_gram_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_n_gram_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_n_gram_64":
|
|||
|
model = GCN_node_feature(64, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_n_gram_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_n_gram_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_n_gram_512":
|
|||
|
model = GCN_node_feature(512, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GCN_word_frequency_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_word_frequency_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_word_frequency_64":
|
|||
|
model = GCN_node_feature(64, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_word_frequency_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_word_frequency_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GCN_asm2vec_base_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_base_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_base_64":
|
|||
|
model = GCN_node_feature(64, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_base_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_base_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GCN_asm2vec_s_base_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s_base_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s_base_64":
|
|||
|
model = GCN_node_feature(64, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s_base_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s_base_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GCN_asm2vec_s368_base_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s368_base_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s368_base_64":
|
|||
|
model = GCN_node_feature(64, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s368_base_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_s368_base_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GCN_asm2vec_plus_16":
|
|||
|
model = GCN_node_feature(16, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_plus_32":
|
|||
|
model = GCN_node_feature(32, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_plus_64":
|
|||
|
model = GCN_node_feature(64, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_plus_128":
|
|||
|
model = GCN_node_feature(128, 32, 2).to(device)
|
|||
|
elif model_name == "GCN_asm2vec_plus_256":
|
|||
|
model = GCN_node_feature(256, 32, 2).to(device)
|
|||
|
|
|||
|
elif model_name == "GAT":
|
|||
|
model = GAT(100, 100, 100,2,heads=[8, 1]).to(device)
|
|||
|
return model
|
|||
|
|
|||
|
class GCN_base(nn.Module):
|
|||
|
def __init__(self, in_dim, hidden_dim, n_classes):
|
|||
|
super(GCN_base, self).__init__()
|
|||
|
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
|
|||
|
self.conv1.set_allow_zero_in_degree(True)
|
|||
|
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
|
|||
|
self.conv2.set_allow_zero_in_degree(True)
|
|||
|
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
|
|||
|
def forward(self, g):
|
|||
|
"""g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量
|
|||
|
"""
|
|||
|
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
|
|||
|
h = g.in_degrees().view(-1, 1).float() # [N, 1]
|
|||
|
# 执行图卷积和激活函数
|
|||
|
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
|
|||
|
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
|
|||
|
g.ndata['h'] = h # 将特征赋予到图的节点
|
|||
|
# 通过平均池化每个节点的表示得到图表示
|
|||
|
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
|
|||
|
return self.classify(hg) # [n, n_classes]
|
|||
|
|
|||
|
class GCN_node_feature(nn.Module):
|
|||
|
def __init__(self, in_dim, hidden_dim, n_classes):
|
|||
|
super(GCN_node_feature, self).__init__()
|
|||
|
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
|
|||
|
self.conv1.set_allow_zero_in_degree(True)
|
|||
|
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
|
|||
|
self.conv2.set_allow_zero_in_degree(True)
|
|||
|
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
|
|||
|
def forward(self, g):
|
|||
|
"""g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量
|
|||
|
"""
|
|||
|
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
|
|||
|
h = g.ndata['feature']
|
|||
|
# 执行图卷积和激活函数
|
|||
|
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
|
|||
|
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
|
|||
|
g.ndata['h'] = h # 将特征赋予到图的节点
|
|||
|
# 通过平均池化每个节点的表示得到图表示
|
|||
|
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
|
|||
|
return self.classify(hg) # [n, n_classes]
|
|||
|
|
|||
|
class GAT(nn.Module):
|
|||
|
def __init__(self, in_size, hid_size, out_size,class_size, heads):
|
|||
|
super().__init__()
|
|||
|
self.gat_layers = nn.ModuleList()
|
|||
|
# two-layer GAT
|
|||
|
self.gat_layers.append(
|
|||
|
dglnn.GATConv(
|
|||
|
in_size,
|
|||
|
hid_size,
|
|||
|
heads[0],
|
|||
|
feat_drop=0.6,
|
|||
|
attn_drop=0.6,
|
|||
|
activation=F.elu,
|
|||
|
)
|
|||
|
)
|
|||
|
self.gat_layers.append(
|
|||
|
dglnn.GATConv(
|
|||
|
hid_size * heads[0],
|
|||
|
out_size,
|
|||
|
heads[1],
|
|||
|
feat_drop=0.6,
|
|||
|
attn_drop=0.6,
|
|||
|
activation=None,
|
|||
|
)
|
|||
|
)
|
|||
|
self.classify = nn.Linear(hid_size, class_size)
|
|||
|
|
|||
|
def forward(self, g):
|
|||
|
h = g.ndata['feature']
|
|||
|
# h = inputs
|
|||
|
for i, layer in enumerate(self.gat_layers):
|
|||
|
h = layer(g, h)
|
|||
|
if i == 1: # last layer
|
|||
|
h = h.mean(1)
|
|||
|
else: # other layer(s)
|
|||
|
h = h.flatten(1)
|
|||
|
g.ndata['h'] = h # 将特征赋予到图的节点
|
|||
|
# 通过平均池化每个节点的表示得到图表示
|
|||
|
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
|
|||
|
# print()
|
|||
|
return self.classify(hg) # [n, n_classes]
|
|||
|
# return h
|