180 lines
7.5 KiB
Python
180 lines
7.5 KiB
Python
import torch.nn as nn
|
||
from dgl.nn.pytorch import GraphConv
|
||
import torch.nn.functional as F
|
||
import dgl
|
||
import dgl.nn as dglnn
|
||
import torch
|
||
import sys
|
||
|
||
def load_model(model_name="",device="cpu"):
|
||
model=""
|
||
if model_name == "GCN_base":
|
||
model = GCN_base(1, 256, 2).to(device)
|
||
elif model_name == "GCN_malconv_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_malconv_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_malconv_64":
|
||
model = GCN_node_feature(64 , 32, 2).to(device)
|
||
elif model_name == "GCN_malconv_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_malconv_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
|
||
elif model_name == "GCN_n_gram_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_n_gram_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_n_gram_64":
|
||
model = GCN_node_feature(64, 32, 2).to(device)
|
||
elif model_name == "GCN_n_gram_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_n_gram_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
elif model_name == "GCN_n_gram_512":
|
||
model = GCN_node_feature(512, 32, 2).to(device)
|
||
|
||
elif model_name == "GCN_word_frequency_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_word_frequency_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_word_frequency_64":
|
||
model = GCN_node_feature(64, 32, 2).to(device)
|
||
elif model_name == "GCN_word_frequency_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_word_frequency_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
|
||
elif model_name == "GCN_asm2vec_base_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_base_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_base_64":
|
||
model = GCN_node_feature(64, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_base_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_base_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
|
||
elif model_name == "GCN_asm2vec_s_base_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s_base_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s_base_64":
|
||
model = GCN_node_feature(64, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s_base_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s_base_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
|
||
elif model_name == "GCN_asm2vec_s368_base_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s368_base_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s368_base_64":
|
||
model = GCN_node_feature(64, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s368_base_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_s368_base_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
|
||
elif model_name == "GCN_asm2vec_plus_16":
|
||
model = GCN_node_feature(16, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_plus_32":
|
||
model = GCN_node_feature(32, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_plus_64":
|
||
model = GCN_node_feature(64, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_plus_128":
|
||
model = GCN_node_feature(128, 32, 2).to(device)
|
||
elif model_name == "GCN_asm2vec_plus_256":
|
||
model = GCN_node_feature(256, 32, 2).to(device)
|
||
|
||
elif model_name == "GAT":
|
||
model = GAT(100, 100, 100,2,heads=[8, 1]).to(device)
|
||
return model
|
||
|
||
class GCN_base(nn.Module):
|
||
def __init__(self, in_dim, hidden_dim, n_classes):
|
||
super(GCN_base, self).__init__()
|
||
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
|
||
self.conv1.set_allow_zero_in_degree(True)
|
||
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
|
||
self.conv2.set_allow_zero_in_degree(True)
|
||
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
|
||
def forward(self, g):
|
||
"""g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量
|
||
"""
|
||
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
|
||
h = g.in_degrees().view(-1, 1).float() # [N, 1]
|
||
# 执行图卷积和激活函数
|
||
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
|
||
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
|
||
g.ndata['h'] = h # 将特征赋予到图的节点
|
||
# 通过平均池化每个节点的表示得到图表示
|
||
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
|
||
return self.classify(hg) # [n, n_classes]
|
||
|
||
class GCN_node_feature(nn.Module):
|
||
def __init__(self, in_dim, hidden_dim, n_classes):
|
||
super(GCN_node_feature, self).__init__()
|
||
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
|
||
self.conv1.set_allow_zero_in_degree(True)
|
||
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
|
||
self.conv2.set_allow_zero_in_degree(True)
|
||
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
|
||
def forward(self, g):
|
||
"""g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量
|
||
"""
|
||
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
|
||
h = g.ndata['feature']
|
||
# 执行图卷积和激活函数
|
||
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
|
||
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
|
||
g.ndata['h'] = h # 将特征赋予到图的节点
|
||
# 通过平均池化每个节点的表示得到图表示
|
||
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
|
||
return self.classify(hg) # [n, n_classes]
|
||
|
||
class GAT(nn.Module):
|
||
def __init__(self, in_size, hid_size, out_size,class_size, heads):
|
||
super().__init__()
|
||
self.gat_layers = nn.ModuleList()
|
||
# two-layer GAT
|
||
self.gat_layers.append(
|
||
dglnn.GATConv(
|
||
in_size,
|
||
hid_size,
|
||
heads[0],
|
||
feat_drop=0.6,
|
||
attn_drop=0.6,
|
||
activation=F.elu,
|
||
)
|
||
)
|
||
self.gat_layers.append(
|
||
dglnn.GATConv(
|
||
hid_size * heads[0],
|
||
out_size,
|
||
heads[1],
|
||
feat_drop=0.6,
|
||
attn_drop=0.6,
|
||
activation=None,
|
||
)
|
||
)
|
||
self.classify = nn.Linear(hid_size, class_size)
|
||
|
||
def forward(self, g):
|
||
h = g.ndata['feature']
|
||
# h = inputs
|
||
for i, layer in enumerate(self.gat_layers):
|
||
h = layer(g, h)
|
||
if i == 1: # last layer
|
||
h = h.mean(1)
|
||
else: # other layer(s)
|
||
h = h.flatten(1)
|
||
g.ndata['h'] = h # 将特征赋予到图的节点
|
||
# 通过平均池化每个节点的表示得到图表示
|
||
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
|
||
# print()
|
||
return self.classify(hg) # [n, n_classes]
|
||
# return h
|