import torch import torch.nn as nn import torch.nn.functional as F bce, sigmoid, softmax = nn.BCELoss(), nn.Sigmoid(), nn.Softmax(dim=1) class ASM2VEC(nn.Module): def __init__(self, vocab_size, function_size, embedding_size): super(ASM2VEC, self).__init__() self.linear_f=nn.Linear(vocab_size,embedding_size, bias=True) self.linear_context= nn.Linear(vocab_size, embedding_size, bias=True) # self.linear_next = nn.Linear(vocab_size, embedding_size, bias=True) self.linear_output = nn.Linear(embedding_size,vocab_size, bias=True) #生成token的映射向量 self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size)) #生成function的映射向量 self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2) #语境向量embedding self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2) def update(self, function_size_new, vocab_size_new): print("into update") device = self.embeddings.weight.device vocab_size, function_size, embedding_size = self.embeddings.num_embeddings, self.embeddings_f.num_embeddings, self.embeddings.embedding_dim if vocab_size_new != vocab_size: weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).to(device)]) self.embeddings = nn.Embedding(vocab_size_new, embedding_size, _weight=weight) weight_r = torch.cat([self.embeddings_r.weight, ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2).to(device)]) self.embeddings_r = nn.Embedding(vocab_size_new, 2 * embedding_size, _weight=weight_r) self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5)/embedding_size/2).to(device)) def get_func_feature(self,func_vec): res = F.normalize(self.embeddings_f(func_vec), p=2, dim=1) return res def v(self, context_vec): # print(self.linear_f.weight) # print(self.linear_f.weight.size()) # print(context_vec.size()) # exit() # exit() #取段id # print(context_vec) len_vec=len(context_vec[0]) # print(len_vec) # exit() # print(context_vec[:,0:int(len_vec/3)].size()) # exit() # v_f = self.embeddings_f(context_vec[:, 443]) v_f=context_vec[:,0:int(len_vec/3)] v_f = self.linear_f(v_f) # print(v_f) # exit() v_prev=context_vec[:,int(len_vec/3):int(len_vec/3*2)] v_prev=self.linear_context(v_prev) v_next = context_vec[:, int(len_vec/3*2):int(len_vec)] v_next = self.linear_context(v_next) # print(v_f.size()) # print(v_prev.size()) # print(v_next.size()) # # print("v_f") # print(inp[:, 0]) # print(inp[:, 0].shape) # print(v_f) # print(v_f.shape) # exit() #剔除段id # e = self.embeddings(inp[:,1:]) #取前一条指令 # v_prev = torch.cat([e[:,0], (e[:,1] + e[:,2]) / 2], dim=1) #取后一条指令向量 # v_next = torch.cat([e[:,3], (e[:,4] + e[:,5]) / 2], dim=1) #生成语境向量+段向量 v = ((v_f + v_prev + v_next) / 3) # print(v) # print(v.size()) # exit() return v def forward(self, context_vec, center_vec): # print("size1") # print(context_vec.size(), center_vec.size()) # print("into") device, batch_size = context_vec.device, context_vec.shape[0] #inp应该是语境向量+段向量,v是输入 #torch.Size([1024, 200, 1]) v = self.v(context_vec) # print("size2") # print(v.size()) # exit() #torch.Size([1024, 28, 200]) # a=self.embeddings_r(torch.cat([pos, neg], dim=1)) # print(a.size()) # # exit() # negative sampling loss pred=self.linear_output(v) # pred = torch.bmm(self.embeddings_r(torch.cat([pos, neg], dim=1)), v).squeeze() # print(pred) # print(pred.size()) # exit() #torch.Size([1024, 28]) # print(pred.size()) # # label = torch.cat([torch.ones(batch_size, 3), torch.zeros(batch_size, neg.shape[1])], dim=1).to(device) # print(label) # print(label.size()) label=center_vec # loss=bce(sigmoid(pred), label) loss = bce(softmax(pred), label) # print(loss) # exit() return loss def predict(self, context_vec, center_vec): device, batch_size = context_vec.device, context_vec.shape[0] v = self.v(context_vec) # probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2) pred = self.linear_output(v) pred=softmax(pred) return pred