209 lines
7.3 KiB
Python
209 lines
7.3 KiB
Python
|
import os
|
|||
|
import time
|
|||
|
import torch
|
|||
|
from torch.utils.data import DataLoader, Dataset
|
|||
|
from pathlib import Path
|
|||
|
from .datatype import Tokens, Function, Instruction
|
|||
|
from .model import ASM2VEC
|
|||
|
from tqdm import tqdm
|
|||
|
class AsmDataset(Dataset):
|
|||
|
def __init__(self, x, y):
|
|||
|
self.x = x
|
|||
|
self.y = y
|
|||
|
def __len__(self):
|
|||
|
return len(self.x)
|
|||
|
def __getitem__(self, index):
|
|||
|
return self.x[index], self.y[index]
|
|||
|
import csv
|
|||
|
def csv_read(malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
|
|||
|
|
|||
|
malware_cfg_list=[]
|
|||
|
benign_cfg_list=[]
|
|||
|
with open('../CFG_data/malware_msg.csv', 'r', encoding='utf-8') as f:
|
|||
|
#经下述操作后,reader成为了一个可以迭代行的文件
|
|||
|
reader = csv.reader(f)
|
|||
|
#先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行
|
|||
|
header = next(reader)
|
|||
|
print(header)
|
|||
|
|
|||
|
for row in reader:
|
|||
|
file_name=row[0]
|
|||
|
nodes_num=row[1]
|
|||
|
edgs_num=row[2]
|
|||
|
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
|
|||
|
malware_cfg_list.append(file_name+".gexf")
|
|||
|
if len(malware_cfg_list) == malware_num:
|
|||
|
break
|
|||
|
|
|||
|
def load_data(paths, test_list=None,limit=None,normal=False):
|
|||
|
print("正在加载二进制函数为向量:")
|
|||
|
functions = []
|
|||
|
with open(paths, 'r', encoding='utf-8') as f:
|
|||
|
# 经下述操作后,reader成为了一个可以迭代行的文件
|
|||
|
reader = csv.reader(f)
|
|||
|
# 先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行
|
|||
|
header = next(reader)
|
|||
|
print(header)
|
|||
|
for row in tqdm(reader):
|
|||
|
func_name = row[0]
|
|||
|
bytes = row[1]
|
|||
|
fn = Function.load(bytes,normal)
|
|||
|
# 在函数对象列表中添加fn
|
|||
|
functions.append(fn)
|
|||
|
# 如果测试函数列表存在,则判断加入的函数在不在测试列表中,不在的话退出
|
|||
|
if test_list != None:
|
|||
|
if len(test_list) == len(functions):
|
|||
|
break
|
|||
|
if func_name not in test_list:
|
|||
|
continue
|
|||
|
if limit and len(functions) >= limit:
|
|||
|
break
|
|||
|
# 在token列表中添加函数的
|
|||
|
# tokens是每个函数中所有的操作符和操作数的列表
|
|||
|
# 返回functions的列表与对应的token列表
|
|||
|
return functions
|
|||
|
|
|||
|
|
|||
|
#找到语境词与中心词、段向量
|
|||
|
def preprocess(functions):
|
|||
|
|
|||
|
#上下文窗口大小
|
|||
|
C = 1
|
|||
|
#context_vec_list的每个成员由 段向量、前一条向量、后一条向量构成
|
|||
|
context_vec_list=[]
|
|||
|
center_vec_list = []
|
|||
|
for i, fn in enumerate(functions):
|
|||
|
hex2vec_list = fn.hex2vec_list
|
|||
|
fun2vec = fn.fun2vec
|
|||
|
j=C
|
|||
|
while True:
|
|||
|
center_word = hex2vec_list[j]
|
|||
|
center_vec_list.append(center_word)
|
|||
|
context_words = fun2vec + hex2vec_list[(j - C):j][0] + hex2vec_list[(j + 1):(j + C + 1)][0]
|
|||
|
context_vec_list.append(context_words)
|
|||
|
j+=1
|
|||
|
if j >= len(hex2vec_list)-1:
|
|||
|
break
|
|||
|
|
|||
|
return torch.tensor(context_vec_list), torch.tensor(center_vec_list)
|
|||
|
|
|||
|
def train(
|
|||
|
functions,
|
|||
|
model=None,
|
|||
|
embedding_size=100,
|
|||
|
batch_size=1024,
|
|||
|
epochs=10,
|
|||
|
calc_acc=False,
|
|||
|
device='cuda:0',
|
|||
|
mode='train',
|
|||
|
callback=None,
|
|||
|
learning_rate=0.02
|
|||
|
):
|
|||
|
vocab_size=len(functions[0].hex2vec_list[0])
|
|||
|
# print(vocab_size)
|
|||
|
# exit()
|
|||
|
if mode == 'train':
|
|||
|
if model is None:
|
|||
|
model = ASM2VEC(vocab_size=vocab_size, function_size=len(functions), embedding_size=embedding_size).to(device)
|
|||
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|||
|
elif mode == 'test':
|
|||
|
if model is None:
|
|||
|
raise ValueError("test mode required pretrained model")
|
|||
|
optimizer = torch.optim.Adam(model.linear_f.parameters(), lr=learning_rate)
|
|||
|
else:
|
|||
|
raise ValueError("Unknown mode")
|
|||
|
|
|||
|
loader = DataLoader(AsmDataset(*preprocess(functions)), batch_size=batch_size, shuffle=True)
|
|||
|
|
|||
|
bar =tqdm(range(epochs),position=0)
|
|||
|
for epoch in bar:
|
|||
|
start = time.time()
|
|||
|
loss_sum, loss_count, accs = 0.0, 0, []
|
|||
|
model.train()
|
|||
|
#inp是语境词向量,pos是中心词向量,neg是错误的中心词
|
|||
|
# for i, (inp, pos) in enumerate(loader):
|
|||
|
for i, (context_vec, center_vec) in enumerate(loader):
|
|||
|
# print(context_vec)
|
|||
|
# print(len(context_vec))
|
|||
|
# print(center_vec)
|
|||
|
# print(len(center_vec))
|
|||
|
# exit()
|
|||
|
loss = model(context_vec.to(device), center_vec.to(device))
|
|||
|
loss_sum, loss_count = loss_sum + loss, loss_count + 1
|
|||
|
optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
|
|||
|
if i == 0 and calc_acc:
|
|||
|
probs = model.predict(context_vec.to(device), center_vec.to(device))
|
|||
|
accs.append(accuracy(center_vec, probs))
|
|||
|
|
|||
|
if callback:
|
|||
|
callback({
|
|||
|
'model': model,
|
|||
|
# 'tokens': tokens,
|
|||
|
'epoch': epoch,
|
|||
|
'time': time.time() - start,
|
|||
|
'loss': loss_sum / loss_count,
|
|||
|
'accuracy': torch.tensor(accs).mean() if calc_acc else None
|
|||
|
})
|
|||
|
# bar.set_postfix(loss=f"{loss_sum:.6f}", epoch=epoch + 1)
|
|||
|
return model
|
|||
|
|
|||
|
def save_model(path, model, ):
|
|||
|
torch.save({
|
|||
|
'model_params': (
|
|||
|
model.embeddings.num_embeddings,
|
|||
|
model.embeddings_f.num_embeddings,
|
|||
|
model.embeddings.embedding_dim
|
|||
|
),
|
|||
|
'model': model.state_dict(),
|
|||
|
|
|||
|
}, path)
|
|||
|
|
|||
|
def load_model(path, device='cpu'):
|
|||
|
checkpoint = torch.load(path, map_location=device)
|
|||
|
# tokens = Tokens()
|
|||
|
# tokens.load_state_dict(checkpoint['tokens'])
|
|||
|
model = ASM2VEC(*checkpoint['model_params'])
|
|||
|
model.load_state_dict(checkpoint['model'])
|
|||
|
model = model.to(device)
|
|||
|
return model
|
|||
|
|
|||
|
def show_probs(x, y, probs, tokens, limit=None, pretty=False):
|
|||
|
if pretty:
|
|||
|
TL, TR, BL, BR = '┌', '┐', '└', '┘'
|
|||
|
LM, RM, TM, BM = '├', '┤', '┬', '┴'
|
|||
|
H, V = '─', '│'
|
|||
|
arrow = ' ➔'
|
|||
|
else:
|
|||
|
TL = TR = BL = BR = '+'
|
|||
|
LM = RM = TM = BM = '+'
|
|||
|
H, V = '-', '|'
|
|||
|
arrow = '->'
|
|||
|
top = probs.topk(5)
|
|||
|
for i, (xi, yi) in enumerate(zip(x, y)):
|
|||
|
if limit and i >= limit:
|
|||
|
break
|
|||
|
xi, yi = xi.tolist(), yi.tolist()
|
|||
|
print(TL + H * 42 + TR)
|
|||
|
print(f'{V} {str(Instruction(tokens[xi[1]], tokens[xi[2:4]])):37} {V}')
|
|||
|
print(f'{V} {arrow} {str(Instruction(tokens[yi[0]], tokens[yi[1:3]])):37} {V}')
|
|||
|
print(f'{V} {str(Instruction(tokens[xi[4]], tokens[xi[5:7]])):37} {V}')
|
|||
|
print(LM + H * 8 + TM + H * 33 + RM)
|
|||
|
for value, index in zip(top.values[i], top.indices[i]):
|
|||
|
if index in yi:
|
|||
|
colorbegin, colorclear = '\033[92m', '\033[0m'
|
|||
|
else:
|
|||
|
colorbegin, colorclear = '', ''
|
|||
|
print(f'{V} {colorbegin}{value*100:05.2f}%{colorclear} {V} {colorbegin}{tokens[index.item()].name:31}{colorclear} {V}')
|
|||
|
print(BL + H * 8 + BM + H * 33 + BR)
|
|||
|
|
|||
|
def accuracy(y, probs):
|
|||
|
y=y.type(torch.bool)
|
|||
|
# print()
|
|||
|
# print(probs)
|
|||
|
# exit()
|
|||
|
return torch.mean(torch.tensor([torch.sum(probs[i][yi]) for i, yi in enumerate(y)]))
|
|||
|
|