detect_rep/detect_script/train.py

414 lines
21 KiB
Python
Raw Permalink Normal View History

2023-04-05 10:04:49 +08:00
import dgl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.metrics import accuracy_score
from load_dataset import mydataset
# from GCN_model import GCN_base,GCN_ASM2VEC_PLUS,GCN_ASM2VEC_BASE,GAT,GCN_ASM2VEC_PLUS2,GCN_ASM2VEC_TEST,GCN_ASM2VEC_TEST2
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from detect_model import load_model
import csv
from load_dataset import csv_read2
import random
def collate(samples):
# 输入参数samples是一个列表
# 列表里的每个元素是图和标签对,如[(graph1, label1), (graph2, label2), ...]
graphs, labels,msg_dicts = map(list, zip(*samples))
# print(graphs)
# print(labels)
return dgl.batch(graphs), torch.tensor(labels, dtype=torch.long),msg_dicts
def plt_save(msg,save_name):
x=range(1,len(msg)+1)
train_loss=[]
test_loss=[]
test_acc=[]
for i in msg:
train_loss.append(i["train_loss"])
test_loss.append(i["test_loss"])
test_acc.append(i["test_acc"])
y1=train_loss
y2=test_loss
y3=test_acc
# print(y1)
plt.subplot(2, 1, 1)
# plt.plot(x1, y1, 'o-',color='r')
plt.plot(x, y1, 'o-', label="Train_loss")
plt.plot(x, y2, 'o-', label="Test_loss")
plt.title(' Loss vs. epoches')
plt.ylabel('Loss')
plt.legend(loc='best')
plt.subplot(2, 1, 2)
plt.plot(x, y2, '.-', label="Test_Loss")
plt.plot(x, y3, '.-', label="Test_Accuracy")
plt.xlabel('Test loss & Test Accuracy vs. epoches')
plt.ylabel('Value')
plt.legend(loc='best')
# plt.show()
f = plt.gcf() #获取当前图像
f.savefig(save_name)
f.clear() #释放内存
from torch.utils.data import Subset
from torch._utils import _accumulate
def data_split(dataset, lengths):
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = list(range(sum(lengths)))
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
malware_train_msg="../cfg_data_with_feature/malware_train_msg.csv"
malware_test_msg="../cfg_data_with_feature/malware_test_msg.csv"
benign_train_msg="../cfg_data_with_feature/benign_train_msg.csv"
benign_test_msg="../cfg_data_with_feature/benign_test_msg.csv"
def generate_train_test_csv(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,train_set_rate=0.75):
train_size = int( malware_num * train_set_rate)
# test_size = malware_num - train_size
malware_list,benign_list=csv_read2(malware_csv=malware_csv, benign_csv=benign_csv, malware_num=malware_num, benign_num=benign_num,
max_nodes=10000, min_nodes=15)
# print(malware_list)
# print(test_list)
malware_train_list=malware_list[:train_size]
malware_test_list=malware_list[train_size:]
benign_train_list=benign_list[:train_size]
benign_test_list=benign_list[train_size:]
with open("../cfg_data_with_feature/malware_train_msg.csv",'w',encoding='utf-8',newline='') as fp:
writer =csv.writer(fp)
writer.writerow(['malware_name','nodes_num','edgs_num'])
writer.writerows(malware_train_list)
with open("../cfg_data_with_feature/malware_test_msg.csv",'w',encoding='utf-8',newline='') as fp:
writer =csv.writer(fp)
writer.writerow(['malware_name','nodes_num','edgs_num'])
writer.writerows(malware_test_list)
with open("../cfg_data_with_feature/benign_train_msg.csv",'w',encoding='utf-8',newline='') as fp:
writer =csv.writer(fp)
writer.writerow(['benign_name','nodes_num','edgs_num'])
writer.writerows(benign_train_list)
with open("../cfg_data_with_feature/benign_test_msg.csv",'w',encoding='utf-8',newline='') as fp:
writer =csv.writer(fp)
writer.writerow(['benign_name','nodes_num','edgs_num'])
writer.writerows(benign_test_list)
# exit()
#解决样本随机化问题
def load_train_test_data(malware_csv='../CFG_data/malware_msg.csv',
benign_csv='../CFG_data/benign_msg.csv',
malware_CFG_dir="../CFG_data/malware",
bengin_CFG_dir="../CFG_data/benign",
node_feature_method="n_gram",
malware_num=500,
benign_num=500,
batch_size=64,
train_set_rate=0.75,
input_dm=0):
generate_train_test_csv(malware_csv=malware_csv,benign_csv=benign_csv,malware_num=malware_num,benign_num=benign_num,train_set_rate=train_set_rate)
train_size = int(malware_num * train_set_rate)
test_size = malware_num - train_size
print("正在读取训练样本...")
train_dataset = mydataset(malware_csv=malware_train_msg,
benign_csv=benign_train_msg,
malware_CFG_dir=malware_CFG_dir,
bengin_CFG_dir=bengin_CFG_dir,
malware_num=train_size,
benign_num=train_size,
node_feature_method=node_feature_method,
input_dm=input_dm
)
print("正在读取测试样本...")
test_dataset= mydataset(malware_csv=malware_test_msg,
benign_csv=benign_test_msg,
malware_CFG_dir=malware_CFG_dir,
bengin_CFG_dir=bengin_CFG_dir,
malware_num=test_size,
benign_num=test_size,
node_feature_method=node_feature_method,
input_dm=input_dm
)
# train_dataset, test_dataset = data_split(dataset, [train_size, test_size])
# train_dataset=dataset[:len(dataset)*0.75]
# test_dataset = dataset[len(dataset) * 0.75:]
# 载入数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False,
collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False,
collate_fn=collate)
# for iter, (batchg, label, msg_dict) in tqdm(enumerate(train_loader)):
# print(label)
#
# exit()
return train_loader,test_loader
def GCN_train(input_dm=0,model_name="GCN_base",saved_checkpoint="",node_feature_method="n_gram",malware_num=500, benign_num=500,save_gap=50,epoch_num = 100,train_set_rate=0.75):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 构造模型
model = load_model(model_name, device)
train_loader, test_loader=load_train_test_data(
malware_csv='../cfg_data_with_feature/malware_pretreat_msg1.csv',
benign_csv='../cfg_data_with_feature/benign_pretreat_msg1.csv',
malware_CFG_dir="../cfg_data_with_feature/malware1",
bengin_CFG_dir="../cfg_data_with_feature/benign1",
node_feature_method=node_feature_method,
malware_num=malware_num,
benign_num=benign_num,
batch_size=64,
train_set_rate=train_set_rate,
input_dm=input_dm)
train_size=(malware_num+benign_num)*train_set_rate
save_gap = save_gap
epoch_num = epoch_num
# 训练集、测试集比例
#记录历史信息,用于作图
msg=[]
start_epoch=0
if saved_checkpoint!="":
checkpoint = torch.load(saved_checkpoint)
model.load_state_dict(checkpoint['model'])
start_epoch = checkpoint['epoch']
#写入历史信息
msg=checkpoint['msg']
# model.load_state_dict(torch.load(saved_model)) # 加载参数
# exit()
# 定义分类交叉熵损失
loss_func = nn.CrossEntropyLoss().to(device)
# 定义Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
train_msg=[]
# 设置训练网络的一些参数
# total_train_step = 0 # 记录训练的次数
now_time = time.strftime("%Y%m%d-%H%M%S")
start = time.time()
for epoch in range(start_epoch,epoch_num)\
:
print("------第 {}/{} 轮训练开始------".format(epoch + 1, epoch_num))
train_loss,train_acc=train_once(train_loader,optimizer,loss_func,model,device)
# print('轮数 {}, 训练集loss {:.4f}, 训练集acc {:.4f}'.format(epoch + 1, train_loss,train_acc))
#测试集
test_loss,test_acc=tes1t_once(test_loader,loss_func,model,device)
end = time.time()
train_msg.append([str(epoch+1),str(round(train_acc,4)),str(round(test_acc,4)),str(round(train_loss,4)),str(round(test_loss,4)),str(round(end - start,2))])
print('轮数 {}, 训练集acc {:.4f}, 测试集acc {:.4f}, 训练集loss {:.4f}, 测试集loss {:.4f}, 时间 {:.1f}s'.format(epoch + 1, train_acc,test_acc,train_loss,test_loss,end - start))
#写入log
with open("./log/{}_{}_{}".format(now_time, model_name, train_size) + '.log', mode='a') as log:
log.write('轮数 {}, 训练集acc {:.4f}, 测试集acc {:.4f}, 训练集loss {:.4f}, 测试集loss {:.4f}, 时间 {:.1f}s'.format(epoch + 1, train_acc,test_acc,train_loss,test_loss,end - start))
msg.append({"epoch": epoch+1,"train_loss":train_loss,"test_loss":test_loss,"test_acc":test_acc})
# 保存模型
state = {"epoch": epoch+1, "model": model.state_dict(),"msg":msg}
if (epoch+1) % save_gap == 0:
#保存训练信息
save_train_msg(csv_save_path="./csv/{}_{}_{}_{}.csv".format(time.strftime("%Y%m%d-%H%M%S"), model_name,train_size, epoch + 1),csv_data=train_msg)
#保存模型
torch.save(state, "./GCN_checkpoints/{}_{}_{}_{}.pth".format(time.strftime("%Y%m%d-%H%M%S"), model_name,train_size, epoch + 1))
# 作图
plt_save(msg,"./png/{}_{}_{}_{}.png".format(time.strftime("%Y%m%d-%H%M%S"), model_name, train_size, epoch + 1))
# torch.save(model, )
print("模型已保存")
model.eval()
print("训练完测试")
test_loss, test_acc = tes1t_once(test_loader, loss_func, model, device)
print("accuracy: ", test_acc)
return train_acc,test_acc
def train_once(train_loader,optimizer,loss_func,model,device):
train_pred, train_label = [], []
epoch_loss = 0
iter = 0
for iter, (batchg, label, msg_dict) in tqdm(enumerate(train_loader)):
# print(batchg)
# print(label)
# print(msg_dict)
# exit()
batchg = batchg.to(device)
label = label.to(device)
prediction = torch.softmax(model(batchg), 1)
del batchg
loss = loss_func(prediction, label)
pred = torch.max(prediction, 1)[1].view(-1)
train_pred += pred.detach().cpu().numpy().tolist()
train_label += label.cpu().numpy().tolist()
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
# total_train_step += 1 # 上面就是进行了一次训练,训练次数 +1
epoch_loss /= (iter + 1)
return epoch_loss,accuracy_score(train_label, train_pred)
# print('Epoch {}, loss {:.4f}'.format(epoch+1, epoch_loss))
def tes1t_once(test_loader,loss_func,model,device):
test_pred, test_label = [], []
total_test_loss = 0
with torch.no_grad():
it = 0
for it, (batchg, label, msg_dict) in tqdm(enumerate(test_loader)):
batchg = batchg.to(device)
label = label.to(device)
pred = torch.softmax(model(batchg), 1)
del batchg
loss = loss_func(pred, label)
total_test_loss = total_test_loss + loss
pred = torch.max(pred, 1)[1].view(-1)
test_pred += pred.detach().cpu().numpy().tolist()
test_label += label.cpu().numpy().tolist()
total_test_loss /= (it + 1)
#返回loss和精确度
return total_test_loss.detach().item(),accuracy_score(test_label, test_pred)
def save_train_msg(csv_save_path,csv_data):
header = ['轮数', '训练集acc','测试集acc', '训练集loss','测试集loss','时间']
#保存到一个csv里
with open(csv_save_path, 'w', encoding='utf-8', newline='') as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(csv_data)
print("成功结束")
def tes1t_acc(model_name="GCN_base",saved_checkpoint="",malware_num=500, benign_num=500,train_set_rate=0.75):
train_loader, test_loader=load_train_test_data(
malware_csv='../CFG_data/malware_n_gram.csv',
benign_csv='../CFG_data/benign_n_gram.csv',
malware_CFG_dir="../CFG_data/malware_n_gram",
bengin_CFG_dir="../CFG_data/benign_n_gram",
malware_num=malware_num,
benign_num=benign_num,
batch_size=64,
train_set_rate=train_set_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 构造模型
model = load_model(model_name, device)
# 记录历史信息,用于作图
msg = []
start_epoch = 0
if saved_checkpoint != "":
checkpoint = torch.load(saved_checkpoint)
model.load_state_dict(checkpoint['model'])
start_epoch = checkpoint['epoch']
# 写入历史信息
msg = checkpoint['msg']
model.eval()
print("训练完测试")
test_pred, test_label = [], []
with torch.no_grad():
for it, (batchg, label, msg_dict) in enumerate(test_loader):
batchg = batchg.to(device)
batchg = batchg.to(device)
label = label.to(device)
pred1 = torch.softmax(model(batchg), 1)
print(pred1.tolist())
pred = torch.max(pred1, 1)[1].view(-1)
test_pred += pred.detach().cpu().numpy().tolist()
test_label += label.cpu().numpy().tolist()
# pred = torch.max(pred1, 1)[0].view(-1)
print("accuracy: ", accuracy_score(test_label, test_pred))
def current_train():
for i in range(100):
train_acc, test_acc = GCN_train(model_name="GCN_asm2vec_base_16", malware_num=malware_num, benign_num=benign_num,save_gap=50, node_feature_method="asm2vec_base_small", epoch_num=100)
if train_acc<0.91 and test_acc<0.86:
print("找到结果。")
break
print("结束")
if __name__ == '__main__':
malware_num=1500
benign_num=1500
# current_train()
# GCN_train(model_name="GCN_base", input_dm=16, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="malconv", epoch_num=100)
# GCN_train(model_name="GCN_malconv_16", input_dm=16,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100)
# GCN_train(model_name="GCN_n_gram_16", input_dm=16, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100)
# GCN_train(model_name="GCN_word_frequency_16", input_dm=16, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_base_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_16",epoch_num = 100)
# GCN_train(model_name="GCN_asm2vec_s_base_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_16", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_s368_base_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_16", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_plus_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_16", epoch_num=100)
# exit()
# GCN_train(model_name="GCN_malconv_32", input_dm=32,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100)
# GCN_train(model_name="GCN_n_gram_32", input_dm=32, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100)
# GCN_train(model_name="GCN_word_frequency_32", input_dm=32, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_base_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_32",epoch_num = 100)
# GCN_train(model_name="GCN_asm2vec_s_base_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_32", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_s368_base_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_32", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_plus_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_32", epoch_num=100)
# GCN_train(model_name="GCN_malconv_64", input_dm=64,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100)
# GCN_train(model_name="GCN_n_gram_64", input_dm=64, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100)
# GCN_train(model_name="GCN_word_frequency_64", input_dm=64, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_base_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_64",epoch_num = 100)
# GCN_train(model_name="GCN_asm2vec_s_base_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_64", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_s368_base_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_64", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_plus_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_64", epoch_num=100)
# GCN_train(model_name="GCN_malconv_128", input_dm=128,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100)
# GCN_train(model_name="GCN_n_gram_128", input_dm=128, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100)
# GCN_train(model_name="GCN_word_frequency_128", input_dm=128, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_base_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_128",epoch_num = 100)
# GCN_train(model_name="GCN_asm2vec_s_base_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_128", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_s368_base_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_128", epoch_num=100)
GCN_train(model_name="GCN_asm2vec_plus_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_128", epoch_num=100)
# GCN_train(model_name="GCN_malconv_256", input_dm=256,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100)
# GCN_train(model_name="GCN_n_gram_256", input_dm=256, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100)
# GCN_train(model_name="GCN_word_frequency_256", input_dm=256, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_base_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_256",epoch_num = 100)
# GCN_train(model_name="GCN_asm2vec_s_base_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_256", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_s368_base_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_256", epoch_num=100)
# GCN_train(model_name="GCN_asm2vec_plus_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_256", epoch_num=100)