from torch.utils.data import Dataset import csv import torch import networkx as nx import os from tqdm import tqdm import dgl import sys import re # sys.path.append(r'../ASM2VEC_base_scripts/') # import asm2vec # from func2vec import func2vec,load_model # sys.path.append(r'../ASM2VEC_scripts/') # from func2vec import func2vec,load_model # from node_feature import asm2vec_plus,asm2vec_base,malconv,n_gram,word_frequency # # # def get_node_feature(dict,node_feature_method="n_gram"): # if node_feature_method =="asm2vec_plus": # return dict[] # elif node_feature_method =="asm2vec_base": # return asm2vec_base(hex_asm=hex_asm).tolist() # elif node_feature_method =="malconv": # return malconv(hex_asm=hex_asm) # elif node_feature_method =="n_gram": # return n_gram(hex_asm=hex_asm) # elif node_feature_method =="word_frequency": # return word_frequency(hex_asm=hex_asm) class mydataset(Dataset): def __init__(self,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",node_feature_method="n_gram", input_dm=0,malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): # 读取加载数据 # data = pd.read_csv("600519.csv", encoding='gbk') data_loader=load_data(node_feature_method=node_feature_method,input_dm=input_dm,malware_csv=malware_csv,benign_csv=benign_csv,malware_CFG_dir=malware_CFG_dir,bengin_CFG_dir=bengin_CFG_dir,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes) self._x = [] self._y = [] self._msg= [] for iter, (graph, label,msg_dict) in enumerate(data_loader): self._x.append(graph) self._y.append(torch.tensor([label])) msg_dict_res=self.msg_tuple_to_dict(msg_dict) self._msg.append(msg_dict_res) # print(self._y) # exit() self._len = len(data_loader) def msg_tuple_to_dict(self,msg_dict): msg_dict_res=[] for item_msg in msg_dict: msg_dict_dict = {} msg_dict_dict['bytes']=item_msg[0] msg_dict_res.append(msg_dict_dict) return msg_dict_res def __getitem__(self, item): return self._x[item], self._y[item], self._msg[item] def __len__(self): # 返回整个数据的长度 return self._len def str_to_list(num_str): res=num_str.strip("[").strip("]") res=res.split(",") for i in range(len(res)): res[i]=float(res[i]) return res def get_data(data_loader,cfg_list,CFG_dir,input_dm=0,mode="malware",node_feature_method="n_gram"): for i in tqdm(range(len(cfg_list))): file = cfg_list[i] # print(malware) # 读取cfg G = nx.read_gexf(os.path.join(CFG_dir, file)) G_feature = [] msg_dict=[] for e, dict in G.nodes.items(): if input_dm!=0: G_feature.append(str_to_list(dict[node_feature_method])[:input_dm]) else: G_feature.append(str_to_list(dict[node_feature_method])) # print(G_feature) # exit() # msg_dict.append(tuple([dict["bytes"]])) g_dgl = dgl.from_networkx(G) # print(G_feature) # exit() g_dgl.ndata['feature'] = torch.Tensor(G_feature) del G # del G_feature # 对没有边的节点添加自连接边 # 恶意软件的label为1 if mode=="malware": data_loader.add((g_dgl, 1,tuple(msg_dict))) elif mode=="benign": data_loader.add((g_dgl, 0,tuple(msg_dict))) # print(data_loader) # exit() return data_loader def load_data(node_feature_method="n_gram",input_dm=0,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): # asm2vec_model = load_model(asm2vec_model_path) malware_cfg_list,benign_cfg_list = csv_read(malware_csv=malware_csv,benign_csv=benign_csv,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes) # = os.listdir(bengin_CFG_dir) data_loader=set() print("载入恶意样本cfg...") data_loader= get_data(data_loader,malware_cfg_list,malware_CFG_dir,input_dm=input_dm,mode="malware",node_feature_method=node_feature_method) # print(data_loader) # exit() print("载入良性样本cfg...") data_loader = get_data(data_loader, benign_cfg_list, bengin_CFG_dir,input_dm=input_dm, mode="benign",node_feature_method=node_feature_method) #载入asm2vec_base模型 print("读取" + str(len(malware_cfg_list)) + "个恶意函数cfg") print("读取" + str(len(benign_cfg_list)) + "个良性函数cfg") return data_loader def csv_read(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): malware_cfg_list=[] benign_cfg_list=[] with open(malware_csv, 'r', encoding='utf-8') as f: #经下述操作后,reader成为了一个可以迭代行的文件 reader = csv.reader(f) #先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行 header = next(reader) print(header) for row in reader: file_name=row[0] nodes_num=row[1] edgs_num=row[2] # insert_point_count=row[3] if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes: malware_cfg_list.append(file_name) if len(malware_cfg_list) == malware_num: break if benign_num!=0: with open(benign_csv, 'r', encoding='utf-8') as f: #经下述操作后,reader成为了一个可以迭代行的文件 reader = csv.reader(f) #先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行 header = next(reader) print(header) for row in reader: file_name=row[0] nodes_num=row[1] edgs_num=row[2] # insert_point_count = row[3] if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes: benign_cfg_list.append(file_name) if len(benign_cfg_list) == benign_num: break return malware_cfg_list,benign_cfg_list def csv_read2(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): malware_cfg_list=[] benign_cfg_list=[] with open(malware_csv, 'r', encoding='utf-8') as f: #经下述操作后,reader成为了一个可以迭代行的文件 reader = csv.reader(f) #先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行 header = next(reader) print(header) for row in reader: file_name=row[0] nodes_num=row[1] edgs_num=row[2] # insert_point_count=row[3] if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes: malware_cfg_list.append([file_name,nodes_num,edgs_num]) if len(malware_cfg_list) == malware_num: break if benign_num!=0: with open(benign_csv, 'r', encoding='utf-8') as f: #经下述操作后,reader成为了一个可以迭代行的文件 reader = csv.reader(f) #先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行 header = next(reader) print(header) for row in reader: file_name=row[0] nodes_num=row[1] edgs_num=row[2] # insert_point_count = row[3] if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes: benign_cfg_list.append([file_name,nodes_num,edgs_num]) if len(benign_cfg_list) == benign_num: break return malware_cfg_list,benign_cfg_list