detect_rep/detect_script/detect_model.py
2023-04-05 10:04:49 +08:00

180 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
import torch
import sys
def load_model(model_name="",device="cpu"):
model=""
if model_name == "GCN_base":
model = GCN_base(1, 256, 2).to(device)
elif model_name == "GCN_malconv_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_malconv_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_malconv_64":
model = GCN_node_feature(64 , 32, 2).to(device)
elif model_name == "GCN_malconv_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_malconv_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_n_gram_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_n_gram_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_n_gram_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_n_gram_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_n_gram_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_n_gram_512":
model = GCN_node_feature(512, 32, 2).to(device)
elif model_name == "GCN_word_frequency_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_word_frequency_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_word_frequency_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_word_frequency_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_word_frequency_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_base_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s_base_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_s368_base_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_16":
model = GCN_node_feature(16, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_32":
model = GCN_node_feature(32, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_64":
model = GCN_node_feature(64, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_128":
model = GCN_node_feature(128, 32, 2).to(device)
elif model_name == "GCN_asm2vec_plus_256":
model = GCN_node_feature(256, 32, 2).to(device)
elif model_name == "GAT":
model = GAT(100, 100, 100,2,heads=[8, 1]).to(device)
return model
class GCN_base(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(GCN_base, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
self.conv1.set_allow_zero_in_degree(True)
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
self.conv2.set_allow_zero_in_degree(True)
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
def forward(self, g):
"""g表示批处理后的大图N表示大图的所有节点数量n表示图的数量
"""
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
h = g.in_degrees().view(-1, 1).float() # [N, 1]
# 执行图卷积和激活函数
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
return self.classify(hg) # [n, n_classes]
class GCN_node_feature(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(GCN_node_feature, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
self.conv1.set_allow_zero_in_degree(True)
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
self.conv2.set_allow_zero_in_degree(True)
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
def forward(self, g):
"""g表示批处理后的大图N表示大图的所有节点数量n表示图的数量
"""
# 我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
h = g.ndata['feature']
# 执行图卷积和激活函数
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
return self.classify(hg) # [n, n_classes]
class GAT(nn.Module):
def __init__(self, in_size, hid_size, out_size,class_size, heads):
super().__init__()
self.gat_layers = nn.ModuleList()
# two-layer GAT
self.gat_layers.append(
dglnn.GATConv(
in_size,
hid_size,
heads[0],
feat_drop=0.6,
attn_drop=0.6,
activation=F.elu,
)
)
self.gat_layers.append(
dglnn.GATConv(
hid_size * heads[0],
out_size,
heads[1],
feat_drop=0.6,
attn_drop=0.6,
activation=None,
)
)
self.classify = nn.Linear(hid_size, class_size)
def forward(self, g):
h = g.ndata['feature']
# h = inputs
for i, layer in enumerate(self.gat_layers):
h = layer(g, h)
if i == 1: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
# print()
return self.classify(hg) # [n, n_classes]
# return h