import torch import torch.nn as nn import click import asm2vec from utils2 import sigmoid, get_batches, compute_pca, get_dict from matplotlib import pyplot import numpy as np import os import random from asm2vec.get_opcode_vector import get_asm_input_vector,str_hex_to_bytes def cosine_similarity(v1, v2): return (v1 @ v2 / (v1.norm() * v2.norm())).item() def load_model(path="./asm2vec_checkpoints/model.pt"): device = 'cuda' if torch.cuda.is_available() else 'cpu' model = asm2vec.utils.load_model(path, device=device) return model def func2vec1(model,hex_asm_list=["56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3","56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"]): device = 'cuda' if torch.cuda.is_available() else 'cpu' fun2vec_origin_list=[] for hex_asm in hex_asm_list: hex2vec_list = str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = get_asm_input_vector(hex2vec_list) hex2vec_list=hex2vec_list fun2vec_origin = [0.0] * len(hex2vec_list[0]) #开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] opcode_seq_len=len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len fun2vec_origin=torch.tensor(fun2vec_origin).to(device) fun2vec_origin_list.append(fun2vec_origin) fun2vec_origin_list = torch.tensor([item.cpu().detach().numpy() for item in fun2vec_origin_list]).cuda() # print(fun2vec_origin_list) # exit() embedding_func_vec = model.to(device).linear_f(torch.tensor(fun2vec_origin_list).to(device)).clone().detach().requires_grad_(True) # print(embedding_func_vec) return embedding_func_vec def func2vec(model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): device = 'cuda' if torch.cuda.is_available() else 'cpu' # device="cpu" # print(device) # device="cpu" # exit() hex2vec_list = str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = get_asm_input_vector(hex2vec_list) hex2vec_list=hex2vec_list # print(hex2vec_list) # exit() # print(hex2vec_list) # print(hex2vec_list) # exit() fun2vec_origin = [0.0] * len(hex2vec_list[0]) # print(hex2vec_list) # print(fun2vec_origin) # print(torch.tensor([0,0,0])+torch.tensor([12,2,3])) # exit() #开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] # print(fun2vec_origin) # exit() opcode_seq_len=len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len # print(fun2vec_origin) fun2vec_origin=torch.tensor(fun2vec_origin).to(device) embedding_func_vec = model.to(device).linear_f(fun2vec_origin) # print(embedding_func_vec) # exit() return embedding_func_vec if __name__ == '__main__': model= load_model(path="./asm2vec_checkpoints/model_100.pt") # model= # func2vec(model,hex_asm="f044014910488b81e00000004885c07404") func2vec(model) # func2vec(model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3")