85 lines
3.5 KiB
Python
85 lines
3.5 KiB
Python
import os
|
|
import os.path as osp
|
|
from datetime import datetime
|
|
|
|
import torch
|
|
from torch_geometric.data import Dataset, DataLoader
|
|
from utils.RealBatch import create_real_batch_data # noqa
|
|
|
|
|
|
class MalwareDetectionDataset(Dataset):
|
|
def __init__(self, root, train_or_test, transform=None, pre_transform=None):
|
|
super(MalwareDetectionDataset, self).__init__(None, transform, pre_transform)
|
|
self.flag = train_or_test.lower()
|
|
self.malware_root = os.path.join(root, "{}_malware".format(self.flag))
|
|
self.benign_root = os.path.join(root, "{}_benign".format(self.flag))
|
|
self.malware_files = os.listdir(self.malware_root)
|
|
self.benign_files = os.listdir(self.benign_root)
|
|
|
|
@staticmethod
|
|
def _list_files_for_pt(the_path):
|
|
files = []
|
|
for name in os.listdir(the_path):
|
|
if os.path.splitext(name)[-1] == '.pt':
|
|
files.append(name)
|
|
return files
|
|
|
|
def __len__(self):
|
|
# def len(self):
|
|
# return 201
|
|
return len(self.malware_files) + len(self.benign_files)
|
|
|
|
def get(self, idx):
|
|
split = len(self.malware_files)
|
|
# split = 100
|
|
if idx < split:
|
|
idx_data = torch.load(osp.join(self.malware_root, 'malware_{}.pt'.format(idx)))
|
|
else:
|
|
over_fit_idx = idx - split
|
|
idx_data = torch.load(osp.join(self.benign_root, "benign_{}.pt".format(over_fit_idx)))
|
|
return idx_data
|
|
|
|
|
|
def _simulating(_dataset, _batch_size: int):
|
|
print("\nBatch size = {}".format(_batch_size))
|
|
time_start = datetime.now()
|
|
print("start time: " + time_start.strftime("%Y-%m-%d@%H:%M:%S"))
|
|
|
|
# https://github.com/pytorch/fairseq/issues/1560
|
|
# https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189
|
|
# loaders_1 = DataLoader(dataset=benign_exe_dataset, batch_size=10, shuffle=True, num_workers=0)
|
|
# increasing the shared memory: ulimit -SHn 51200
|
|
loader = DataLoader(dataset=_dataset, batch_size=_batch_size, shuffle=True) # default of prefetch_factor = 2 # num_workers=4
|
|
|
|
for index, data in enumerate(loader):
|
|
if index >= 3:
|
|
break
|
|
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
|
|
print(data)
|
|
print("Hash: ", _hash)
|
|
print("Position: ", _position)
|
|
print("\n")
|
|
|
|
time_end = datetime.now()
|
|
print("end time: " + time_end.strftime("%Y-%m-%d@%H:%M:%S"))
|
|
print("All time = {}\n\n".format(time_end - time_start))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
root_path: str = '/home/xiang/MalGraph/data/processed_dataset/DatasetJSON/'
|
|
i_batch_size = 2
|
|
|
|
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
|
|
print(train_dataset.malware_root, train_dataset.benign_root)
|
|
print(len(train_dataset.malware_files), len(train_dataset.benign_files), len(train_dataset))
|
|
_simulating(_dataset=train_dataset, _batch_size=i_batch_size)
|
|
|
|
valid_dataset = MalwareDetectionDataset(root=root_path, train_or_test='valid')
|
|
print(valid_dataset.malware_root, valid_dataset.benign_root)
|
|
print(len(valid_dataset.malware_files), len(valid_dataset.benign_files), len(valid_dataset))
|
|
_simulating(_dataset=valid_dataset, _batch_size=i_batch_size)
|
|
|
|
test_dataset = MalwareDetectionDataset(root=root_path, train_or_test='test')
|
|
print(test_dataset.malware_root, test_dataset.benign_root)
|
|
print(len(test_dataset.malware_files), len(test_dataset.benign_files), len(test_dataset))
|
|
_simulating(_dataset=test_dataset, _batch_size=i_batch_size) |