detect_rep/detect_script/load_dataset.py

205 lines
8.1 KiB
Python
Raw Permalink Normal View History

2023-04-05 10:04:49 +08:00
from torch.utils.data import Dataset
import csv
import torch
import networkx as nx
import os
from tqdm import tqdm
import dgl
import sys
import re
# sys.path.append(r'../ASM2VEC_base_scripts/')
# import asm2vec
# from func2vec import func2vec,load_model
# sys.path.append(r'../ASM2VEC_scripts/')
# from func2vec import func2vec,load_model
# from node_feature import asm2vec_plus,asm2vec_base,malconv,n_gram,word_frequency
#
#
# def get_node_feature(dict,node_feature_method="n_gram"):
# if node_feature_method =="asm2vec_plus":
# return dict[]
# elif node_feature_method =="asm2vec_base":
# return asm2vec_base(hex_asm=hex_asm).tolist()
# elif node_feature_method =="malconv":
# return malconv(hex_asm=hex_asm)
# elif node_feature_method =="n_gram":
# return n_gram(hex_asm=hex_asm)
# elif node_feature_method =="word_frequency":
# return word_frequency(hex_asm=hex_asm)
class mydataset(Dataset):
def __init__(self,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",node_feature_method="n_gram", input_dm=0,malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): # 读取加载数据
# data = pd.read_csv("600519.csv", encoding='gbk')
data_loader=load_data(node_feature_method=node_feature_method,input_dm=input_dm,malware_csv=malware_csv,benign_csv=benign_csv,malware_CFG_dir=malware_CFG_dir,bengin_CFG_dir=bengin_CFG_dir,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes)
self._x = []
self._y = []
self._msg= []
for iter, (graph, label,msg_dict) in enumerate(data_loader):
self._x.append(graph)
self._y.append(torch.tensor([label]))
msg_dict_res=self.msg_tuple_to_dict(msg_dict)
self._msg.append(msg_dict_res)
# print(self._y)
# exit()
self._len = len(data_loader)
def msg_tuple_to_dict(self,msg_dict):
msg_dict_res=[]
for item_msg in msg_dict:
msg_dict_dict = {}
msg_dict_dict['bytes']=item_msg[0]
msg_dict_res.append(msg_dict_dict)
return msg_dict_res
def __getitem__(self, item):
return self._x[item], self._y[item], self._msg[item]
def __len__(self): # 返回整个数据的长度
return self._len
def str_to_list(num_str):
res=num_str.strip("[").strip("]")
res=res.split(",")
for i in range(len(res)):
res[i]=float(res[i])
return res
def get_data(data_loader,cfg_list,CFG_dir,input_dm=0,mode="malware",node_feature_method="n_gram"):
for i in tqdm(range(len(cfg_list))):
file = cfg_list[i]
# print(malware)
# 读取cfg
G = nx.read_gexf(os.path.join(CFG_dir, file))
G_feature = []
msg_dict=[]
for e, dict in G.nodes.items():
if input_dm!=0:
G_feature.append(str_to_list(dict[node_feature_method])[:input_dm])
else:
G_feature.append(str_to_list(dict[node_feature_method]))
# print(G_feature)
# exit()
# msg_dict.append(tuple([dict["bytes"]]))
g_dgl = dgl.from_networkx(G)
# print(G_feature)
# exit()
g_dgl.ndata['feature'] = torch.Tensor(G_feature)
del G
# del G_feature
# 对没有边的节点添加自连接边
# 恶意软件的label为1
if mode=="malware":
data_loader.add((g_dgl, 1,tuple(msg_dict)))
elif mode=="benign":
data_loader.add((g_dgl, 0,tuple(msg_dict)))
# print(data_loader)
# exit()
return data_loader
def load_data(node_feature_method="n_gram",input_dm=0,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
# asm2vec_model = load_model(asm2vec_model_path)
malware_cfg_list,benign_cfg_list = csv_read(malware_csv=malware_csv,benign_csv=benign_csv,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes)
# = os.listdir(bengin_CFG_dir)
data_loader=set()
print("载入恶意样本cfg...")
data_loader= get_data(data_loader,malware_cfg_list,malware_CFG_dir,input_dm=input_dm,mode="malware",node_feature_method=node_feature_method)
# print(data_loader)
# exit()
print("载入良性样本cfg...")
data_loader = get_data(data_loader, benign_cfg_list, bengin_CFG_dir,input_dm=input_dm, mode="benign",node_feature_method=node_feature_method)
#载入asm2vec_base模型
print("读取" + str(len(malware_cfg_list)) + "个恶意函数cfg")
print("读取" + str(len(benign_cfg_list)) + "个良性函数cfg")
return data_loader
def csv_read(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
malware_cfg_list=[]
benign_cfg_list=[]
with open(malware_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count=row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
malware_cfg_list.append(file_name)
if len(malware_cfg_list) == malware_num:
break
if benign_num!=0:
with open(benign_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count = row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
benign_cfg_list.append(file_name)
if len(benign_cfg_list) == benign_num:
break
return malware_cfg_list,benign_cfg_list
def csv_read2(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
malware_cfg_list=[]
benign_cfg_list=[]
with open(malware_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count=row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
malware_cfg_list.append([file_name,nodes_num,edgs_num])
if len(malware_cfg_list) == malware_num:
break
if benign_num!=0:
with open(benign_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count = row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
benign_cfg_list.append([file_name,nodes_num,edgs_num])
if len(benign_cfg_list) == benign_num:
break
return malware_cfg_list,benign_cfg_list