detect_rep/detect_script/detect_model.py

180 lines
7.5 KiB
Python
Raw Normal View History

2023-04-05 10:04:49 +08:00
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
import torch
import sys
def load_model(model_name="",device="cpu"):
model=""
if model_name == "GCN_base":
model = GCN_base(1, 256, 2).to(device)
elif model_name == "GCN_malconv_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_malconv_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_malconv_64":
model = GCN_node_feature(64 , 32, 2).to(device)
elif model_name == "GCN_malconv_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_malconv_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_n_gram_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_n_gram_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_n_gram_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_n_gram_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_n_gram_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_n_gram_512":
model = GCN_node_feature(512, 32, 2).to(device)
elif model_name == "GCN_word_frequency_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_word_frequency_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_word_frequency_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_word_frequency_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_word_frequency_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GAT":
model = GAT(100, 100, 100,2,heads=[8, 1]).to(device)
return model
class GCN_base(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(GCN_base, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
self.conv1.set_allow_zero_in_degree(True)
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
self.conv2.set_allow_zero_in_degree(True)
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
def forward(self, g):
"""g表示批处理后的大图N表示大图的所有节点数量n表示图的数量
"""
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
h = g.in_degrees().view(-1, 1).float() # [N, 1]
# 执行图卷积和激活函数
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
return self.classify(hg) # [n, n_classes]
class GCN_node_feature(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(GCN_node_feature, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
self.conv1.set_allow_zero_in_degree(True)
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
self.conv2.set_allow_zero_in_degree(True)
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
def forward(self, g):
"""g表示批处理后的大图N表示大图的所有节点数量n表示图的数量
"""
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
h = g.ndata['feature']
# 执行图卷积和激活函数
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
return self.classify(hg) # [n, n_classes]
class GAT(nn.Module):
def __init__(self, in_size, hid_size, out_size,class_size, heads):
super().__init__()
self.gat_layers = nn.ModuleList()
# two-layer GAT
self.gat_layers.append(
dglnn.GATConv(
in_size,
hid_size,
heads[0],
feat_drop=0.6,
attn_drop=0.6,
activation=F.elu,
)
)
self.gat_layers.append(
dglnn.GATConv(
hid_size * heads[0],
out_size,
heads[1],
feat_drop=0.6,
attn_drop=0.6,
activation=None,
)
)
self.classify = nn.Linear(hid_size, class_size)
def forward(self, g):
h = g.ndata['feature']
# h = inputs
for i, layer in enumerate(self.gat_layers):
h = layer(g, h)
if i == 1: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
# print()
return self.classify(hg) # [n, n_classes]
# return h