detect_rep/detect_script/load_dataset.py
2023-04-05 10:04:49 +08:00

205 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.utils.data import Dataset
import csv
import torch
import networkx as nx
import os
from tqdm import tqdm
import dgl
import sys
import re
# sys.path.append(r'../ASM2VEC_base_scripts/')
# import asm2vec
# from func2vec import func2vec,load_model
# sys.path.append(r'../ASM2VEC_scripts/')
# from func2vec import func2vec,load_model
# from node_feature import asm2vec_plus,asm2vec_base,malconv,n_gram,word_frequency
#
#
# def get_node_feature(dict,node_feature_method="n_gram"):
# if node_feature_method =="asm2vec_plus":
# return dict[]
# elif node_feature_method =="asm2vec_base":
# return asm2vec_base(hex_asm=hex_asm).tolist()
# elif node_feature_method =="malconv":
# return malconv(hex_asm=hex_asm)
# elif node_feature_method =="n_gram":
# return n_gram(hex_asm=hex_asm)
# elif node_feature_method =="word_frequency":
# return word_frequency(hex_asm=hex_asm)
class mydataset(Dataset):
def __init__(self,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",node_feature_method="n_gram", input_dm=0,malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): # 读取加载数据
# data = pd.read_csv("600519.csv", encoding='gbk')
data_loader=load_data(node_feature_method=node_feature_method,input_dm=input_dm,malware_csv=malware_csv,benign_csv=benign_csv,malware_CFG_dir=malware_CFG_dir,bengin_CFG_dir=bengin_CFG_dir,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes)
self._x = []
self._y = []
self._msg= []
for iter, (graph, label,msg_dict) in enumerate(data_loader):
self._x.append(graph)
self._y.append(torch.tensor([label]))
msg_dict_res=self.msg_tuple_to_dict(msg_dict)
self._msg.append(msg_dict_res)
# print(self._y)
# exit()
self._len = len(data_loader)
def msg_tuple_to_dict(self,msg_dict):
msg_dict_res=[]
for item_msg in msg_dict:
msg_dict_dict = {}
msg_dict_dict['bytes']=item_msg[0]
msg_dict_res.append(msg_dict_dict)
return msg_dict_res
def __getitem__(self, item):
return self._x[item], self._y[item], self._msg[item]
def __len__(self): # 返回整个数据的长度
return self._len
def str_to_list(num_str):
res=num_str.strip("[").strip("]")
res=res.split(",")
for i in range(len(res)):
res[i]=float(res[i])
return res
def get_data(data_loader,cfg_list,CFG_dir,input_dm=0,mode="malware",node_feature_method="n_gram"):
for i in tqdm(range(len(cfg_list))):
file = cfg_list[i]
# print(malware)
# 读取cfg
G = nx.read_gexf(os.path.join(CFG_dir, file))
G_feature = []
msg_dict=[]
for e, dict in G.nodes.items():
if input_dm!=0:
G_feature.append(str_to_list(dict[node_feature_method])[:input_dm])
else:
G_feature.append(str_to_list(dict[node_feature_method]))
# print(G_feature)
# exit()
# msg_dict.append(tuple([dict["bytes"]]))
g_dgl = dgl.from_networkx(G)
# print(G_feature)
# exit()
g_dgl.ndata['feature'] = torch.Tensor(G_feature)
del G
# del G_feature
# 对没有边的节点添加自连接边
# 恶意软件的label为1
if mode=="malware":
data_loader.add((g_dgl, 1,tuple(msg_dict)))
elif mode=="benign":
data_loader.add((g_dgl, 0,tuple(msg_dict)))
# print(data_loader)
# exit()
return data_loader
def load_data(node_feature_method="n_gram",input_dm=0,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
# asm2vec_model = load_model(asm2vec_model_path)
malware_cfg_list,benign_cfg_list = csv_read(malware_csv=malware_csv,benign_csv=benign_csv,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes)
# = os.listdir(bengin_CFG_dir)
data_loader=set()
print("载入恶意样本cfg...")
data_loader= get_data(data_loader,malware_cfg_list,malware_CFG_dir,input_dm=input_dm,mode="malware",node_feature_method=node_feature_method)
# print(data_loader)
# exit()
print("载入良性样本cfg...")
data_loader = get_data(data_loader, benign_cfg_list, bengin_CFG_dir,input_dm=input_dm, mode="benign",node_feature_method=node_feature_method)
#载入asm2vec_base模型
print("读取" + str(len(malware_cfg_list)) + "个恶意函数cfg")
print("读取" + str(len(benign_cfg_list)) + "个良性函数cfg")
return data_loader
def csv_read(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
malware_cfg_list=[]
benign_cfg_list=[]
with open(malware_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count=row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
malware_cfg_list.append(file_name)
if len(malware_cfg_list) == malware_num:
break
if benign_num!=0:
with open(benign_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count = row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
benign_cfg_list.append(file_name)
if len(benign_cfg_list) == benign_num:
break
return malware_cfg_list,benign_cfg_list
def csv_read2(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
malware_cfg_list=[]
benign_cfg_list=[]
with open(malware_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count=row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
malware_cfg_list.append([file_name,nodes_num,edgs_num])
if len(malware_cfg_list) == malware_num:
break
if benign_num!=0:
with open(benign_csv, 'r', encoding='utf-8') as f:
#经下述操作后reader成为了一个可以迭代行的文件
reader = csv.reader(f)
#先拿出csv文件的首行一般是基本名称说明的行此时指针指向下一行
header = next(reader)
print(header)
for row in reader:
file_name=row[0]
nodes_num=row[1]
edgs_num=row[2]
# insert_point_count = row[3]
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
benign_cfg_list.append([file_name,nodes_num,edgs_num])
if len(benign_cfg_list) == benign_num:
break
return malware_cfg_list,benign_cfg_list