detect_rep/ASM2VEC_plus_scripts/asm2vec/utils.py
2023-04-05 10:04:49 +08:00

209 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import torch
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from .datatype import Tokens, Function, Instruction
from .model import ASM2VEC
from tqdm import tqdm
class AsmDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
import csv
def csv_read(malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
malware_cfg_list=[]
benign_cfg_list=[]
with open('../CFG_data/malware_msg.csv', 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
malware_cfg_list.append(file_name+".gexf")
if len(malware_cfg_list) == malware_num:
break
def load_data(paths, test_list=None,limit=None,normal=False):
print("正在加载二进制函数为向量:")
functions = []
with open(paths, 'r', encoding='utf-8') as f:
# 经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
# 先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in tqdm(reader):
func_name = row[0]
bytes = row[1]
fn = Function.load(bytes,normal)
# 在函数对象列表中添加fn
functions.append(fn)
# 如果测试函数列表存在,则判断加入的函数在不在测试列表中,不在的话退出
if test_list != None:
if len(test_list) == len(functions):
break
if func_name not in test_list:
continue
if limit and len(functions) >= limit:
break
# 在token列表中添加函数的
# tokens是每个函数中所有的操作符和操作数的列表
# 返回functions的列表与对应的token列表
return functions
#找到语境词与中心词、段向量
def preprocess(functions):
#上下文窗口大小
C = 1
#context_vec_list的每个成员由 段向量、前一条向量、后一条向量构成
context_vec_list=[]
center_vec_list = []
for i, fn in enumerate(functions):
hex2vec_list = fn.hex2vec_list
fun2vec = fn.fun2vec
j=C
while True:
center_word = hex2vec_list[j]
center_vec_list.append(center_word)
context_words = fun2vec + hex2vec_list[(j - C):j][0] + hex2vec_list[(j + 1):(j + C + 1)][0]
context_vec_list.append(context_words)
j+=1
if j >= len(hex2vec_list)-1:
break
return torch.tensor(context_vec_list), torch.tensor(center_vec_list)
def train(
functions,
model=None,
embedding_size=100,
batch_size=1024,
epochs=10,
calc_acc=False,
device='cuda:0',
mode='train',
callback=None,
learning_rate=0.02
):
vocab_size=len(functions[0].hex2vec_list[0])
# print(vocab_size)
# exit()
if mode == 'train':
if model is None:
model = ASM2VEC(vocab_size=vocab_size, function_size=len(functions), embedding_size=embedding_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
elif mode == 'test':
if model is None:
raise ValueError("test mode required pretrained model")
optimizer = torch.optim.Adam(model.linear_f.parameters(), lr=learning_rate)
else:
raise ValueError("Unknown mode")
loader = DataLoader(AsmDataset(*preprocess(functions)), batch_size=batch_size, shuffle=True)
bar =tqdm(range(epochs),position=0)
for epoch in bar:
start = time.time()
loss_sum, loss_count, accs = 0.0, 0, []
model.train()
#inp是语境词向量pos是中心词向量neg是错误的中心词
# for i, (inp, pos) in enumerate(loader):
for i, (context_vec, center_vec) in enumerate(loader):
# print(context_vec)
# print(len(context_vec))
# print(center_vec)
# print(len(center_vec))
# exit()
loss = model(context_vec.to(device), center_vec.to(device))
loss_sum, loss_count = loss_sum + loss, loss_count + 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i == 0 and calc_acc:
probs = model.predict(context_vec.to(device), center_vec.to(device))
accs.append(accuracy(center_vec, probs))
if callback:
callback({
'model': model,
# 'tokens': tokens,
'epoch': epoch,
'time': time.time() - start,
'loss': loss_sum / loss_count,
'accuracy': torch.tensor(accs).mean() if calc_acc else None
})
# bar.set_postfix(loss=f"{loss_sum:.6f}", epoch=epoch + 1)
return model
def save_model(path, model, ):
torch.save({
'model_params': (
model.embeddings.num_embeddings,
model.embeddings_f.num_embeddings,
model.embeddings.embedding_dim
),
'model': model.state_dict(),
}, path)
def load_model(path, device='cpu'):
checkpoint = torch.load(path, map_location=device)
# tokens = Tokens()
# tokens.load_state_dict(checkpoint['tokens'])
model = ASM2VEC(*checkpoint['model_params'])
model.load_state_dict(checkpoint['model'])
model = model.to(device)
return model
def show_probs(x, y, probs, tokens, limit=None, pretty=False):
if pretty:
TL, TR, BL, BR = '', '', '', ''
LM, RM, TM, BM = '', '', '', ''
H, V = '', ''
arrow = ''
else:
TL = TR = BL = BR = '+'
LM = RM = TM = BM = '+'
H, V = '-', '|'
arrow = '->'
top = probs.topk(5)
for i, (xi, yi) in enumerate(zip(x, y)):
if limit and i >= limit:
break
xi, yi = xi.tolist(), yi.tolist()
print(TL + H * 42 + TR)
print(f'{V} {str(Instruction(tokens[xi[1]], tokens[xi[2:4]])):37} {V}')
print(f'{V} {arrow} {str(Instruction(tokens[yi[0]], tokens[yi[1:3]])):37} {V}')
print(f'{V} {str(Instruction(tokens[xi[4]], tokens[xi[5:7]])):37} {V}')
print(LM + H * 8 + TM + H * 33 + RM)
for value, index in zip(top.values[i], top.indices[i]):
if index in yi:
colorbegin, colorclear = '\033[92m', '\033[0m'
else:
colorbegin, colorclear = '', ''
print(f'{V} {colorbegin}{value*100:05.2f}%{colorclear} {V} {colorbegin}{tokens[index.item()].name:31}{colorclear} {V}')
print(BL + H * 8 + BM + H * 33 + BR)
def accuracy(y, probs):
y=y.type(torch.bool)
# print()
# print(probs)
# exit()
return torch.mean(torch.tensor([torch.sum(probs[i][yi]) for i, yi in enumerate(y)]))