detect_rep/ASM2VEC_base_scripts/func2vec.py

94 lines
3.3 KiB
Python
Raw Normal View History

2023-04-05 10:04:49 +08:00
import torch
import torch.nn as nn
import click
import asm2vec
from utils2 import sigmoid, get_batches, compute_pca, get_dict
from matplotlib import pyplot
import numpy as np
import os
import random
from asm2vec.get_opcode_vector import get_asm_input_vector,str_hex_to_bytes
def cosine_similarity(v1, v2):
return (v1 @ v2 / (v1.norm() * v2.norm())).item()
def load_model(path="./asm2vec_checkpoints/model.pt"):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = asm2vec.utils.load_model(path, device=device)
return model
def func2vec1(model,hex_asm_list=["56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3","56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"]):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fun2vec_origin_list=[]
for hex_asm in hex_asm_list:
hex2vec_list = str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = get_asm_input_vector(hex2vec_list)
hex2vec_list=hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
#开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len=len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin=torch.tensor(fun2vec_origin).to(device)
fun2vec_origin_list.append(fun2vec_origin)
fun2vec_origin_list = torch.tensor([item.cpu().detach().numpy() for item in fun2vec_origin_list]).cuda()
# print(fun2vec_origin_list)
# exit()
embedding_func_vec = model.to(device).linear_f(torch.tensor(fun2vec_origin_list).to(device)).clone().detach().requires_grad_(True)
# print(embedding_func_vec)
return embedding_func_vec
def func2vec(model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device="cpu"
# print(device)
# device="cpu"
# exit()
hex2vec_list = str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = get_asm_input_vector(hex2vec_list)
hex2vec_list=hex2vec_list
# print(hex2vec_list)
# exit()
# print(hex2vec_list)
# print(hex2vec_list)
# exit()
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# print(hex2vec_list)
# print(fun2vec_origin)
# print(torch.tensor([0,0,0])+torch.tensor([12,2,3]))
# exit()
#开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
# print(fun2vec_origin)
# exit()
opcode_seq_len=len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
# print(fun2vec_origin)
fun2vec_origin=torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
# print(embedding_func_vec)
# exit()
return embedding_func_vec
if __name__ == '__main__':
model= load_model(path="./asm2vec_checkpoints/model_100.pt")
# model=
# func2vec(model,hex_asm="f044014910488b81e00000004885c07404")
func2vec(model)
# func2vec(model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3")