124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import click
|
||
|
import asm2vec
|
||
|
from utils2 import sigmoid, get_batches, compute_pca, get_dict
|
||
|
from matplotlib import pyplot
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import random
|
||
|
def cosine_similarity(v1, v2):
|
||
|
return (v1 @ v2 / (v1.norm() * v2.norm())).item()
|
||
|
|
||
|
# @click.command()
|
||
|
# @click.option('-i1', '--input1', 'ipath1', help='target function 1', required=True)
|
||
|
# @click.option('-i2', '--input2', 'ipath2', help='target function 2', required=True)
|
||
|
# @click.option('-m', '--model', 'mpath', help='model path', required=True)
|
||
|
# @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True)
|
||
|
# @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True)
|
||
|
# @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True)
|
||
|
|
||
|
|
||
|
def cli():
|
||
|
# mpath = "../model.pt"
|
||
|
mpath="./asm2vec_checkpoints/model.pt"
|
||
|
# epochs = 10
|
||
|
device = "auto"
|
||
|
# lr = 0.02
|
||
|
file_dir="../asm_func/asm_hex/func_bytes.csv"
|
||
|
|
||
|
if device == 'auto':
|
||
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||
|
|
||
|
# load model, tokens
|
||
|
model = asm2vec.utils.load_model(mpath, device=device)
|
||
|
|
||
|
|
||
|
name_list=["_fsopen","_fstat","_ftime","_futime","_fdopen",
|
||
|
"_getch","_getche",
|
||
|
"_strcmpi","_strupr","_strnset",
|
||
|
"fgetwc","fgetws","fgetc","fgets","fprintf","fputc","fputs","fread","fclose","free","fflush"
|
||
|
,"public _class_ostream_____thiscall_ostream operator___void_const","public _class_ostream_____thiscall_ostream operator___unsigned_short_int"
|
||
|
,"public _class_ostream_____thiscall_ostream operator___long_int"
|
||
|
,"vprintf","vsprintf","vswprintf"
|
||
|
,"fputs","fputwc","fputws"
|
||
|
,"wcslen","wcsncat","wcsncmp","wcsncpy","wcsrchr","wcspbrk"
|
||
|
,"strstr","strncmp","strncat","strlen","strcpy","strcmp","strchr",
|
||
|
"time"]
|
||
|
name_list2=[""]*len(name_list)
|
||
|
for i in range(len(name_list)):
|
||
|
name_list2[i]=name_list[i]
|
||
|
|
||
|
|
||
|
# name_list=os.listdir(file_dir)
|
||
|
# name_list=random.sample(name_list, 100)
|
||
|
|
||
|
print(name_list)
|
||
|
# exit()
|
||
|
# for i in range(len(name_list)):
|
||
|
# name_list[i]=os.path.join(file_dir,"sym.MSVCRT20.dll_"+name_list[i])
|
||
|
for i in range(len(name_list)):
|
||
|
name_list[i] = "sym.MSVCRT20.dll_" + name_list[i]
|
||
|
|
||
|
# functions, tokens_new = asm2vec.utils.load_data([ipath1, ipath2])
|
||
|
# print(name_list)
|
||
|
functions = asm2vec.utils.load_data(file_dir,name_list)
|
||
|
# print(len(name_list))
|
||
|
# print(len(functions))
|
||
|
# exit()
|
||
|
# print(len(functions))
|
||
|
# print(functions)
|
||
|
# exit()
|
||
|
# tokens.update(tokens_new)
|
||
|
# model.update(2, tokens.size())
|
||
|
# model.update(len(name_list), tokens.size())
|
||
|
model = model.to(device)
|
||
|
|
||
|
# train function embedding
|
||
|
# model = asm2vec.utils.train(
|
||
|
# functions,
|
||
|
# # tokens,
|
||
|
# model=model,
|
||
|
# epochs=epochs,
|
||
|
# device=device,
|
||
|
# mode='test',
|
||
|
# # mode='train',
|
||
|
# learning_rate=lr
|
||
|
# )
|
||
|
|
||
|
# compare 2 function vectors
|
||
|
|
||
|
|
||
|
|
||
|
# len_list=[i for i in range(len(name_list))]
|
||
|
# v1, v2 = model.to('cpu').embeddings_f(torch.tensor([0, 1]))
|
||
|
|
||
|
|
||
|
# v_list= model.to('cpu').embeddings_f(torch.tensor(len_list))
|
||
|
# a,fun_vec_list=preprocess(functions)
|
||
|
fun_vec_list=[]
|
||
|
for fn in functions:
|
||
|
fun_vec_list.append(fn.fun2vec)
|
||
|
|
||
|
v_list = model.to('cpu').linear_f(torch.tensor(fun_vec_list))
|
||
|
print(v_list)
|
||
|
# exit()
|
||
|
|
||
|
|
||
|
result_vec = np.array(v_list.tolist())
|
||
|
print("res")
|
||
|
|
||
|
result = compute_pca(result_vec, 2)
|
||
|
|
||
|
pyplot.scatter(result[:, 0], result[:, 1])
|
||
|
for i, word in enumerate(name_list2):
|
||
|
pyplot.annotate(word, xy=(result[i, 0], result[i, 1]))
|
||
|
pyplot.show()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
cli()
|