import dgl import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.optim as optim from sklearn.metrics import accuracy_score from load_dataset import mydataset # from GCN_model import GCN_base,GCN_ASM2VEC_PLUS,GCN_ASM2VEC_BASE,GAT,GCN_ASM2VEC_PLUS2,GCN_ASM2VEC_TEST,GCN_ASM2VEC_TEST2 from torch.utils.data import DataLoader from tqdm import tqdm import time import matplotlib.pyplot as plt from detect_model import load_model import csv from load_dataset import csv_read2 import random def collate(samples): # 输入参数samples是一个列表 # 列表里的每个元素是图和标签对,如[(graph1, label1), (graph2, label2), ...] graphs, labels,msg_dicts = map(list, zip(*samples)) # print(graphs) # print(labels) return dgl.batch(graphs), torch.tensor(labels, dtype=torch.long),msg_dicts def plt_save(msg,save_name): x=range(1,len(msg)+1) train_loss=[] test_loss=[] test_acc=[] for i in msg: train_loss.append(i["train_loss"]) test_loss.append(i["test_loss"]) test_acc.append(i["test_acc"]) y1=train_loss y2=test_loss y3=test_acc # print(y1) plt.subplot(2, 1, 1) # plt.plot(x1, y1, 'o-',color='r') plt.plot(x, y1, 'o-', label="Train_loss") plt.plot(x, y2, 'o-', label="Test_loss") plt.title(' Loss vs. epoches') plt.ylabel('Loss') plt.legend(loc='best') plt.subplot(2, 1, 2) plt.plot(x, y2, '.-', label="Test_Loss") plt.plot(x, y3, '.-', label="Test_Accuracy") plt.xlabel('Test loss & Test Accuracy vs. epoches') plt.ylabel('Value') plt.legend(loc='best') # plt.show() f = plt.gcf() #获取当前图像 f.savefig(save_name) f.clear() #释放内存 from torch.utils.data import Subset from torch._utils import _accumulate def data_split(dataset, lengths): r""" Randomly split a dataset into non-overlapping new datasets of given lengths. Arguments: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced """ if sum(lengths) != len(dataset): raise ValueError("Sum of input lengths does not equal the length of the input dataset!") indices = list(range(sum(lengths))) return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] malware_train_msg="../cfg_data_with_feature/malware_train_msg.csv" malware_test_msg="../cfg_data_with_feature/malware_test_msg.csv" benign_train_msg="../cfg_data_with_feature/benign_train_msg.csv" benign_test_msg="../cfg_data_with_feature/benign_test_msg.csv" def generate_train_test_csv(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,train_set_rate=0.75): train_size = int( malware_num * train_set_rate) # test_size = malware_num - train_size malware_list,benign_list=csv_read2(malware_csv=malware_csv, benign_csv=benign_csv, malware_num=malware_num, benign_num=benign_num, max_nodes=10000, min_nodes=15) # print(malware_list) # print(test_list) malware_train_list=malware_list[:train_size] malware_test_list=malware_list[train_size:] benign_train_list=benign_list[:train_size] benign_test_list=benign_list[train_size:] with open("../cfg_data_with_feature/malware_train_msg.csv",'w',encoding='utf-8',newline='') as fp: writer =csv.writer(fp) writer.writerow(['malware_name','nodes_num','edgs_num']) writer.writerows(malware_train_list) with open("../cfg_data_with_feature/malware_test_msg.csv",'w',encoding='utf-8',newline='') as fp: writer =csv.writer(fp) writer.writerow(['malware_name','nodes_num','edgs_num']) writer.writerows(malware_test_list) with open("../cfg_data_with_feature/benign_train_msg.csv",'w',encoding='utf-8',newline='') as fp: writer =csv.writer(fp) writer.writerow(['benign_name','nodes_num','edgs_num']) writer.writerows(benign_train_list) with open("../cfg_data_with_feature/benign_test_msg.csv",'w',encoding='utf-8',newline='') as fp: writer =csv.writer(fp) writer.writerow(['benign_name','nodes_num','edgs_num']) writer.writerows(benign_test_list) # exit() #解决样本随机化问题 def load_train_test_data(malware_csv='../CFG_data/malware_msg.csv', benign_csv='../CFG_data/benign_msg.csv', malware_CFG_dir="../CFG_data/malware", bengin_CFG_dir="../CFG_data/benign", node_feature_method="n_gram", malware_num=500, benign_num=500, batch_size=64, train_set_rate=0.75, input_dm=0): generate_train_test_csv(malware_csv=malware_csv,benign_csv=benign_csv,malware_num=malware_num,benign_num=benign_num,train_set_rate=train_set_rate) train_size = int(malware_num * train_set_rate) test_size = malware_num - train_size print("正在读取训练样本...") train_dataset = mydataset(malware_csv=malware_train_msg, benign_csv=benign_train_msg, malware_CFG_dir=malware_CFG_dir, bengin_CFG_dir=bengin_CFG_dir, malware_num=train_size, benign_num=train_size, node_feature_method=node_feature_method, input_dm=input_dm ) print("正在读取测试样本...") test_dataset= mydataset(malware_csv=malware_test_msg, benign_csv=benign_test_msg, malware_CFG_dir=malware_CFG_dir, bengin_CFG_dir=bengin_CFG_dir, malware_num=test_size, benign_num=test_size, node_feature_method=node_feature_method, input_dm=input_dm ) # train_dataset, test_dataset = data_split(dataset, [train_size, test_size]) # train_dataset=dataset[:len(dataset)*0.75] # test_dataset = dataset[len(dataset) * 0.75:] # 载入数据 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False, collate_fn=collate) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False, collate_fn=collate) # for iter, (batchg, label, msg_dict) in tqdm(enumerate(train_loader)): # print(label) # # exit() return train_loader,test_loader def GCN_train(input_dm=0,model_name="GCN_base",saved_checkpoint="",node_feature_method="n_gram",malware_num=500, benign_num=500,save_gap=50,epoch_num = 100,train_set_rate=0.75): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 构造模型 model = load_model(model_name, device) train_loader, test_loader=load_train_test_data( malware_csv='../cfg_data_with_feature/malware_pretreat_msg1.csv', benign_csv='../cfg_data_with_feature/benign_pretreat_msg1.csv', malware_CFG_dir="../cfg_data_with_feature/malware1", bengin_CFG_dir="../cfg_data_with_feature/benign1", node_feature_method=node_feature_method, malware_num=malware_num, benign_num=benign_num, batch_size=64, train_set_rate=train_set_rate, input_dm=input_dm) train_size=(malware_num+benign_num)*train_set_rate save_gap = save_gap epoch_num = epoch_num # 训练集、测试集比例 #记录历史信息,用于作图 msg=[] start_epoch=0 if saved_checkpoint!="": checkpoint = torch.load(saved_checkpoint) model.load_state_dict(checkpoint['model']) start_epoch = checkpoint['epoch'] #写入历史信息 msg=checkpoint['msg'] # model.load_state_dict(torch.load(saved_model)) # 加载参数 # exit() # 定义分类交叉熵损失 loss_func = nn.CrossEntropyLoss().to(device) # 定义Adam优化器 optimizer = optim.Adam(model.parameters(), lr=0.001) model.train() train_msg=[] # 设置训练网络的一些参数 # total_train_step = 0 # 记录训练的次数 now_time = time.strftime("%Y%m%d-%H%M%S") start = time.time() for epoch in range(start_epoch,epoch_num)\ : print("------第 {}/{} 轮训练开始------".format(epoch + 1, epoch_num)) train_loss,train_acc=train_once(train_loader,optimizer,loss_func,model,device) # print('轮数 {}, 训练集loss {:.4f}, 训练集acc {:.4f}'.format(epoch + 1, train_loss,train_acc)) #测试集 test_loss,test_acc=tes1t_once(test_loader,loss_func,model,device) end = time.time() train_msg.append([str(epoch+1),str(round(train_acc,4)),str(round(test_acc,4)),str(round(train_loss,4)),str(round(test_loss,4)),str(round(end - start,2))]) print('轮数 {}, 训练集acc {:.4f}, 测试集acc {:.4f}, 训练集loss {:.4f}, 测试集loss {:.4f}, 时间 {:.1f}s'.format(epoch + 1, train_acc,test_acc,train_loss,test_loss,end - start)) #写入log with open("./log/{}_{}_{}".format(now_time, model_name, train_size) + '.log', mode='a') as log: log.write('轮数 {}, 训练集acc {:.4f}, 测试集acc {:.4f}, 训练集loss {:.4f}, 测试集loss {:.4f}, 时间 {:.1f}s'.format(epoch + 1, train_acc,test_acc,train_loss,test_loss,end - start)) msg.append({"epoch": epoch+1,"train_loss":train_loss,"test_loss":test_loss,"test_acc":test_acc}) # 保存模型 state = {"epoch": epoch+1, "model": model.state_dict(),"msg":msg} if (epoch+1) % save_gap == 0: #保存训练信息 save_train_msg(csv_save_path="./csv/{}_{}_{}_{}.csv".format(time.strftime("%Y%m%d-%H%M%S"), model_name,train_size, epoch + 1),csv_data=train_msg) #保存模型 torch.save(state, "./GCN_checkpoints/{}_{}_{}_{}.pth".format(time.strftime("%Y%m%d-%H%M%S"), model_name,train_size, epoch + 1)) # 作图 plt_save(msg,"./png/{}_{}_{}_{}.png".format(time.strftime("%Y%m%d-%H%M%S"), model_name, train_size, epoch + 1)) # torch.save(model, ) print("模型已保存") model.eval() print("训练完测试") test_loss, test_acc = tes1t_once(test_loader, loss_func, model, device) print("accuracy: ", test_acc) return train_acc,test_acc def train_once(train_loader,optimizer,loss_func,model,device): train_pred, train_label = [], [] epoch_loss = 0 iter = 0 for iter, (batchg, label, msg_dict) in tqdm(enumerate(train_loader)): # print(batchg) # print(label) # print(msg_dict) # exit() batchg = batchg.to(device) label = label.to(device) prediction = torch.softmax(model(batchg), 1) del batchg loss = loss_func(prediction, label) pred = torch.max(prediction, 1)[1].view(-1) train_pred += pred.detach().cpu().numpy().tolist() train_label += label.cpu().numpy().tolist() optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.detach().item() # total_train_step += 1 # 上面就是进行了一次训练,训练次数 +1 epoch_loss /= (iter + 1) return epoch_loss,accuracy_score(train_label, train_pred) # print('Epoch {}, loss {:.4f}'.format(epoch+1, epoch_loss)) def tes1t_once(test_loader,loss_func,model,device): test_pred, test_label = [], [] total_test_loss = 0 with torch.no_grad(): it = 0 for it, (batchg, label, msg_dict) in tqdm(enumerate(test_loader)): batchg = batchg.to(device) label = label.to(device) pred = torch.softmax(model(batchg), 1) del batchg loss = loss_func(pred, label) total_test_loss = total_test_loss + loss pred = torch.max(pred, 1)[1].view(-1) test_pred += pred.detach().cpu().numpy().tolist() test_label += label.cpu().numpy().tolist() total_test_loss /= (it + 1) #返回loss和精确度 return total_test_loss.detach().item(),accuracy_score(test_label, test_pred) def save_train_msg(csv_save_path,csv_data): header = ['轮数', '训练集acc','测试集acc', '训练集loss','测试集loss','时间'] #保存到一个csv里 with open(csv_save_path, 'w', encoding='utf-8', newline='') as fp: writer = csv.writer(fp) writer.writerow(header) writer.writerows(csv_data) print("成功结束") def tes1t_acc(model_name="GCN_base",saved_checkpoint="",malware_num=500, benign_num=500,train_set_rate=0.75): train_loader, test_loader=load_train_test_data( malware_csv='../CFG_data/malware_n_gram.csv', benign_csv='../CFG_data/benign_n_gram.csv', malware_CFG_dir="../CFG_data/malware_n_gram", bengin_CFG_dir="../CFG_data/benign_n_gram", malware_num=malware_num, benign_num=benign_num, batch_size=64, train_set_rate=train_set_rate) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 构造模型 model = load_model(model_name, device) # 记录历史信息,用于作图 msg = [] start_epoch = 0 if saved_checkpoint != "": checkpoint = torch.load(saved_checkpoint) model.load_state_dict(checkpoint['model']) start_epoch = checkpoint['epoch'] # 写入历史信息 msg = checkpoint['msg'] model.eval() print("训练完测试") test_pred, test_label = [], [] with torch.no_grad(): for it, (batchg, label, msg_dict) in enumerate(test_loader): batchg = batchg.to(device) batchg = batchg.to(device) label = label.to(device) pred1 = torch.softmax(model(batchg), 1) print(pred1.tolist()) pred = torch.max(pred1, 1)[1].view(-1) test_pred += pred.detach().cpu().numpy().tolist() test_label += label.cpu().numpy().tolist() # pred = torch.max(pred1, 1)[0].view(-1) print("accuracy: ", accuracy_score(test_label, test_pred)) def current_train(): for i in range(100): train_acc, test_acc = GCN_train(model_name="GCN_asm2vec_base_16", malware_num=malware_num, benign_num=benign_num,save_gap=50, node_feature_method="asm2vec_base_small", epoch_num=100) if train_acc<0.91 and test_acc<0.86: print("找到结果。") break print("结束") if __name__ == '__main__': malware_num=1500 benign_num=1500 # current_train() # GCN_train(model_name="GCN_base", input_dm=16, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="malconv", epoch_num=100) # GCN_train(model_name="GCN_malconv_16", input_dm=16,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100) # GCN_train(model_name="GCN_n_gram_16", input_dm=16, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100) # GCN_train(model_name="GCN_word_frequency_16", input_dm=16, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_base_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_16",epoch_num = 100) # GCN_train(model_name="GCN_asm2vec_s_base_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_16", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_s368_base_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_16", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_plus_16", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_16", epoch_num=100) # exit() # GCN_train(model_name="GCN_malconv_32", input_dm=32,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100) # GCN_train(model_name="GCN_n_gram_32", input_dm=32, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100) # GCN_train(model_name="GCN_word_frequency_32", input_dm=32, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_base_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_32",epoch_num = 100) # GCN_train(model_name="GCN_asm2vec_s_base_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_32", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_s368_base_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_32", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_plus_32", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_32", epoch_num=100) # GCN_train(model_name="GCN_malconv_64", input_dm=64,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100) # GCN_train(model_name="GCN_n_gram_64", input_dm=64, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100) # GCN_train(model_name="GCN_word_frequency_64", input_dm=64, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_base_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_64",epoch_num = 100) # GCN_train(model_name="GCN_asm2vec_s_base_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_64", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_s368_base_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_64", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_plus_64", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_64", epoch_num=100) GCN_train(model_name="GCN_malconv_128", input_dm=128,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100) GCN_train(model_name="GCN_n_gram_128", input_dm=128, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100) GCN_train(model_name="GCN_word_frequency_128", input_dm=128, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_base_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_128",epoch_num = 100) # GCN_train(model_name="GCN_asm2vec_s_base_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_128", epoch_num=100) GCN_train(model_name="GCN_asm2vec_s368_base_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_128", epoch_num=100) GCN_train(model_name="GCN_asm2vec_plus_128", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_128", epoch_num=100) # GCN_train(model_name="GCN_malconv_256", input_dm=256,malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="malconv", epoch_num=100) # GCN_train(model_name="GCN_n_gram_256", input_dm=256, malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="n_gram", epoch_num=100) # GCN_train(model_name="GCN_word_frequency_256", input_dm=256, malware_num=malware_num, benign_num=benign_num,save_gap=100, node_feature_method="word_frequency", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_base_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_base_256",epoch_num = 100) # GCN_train(model_name="GCN_asm2vec_s_base_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s_base_256", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_s368_base_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_s368_base_256", epoch_num=100) # GCN_train(model_name="GCN_asm2vec_plus_256", malware_num=malware_num, benign_num=benign_num, save_gap=100,node_feature_method="asm2vec_plus_256", epoch_num=100)