205 lines
8.1 KiB
Python
205 lines
8.1 KiB
Python
from torch.utils.data import Dataset
|
||
import csv
|
||
import torch
|
||
import networkx as nx
|
||
import os
|
||
from tqdm import tqdm
|
||
import dgl
|
||
import sys
|
||
import re
|
||
# sys.path.append(r'../ASM2VEC_base_scripts/')
|
||
# import asm2vec
|
||
# from func2vec import func2vec,load_model
|
||
# sys.path.append(r'../ASM2VEC_scripts/')
|
||
# from func2vec import func2vec,load_model
|
||
|
||
# from node_feature import asm2vec_plus,asm2vec_base,malconv,n_gram,word_frequency
|
||
#
|
||
#
|
||
# def get_node_feature(dict,node_feature_method="n_gram"):
|
||
# if node_feature_method =="asm2vec_plus":
|
||
# return dict[]
|
||
# elif node_feature_method =="asm2vec_base":
|
||
# return asm2vec_base(hex_asm=hex_asm).tolist()
|
||
# elif node_feature_method =="malconv":
|
||
# return malconv(hex_asm=hex_asm)
|
||
# elif node_feature_method =="n_gram":
|
||
# return n_gram(hex_asm=hex_asm)
|
||
# elif node_feature_method =="word_frequency":
|
||
# return word_frequency(hex_asm=hex_asm)
|
||
class mydataset(Dataset):
|
||
def __init__(self,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",node_feature_method="n_gram", input_dm=0,malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10): # 读取加载数据
|
||
# data = pd.read_csv("600519.csv", encoding='gbk')
|
||
data_loader=load_data(node_feature_method=node_feature_method,input_dm=input_dm,malware_csv=malware_csv,benign_csv=benign_csv,malware_CFG_dir=malware_CFG_dir,bengin_CFG_dir=bengin_CFG_dir,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes)
|
||
self._x = []
|
||
self._y = []
|
||
self._msg= []
|
||
for iter, (graph, label,msg_dict) in enumerate(data_loader):
|
||
self._x.append(graph)
|
||
self._y.append(torch.tensor([label]))
|
||
msg_dict_res=self.msg_tuple_to_dict(msg_dict)
|
||
self._msg.append(msg_dict_res)
|
||
# print(self._y)
|
||
# exit()
|
||
self._len = len(data_loader)
|
||
def msg_tuple_to_dict(self,msg_dict):
|
||
msg_dict_res=[]
|
||
for item_msg in msg_dict:
|
||
msg_dict_dict = {}
|
||
msg_dict_dict['bytes']=item_msg[0]
|
||
msg_dict_res.append(msg_dict_dict)
|
||
return msg_dict_res
|
||
def __getitem__(self, item):
|
||
return self._x[item], self._y[item], self._msg[item]
|
||
|
||
def __len__(self): # 返回整个数据的长度
|
||
return self._len
|
||
|
||
def str_to_list(num_str):
|
||
res=num_str.strip("[").strip("]")
|
||
res=res.split(",")
|
||
for i in range(len(res)):
|
||
res[i]=float(res[i])
|
||
return res
|
||
|
||
def get_data(data_loader,cfg_list,CFG_dir,input_dm=0,mode="malware",node_feature_method="n_gram"):
|
||
|
||
|
||
for i in tqdm(range(len(cfg_list))):
|
||
file = cfg_list[i]
|
||
# print(malware)
|
||
# 读取cfg
|
||
G = nx.read_gexf(os.path.join(CFG_dir, file))
|
||
G_feature = []
|
||
|
||
msg_dict=[]
|
||
for e, dict in G.nodes.items():
|
||
|
||
if input_dm!=0:
|
||
G_feature.append(str_to_list(dict[node_feature_method])[:input_dm])
|
||
else:
|
||
G_feature.append(str_to_list(dict[node_feature_method]))
|
||
# print(G_feature)
|
||
# exit()
|
||
# msg_dict.append(tuple([dict["bytes"]]))
|
||
|
||
g_dgl = dgl.from_networkx(G)
|
||
# print(G_feature)
|
||
# exit()
|
||
g_dgl.ndata['feature'] = torch.Tensor(G_feature)
|
||
del G
|
||
# del G_feature
|
||
# 对没有边的节点添加自连接边
|
||
|
||
# 恶意软件的label为1
|
||
if mode=="malware":
|
||
data_loader.add((g_dgl, 1,tuple(msg_dict)))
|
||
elif mode=="benign":
|
||
data_loader.add((g_dgl, 0,tuple(msg_dict)))
|
||
# print(data_loader)
|
||
# exit()
|
||
return data_loader
|
||
|
||
def load_data(node_feature_method="n_gram",input_dm=0,malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_CFG_dir="../CFG_data/malware",bengin_CFG_dir="../CFG_data/benign",malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
|
||
|
||
# asm2vec_model = load_model(asm2vec_model_path)
|
||
|
||
malware_cfg_list,benign_cfg_list = csv_read(malware_csv=malware_csv,benign_csv=benign_csv,malware_num=malware_num,benign_num=benign_num,max_nodes=max_nodes,min_nodes=min_nodes)
|
||
# = os.listdir(bengin_CFG_dir)
|
||
data_loader=set()
|
||
print("载入恶意样本cfg...")
|
||
data_loader= get_data(data_loader,malware_cfg_list,malware_CFG_dir,input_dm=input_dm,mode="malware",node_feature_method=node_feature_method)
|
||
# print(data_loader)
|
||
# exit()
|
||
print("载入良性样本cfg...")
|
||
data_loader = get_data(data_loader, benign_cfg_list, bengin_CFG_dir,input_dm=input_dm, mode="benign",node_feature_method=node_feature_method)
|
||
#载入asm2vec_base模型
|
||
|
||
|
||
print("读取" + str(len(malware_cfg_list)) + "个恶意函数cfg")
|
||
print("读取" + str(len(benign_cfg_list)) + "个良性函数cfg")
|
||
return data_loader
|
||
|
||
def csv_read(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
|
||
|
||
malware_cfg_list=[]
|
||
benign_cfg_list=[]
|
||
|
||
with open(malware_csv, 'r', encoding='utf-8') as f:
|
||
#经下述操作后,reader成为了一个可以迭代行的文件
|
||
reader = csv.reader(f)
|
||
#先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行
|
||
header = next(reader)
|
||
print(header)
|
||
|
||
for row in reader:
|
||
file_name=row[0]
|
||
nodes_num=row[1]
|
||
edgs_num=row[2]
|
||
# insert_point_count=row[3]
|
||
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
|
||
malware_cfg_list.append(file_name)
|
||
if len(malware_cfg_list) == malware_num:
|
||
break
|
||
if benign_num!=0:
|
||
with open(benign_csv, 'r', encoding='utf-8') as f:
|
||
#经下述操作后,reader成为了一个可以迭代行的文件
|
||
reader = csv.reader(f)
|
||
#先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行
|
||
header = next(reader)
|
||
print(header)
|
||
|
||
for row in reader:
|
||
file_name=row[0]
|
||
nodes_num=row[1]
|
||
edgs_num=row[2]
|
||
# insert_point_count = row[3]
|
||
|
||
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
|
||
benign_cfg_list.append(file_name)
|
||
if len(benign_cfg_list) == benign_num:
|
||
break
|
||
|
||
return malware_cfg_list,benign_cfg_list
|
||
|
||
def csv_read2(malware_csv='../CFG_data/malware_msg.csv',benign_csv='../CFG_data/benign_msg.csv',malware_num=500,benign_num=500,max_nodes=20000,min_nodes=10):
|
||
|
||
malware_cfg_list=[]
|
||
benign_cfg_list=[]
|
||
|
||
with open(malware_csv, 'r', encoding='utf-8') as f:
|
||
#经下述操作后,reader成为了一个可以迭代行的文件
|
||
reader = csv.reader(f)
|
||
#先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行
|
||
header = next(reader)
|
||
print(header)
|
||
|
||
for row in reader:
|
||
file_name=row[0]
|
||
nodes_num=row[1]
|
||
edgs_num=row[2]
|
||
# insert_point_count=row[3]
|
||
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
|
||
malware_cfg_list.append([file_name,nodes_num,edgs_num])
|
||
if len(malware_cfg_list) == malware_num:
|
||
break
|
||
if benign_num!=0:
|
||
with open(benign_csv, 'r', encoding='utf-8') as f:
|
||
#经下述操作后,reader成为了一个可以迭代行的文件
|
||
reader = csv.reader(f)
|
||
#先拿出csv文件的首行(一般是基本名称说明的行),此时指针指向下一行
|
||
header = next(reader)
|
||
print(header)
|
||
|
||
for row in reader:
|
||
file_name=row[0]
|
||
nodes_num=row[1]
|
||
edgs_num=row[2]
|
||
# insert_point_count = row[3]
|
||
|
||
if int(nodes_num) <= max_nodes and int(nodes_num)>=min_nodes:
|
||
benign_cfg_list.append([file_name,nodes_num,edgs_num])
|
||
if len(benign_cfg_list) == benign_num:
|
||
break
|
||
|
||
return malware_cfg_list,benign_cfg_list |