detect_rep/ASM2VEC_base_scripts/func2vec.py
2023-04-05 10:04:49 +08:00

94 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import click
import asm2vec
from utils2 import sigmoid, get_batches, compute_pca, get_dict
from matplotlib import pyplot
import numpy as np
import os
import random
from asm2vec.get_opcode_vector import get_asm_input_vector,str_hex_to_bytes
def cosine_similarity(v1, v2):
return (v1 @ v2 / (v1.norm() * v2.norm())).item()
def load_model(path="./asm2vec_checkpoints/model.pt"):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = asm2vec.utils.load_model(path, device=device)
return model
def func2vec1(model,hex_asm_list=["56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3","56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"]):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fun2vec_origin_list=[]
for hex_asm in hex_asm_list:
hex2vec_list = str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = get_asm_input_vector(hex2vec_list)
hex2vec_list=hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
#开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len=len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin=torch.tensor(fun2vec_origin).to(device)
fun2vec_origin_list.append(fun2vec_origin)
fun2vec_origin_list = torch.tensor([item.cpu().detach().numpy() for item in fun2vec_origin_list]).cuda()
# print(fun2vec_origin_list)
# exit()
embedding_func_vec = model.to(device).linear_f(torch.tensor(fun2vec_origin_list).to(device)).clone().detach().requires_grad_(True)
# print(embedding_func_vec)
return embedding_func_vec
def func2vec(model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device="cpu"
# print(device)
# device="cpu"
# exit()
hex2vec_list = str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = get_asm_input_vector(hex2vec_list)
hex2vec_list=hex2vec_list
# print(hex2vec_list)
# exit()
# print(hex2vec_list)
# print(hex2vec_list)
# exit()
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# print(hex2vec_list)
# print(fun2vec_origin)
# print(torch.tensor([0,0,0])+torch.tensor([12,2,3]))
# exit()
#开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
# print(fun2vec_origin)
# exit()
opcode_seq_len=len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
# print(fun2vec_origin)
fun2vec_origin=torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
# print(embedding_func_vec)
# exit()
return embedding_func_vec
if __name__ == '__main__':
model= load_model(path="./asm2vec_checkpoints/model_100.pt")
# model=
# func2vec(model,hex_asm="f044014910488b81e00000004885c07404")
func2vec(model)
# func2vec(model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3")