This commit is contained in:
ryderling 2022-01-10 16:39:55 +08:00
parent bfdb8079d6
commit 8296bcb432
17 changed files with 1100 additions and 0 deletions

27
configs/default.yaml Normal file
View File

@ -0,0 +1,27 @@
Data:
preprocess_root: "../data/processed_dataset/DatasetJSON/"
train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl"
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
Training:
cuda: True # enable GPU training if cuda is available
dist_backend: "nccl" # if using torch.distribution, the backend to be used
dist_port: "1234"
max_epoches: 10
train_batch_size: 16
test_batch_size: 32
seed: 19920208
only_test_path: 'None'
Model:
ablation_models: "Full" # "Full"
gnn_type: "GraphSAGE" # "GraphSAGE" / "GCN"
pool_type: "global_max_pool" # "global_max_pool" / "global_mean_pool"
acfg_node_init_dims: 11
cfg_filters: "200-200"
fcg_filters: "200-200"
number_classes: 1
drapout_rate: 0.2
Optimizer:
name: "AdamW" # Adam / AdamW
learning_rate: 1e-3 # initial learning rate
weight_decay: 1e-5 # initial weight decay
learning_anneal: 1.1 # Annealing applied to learning rate after each epoch

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

41
samples/PreProcess.py Normal file
View File

@ -0,0 +1,41 @@
import json
import torch
from torch_geometric.data import Data
from tqdm import tqdm
from utils.Vocabulary import Vocab
def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab):
index = 0
with open(jsonl_file, "r", encoding="utf-8") as file:
for item in tqdm(file):
item = json.loads(item)
item_hash = item['hash']
acfg_list = []
for one_acfg in item['acfg_list']: # list of dict of acfg
block_features = one_acfg['block_features']
block_edges = one_acfg['block_edges']
one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float), edge_index=torch.tensor(block_edges, dtype=torch.long))
acfg_list.append(one_acfg_data)
item_function_names = item['function_names']
item_function_edges = item['function_edges']
local_function_name_list = item_function_names[:len(acfg_list)]
assert len(acfg_list) == len(local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List"
external_function_name_list = item_function_names[len(acfg_list):]
external_function_index_list = [vocab[f_name] for f_name in external_function_name_list]
index += 1
torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list, function_edges=item_function_edges, targets=label), "./{}.pt".format(index))
print(index)
if __name__ == '__main__':
json_path = "./sample.jsonl"
train_vocab_file = "../ReservedDataCode/processed_dataset/train_external_function_name_vocab.jsonl"
max_vocab_size = 10000
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary)

21
samples/README.md Normal file
View File

@ -0,0 +1,21 @@
# Data Preprocessing
### STEP 1: PE Disassemble
We first use IDA Pro 6.4 to disassemble one given portable executable (PE) file, obtaining one function call graph (i.e., FCG, including both external functions and local functions) and corresponding control flow graphs (CFGs) of local functions.
In fact, FCG can be exported as Graph Description Language GDL file format, and CFGs can be processed as ACFGs, which are mainly built on the GitHub repo of https://github.com/qian-feng/Gencoding.
We therefore refer interested readers to this repo for more details.
Taking one PE file as an example, we can use IDA Pro to get the following FCG (25 external functions and 2 local functions)
![system](./FunctionCallGraph.png)
and two CFGs of local functions, i.e., sub_401000 and 40103C as follows.
![system](./sub_401000.png)
![system](./sub_40103C.png)
After that, we can save the above hierarchical graph representation into sample.jsonl as follows.
```
{"function_edges": [[1, 1, ..., 1], [0, 2, ..., 26]], "acfg_list": [{"block_number": 3, "block_edges": [[0, 0, 1, 1], [0, 2, 0, 2]], "block_features": [[0, 2, ...], [0, 2, ...], [1, 0, ...]]}, {"block_number": 29, "block_edges": [[0, 1, ..., 28], [16, 0, ..., 8]], "block_features": [[8, 2, ...], [0, 7, ...], [0, 7, ...], [0, 7, ...], [0, 7, ...], [0, 7,...], [1, 18, ...], [1, 21, ...], [0, 21,...], [0, 24, ...], [1, 26, ...], [1, 2, ...], [5, 4, ...], [4, 11, ...], [2, 14, ...], [3, 17, ...], [1, 1, ...], [0, 14, ...], [3, 17, ...], [0, 17, ...], [2, 28, ...], [0, 11, ...], [0, 0, ...], [1, 1, ...], [12, 27, ...], [0, 0, ...], [2, 9, ...], [2, 14,...], [1, 21, ...]]}], "function_names": ["sub_401000", "start", "GetTempPathW", "GetFileSize", ... , "InternetOpenW"], "hash": "3***5", "function_number": 27}
```
### STEP 2: Convert the resulting json file to PyG data object
However, the above resulting json object can not be directly inputted into our model, we therefore convert it into a PyTorch_Geometric `data` object and provide one example python script of `PreProcess.py` for interested readers.

1
samples/sample.jsonl Normal file
View File

@ -0,0 +1 @@
{"function_edges": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]], "acfg_list": [{"block_number": 3, "block_edges": [[0, 0, 1, 1], [0, 2, 0, 2]], "block_features": [[0, 2, 1, 0, 7, 0, 1, 1, 4, 0, 0], [0, 2, 0, 0, 3, 1, 0, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]]}, {"block_number": 29, "block_edges": [[0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 12, 12, 13, 14, 14, 15, 16, 17, 18, 19, 19, 20, 20, 21, 21, 23, 24, 24, 26, 26, 27, 28], [16, 0, 2, 0, 4, 1, 3, 3, 3, 25, 15, 8, 6, 6, 7, 28, 12, 9, 23, 16, 25, 11, 21, 17, 13, 19, 22, 14, 19, 18, 27, 24, 23, 26, 21, 22, 25, 10, 25, 5, 14, 8]], "block_features": [[8, 2, 1, 5, 36, 0, 6, 0, 2, 0, 0], [0, 7, 0, 0, 3, 0, 1, 1, 1, 0, 0], [0, 7, 0, 0, 2, 0, 1, 1, 0, 0, 0], [0, 7, 0, 1, 8, 1, 2, 0, 0, 0, 0], [0, 7, 1, 0, 2, 0, 1, 0, 0, 0, 0], [0, 7, 0, 0, 1, 0, 0, 0, 1, 0, 0], [1, 18, 0, 1, 9, 0, 2, 1, 1, 0, 0], [1, 21, 1, 0, 3, 0, 1, 1, 0, 0, 0], [0, 21, 0, 1, 4, 1, 2, 0, 0, 0, 0], [0, 24, 0, 2, 12, 1, 3, 0, 0, 0, 0], [1, 26, 0, 3, 16, 0, 4, 1, 4, 0, 0], [1, 2, 0, 5, 22, 0, 5, 0, 1, 0, 0], [5, 4, 1, 3, 21, 0, 4, 1, 3, 0, 0], [4, 11, 0, 2, 17, 1, 2, 0, 1, 0, 0], [2, 14, 0, 1, 12, 0, 2, 1, 1, 0, 0], [3, 17, 0, 0, 10, 0, 1, 0, 1, 0, 0], [1, 1, 0, 1, 5, 0, 2, 0, 0, 0, 0], [0, 14, 0, 0, 1, 0, 0, 0, 0, 0, 0], [3, 17, 0, 0, 7, 0, 0, 0, 0, 0, 0], [0, 17, 0, 1, 5, 0, 2, 1, 1, 0, 0], [2, 28, 1, 1, 11, 1, 2, 1, 1, 0, 0], [0, 11, 0, 1, 8, 1, 2, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0], [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0], [12, 27, 1, 7, 41, 0, 8, 1, 6, 0, 0], [0, 0, 1, 0, 7, 1, 0, 0, 0, 1, 0], [2, 9, 0, 2, 17, 0, 3, 1, 3, 0, 0], [2, 14, 0, 0, 5, 0, 1, 0, 4, 0, 0], [1, 21, 4, 1, 13, 0, 2, 0, 5, 0, 0]]}], "function_names": ["sub_401000", "start", "GetTempPathW", "GetFileSize", "GetCurrentDirectoryW", "DeleteFileW", "CloseHandle", "WriteFile", "lstrcmpW", "ReadFile", "GetModuleHandleW", "ExitProcess", "HeapCreate", "HeapAlloc", "GetModuleFileNameW", "CreateFileW", "lstrlenW", "ShellExecuteW", "wsprintfW", "HttpSendRequestW", "InternetSetOptionW", "InternetQueryOptionW", "HttpOpenRequestW", "HttpQueryInfoW", "InternetReadFile", "InternetConnectW", "InternetOpenW"], "hash": "316ebb797d5196020eee013cfe771671fff4da8859adc9f385f52a74e82f4e55", "function_number": 27}

BIN
samples/sub_401000.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 196 KiB

BIN
samples/sub_40103C.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

346
src/DistTrainModel.py Normal file
View File

@ -0,0 +1,346 @@
import logging
import math
import os
import random
from datetime import datetime
import hydra
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as torch_mp
import torch.utils.data
import torch.utils.data.distributed
from hydra.utils import to_absolute_path
from omegaconf import DictConfig
from prefetch_generator import BackgroundGenerator
from sklearn.metrics import roc_auc_score, roc_curve
from torch import nn
from torch_geometric.data import DataLoader
from tqdm import tqdm
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
from utils.ParameterClasses import ModelParams, TrainParams, OptimizerParams, OneEpochResult
from utils.PreProcessedDataset import MalwareDetectionDataset
from utils.RealBatch import create_real_batch_data
from utils.Vocabulary import Vocab
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
def reduce_sum(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
return rt
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
rt /= nprocs
return rt
def all_gather_concat(tensor):
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, optimizer, nprocs, idx_epoch, best_auc, best_model_file, original_valid_length, result_file):
model.train()
local_device = torch.device("cuda", local_rank)
write_into(file_name_path=result_file, log_str="The local device = {} among {} nprocs in the {}-th epoch.".format(local_device, nprocs, idx_epoch))
until_sum_reduced_loss = 0.0
smooth_avg_reduced_loss_list = []
for _idx_bt, _batch in enumerate(tqdm(train_loader, desc="reading _batch from local_rank={}".format(local_rank))):
model.train()
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=_batch)
if _real_batch is None:
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of training ... ".format("*-" * 100))
continue
_real_batch = _real_batch.to(local_device)
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
train_batch_pred = model(real_local_batch=_real_batch, real_bt_positions=_position, bt_external_names=_external_list, bt_all_function_edges=_function_edges, local_device=local_device)
train_batch_pred = train_batch_pred.squeeze()
loss = criterion(train_batch_pred, _true_classes)
torch.distributed.barrier()
optimizer.zero_grad()
loss.backward()
optimizer.step()
reduced_loss = reduce_mean(loss, nprocs)
until_sum_reduced_loss += reduced_loss.item()
smooth_avg_reduced_loss_list.append(until_sum_reduced_loss / (_idx_bt + 1))
if _idx_bt != 0 and (_idx_bt % math.ceil(len(train_loader) / 3) == 0 or _idx_bt == int(len(train_loader) - 1)):
val_start_time = datetime.now()
if local_rank == 0:
write_into(result_file, "\nIn {}-th epoch, {}-th batch, we start to validate ... ".format(idx_epoch, _idx_bt))
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs,
original_valid_length=original_valid_length, result_file=result_file, details=False)
if best_auc < valid_result.ROC_AUC_Score:
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
_idx_bt,
best_auc,
valid_result.ROC_AUC_Score,
best_model_file)
best_auc = valid_result.ROC_AUC_Score
torch.save(model.module.state_dict(), best_model_file)
else:
_info = "[AUC NOT Increased!] AUC decreased from {:.5f} to {:.5f}!".format(best_auc, valid_result.ROC_AUC_Score)
if local_rank == 0:
write_into(result_file, valid_result.__str__())
write_into(result_file, _info)
write_into(result_file, "[#One Validation Time#] Consume about {} time period for one validation.".format(datetime.now() - val_start_time))
return smooth_avg_reduced_loss_list, best_auc
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
model.eval()
if distributed:
local_device = torch.device("cuda", local_rank)
else:
local_device = torch.device("cuda")
sum_loss = torch.tensor(0.0, dtype=torch.float, device=local_device)
n_samples = torch.tensor(0, dtype=torch.int, device=local_device)
all_true_classes = []
all_positive_probs = []
with torch.no_grad():
for idx_batch, data in enumerate(tqdm(valid_loader)):
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
if _real_batch is None:
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of validation ... ".format("*-" * 100))
continue
_real_batch = _real_batch.to(local_device)
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
batch_pred = model(real_local_batch=_real_batch, real_bt_positions=_position, bt_external_names=_external_list, bt_all_function_edges=_function_edges, local_device=local_device)
batch_pred = batch_pred.squeeze(-1)
loss = criterion(batch_pred, _true_classes)
sum_loss += loss.item()
n_samples += len(batch_pred)
all_true_classes.append(_true_classes)
all_positive_probs.append(batch_pred)
avg_loss = sum_loss / (idx_batch + 1)
all_true_classes = torch.cat(all_true_classes, dim=0)
all_positive_probs = torch.cat(all_positive_probs, dim=0)
if distributed:
torch.distributed.barrier()
reduced_n_samples = reduce_sum(n_samples)
reduced_avg_loss = reduce_mean(avg_loss, nprocs)
gather_true_classes = all_gather_concat(all_true_classes).detach().cpu().numpy()
gather_positive_prods = all_gather_concat(all_positive_probs).detach().cpu().numpy()
gather_true_classes = gather_true_classes[:original_valid_length]
gather_positive_prods = gather_positive_prods[:original_valid_length]
else:
reduced_n_samples = n_samples
reduced_avg_loss = avg_loss
gather_true_classes = all_true_classes.detach().cpu().numpy()
gather_positive_prods = all_positive_probs.detach().cpu().numpy()
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
if details is True:
_100_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.01)
_1000_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.001)
else:
_100_info, _1000_info = "None", "None"
_eval_result = OneEpochResult(Epoch_Flag=evaluate_flag, Number_Samples=reduced_n_samples, Avg_Loss=reduced_avg_loss, Info_100=_100_info, Info_1000=_1000_info,
ROC_AUC_Score=_roc_auc_score, Thresholds=_thresholds, TPRs=_tpr, FPRs=_fpr)
return _eval_result
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
log_result_file: str):
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
torch.cuda.set_device(local_rank)
# model configure
vocab = Vocab(freq_file=train_params.external_func_vocab_file, max_vocab_size=train_params.max_vocab_size)
if model_params.ablation_models.lower() == "full":
model = HierarchicalGraphNeuralNetwork(model_params=model_params, external_vocab=vocab, global_log=global_log)
else:
raise NotImplementedError
model.cuda(local_rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
if local_rank == 0:
write_into(file_name_path=log_result_file, log_str=model.__str__())
# loss function
criterion = nn.BCELoss().cuda(local_rank)
lr = optimizer_params.lr
if optimizer_params.optimizer_name.lower() == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif optimizer_params.optimizer_name.lower() == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=optimizer_params.weight_decay)
else:
raise NotImplementedError
max_epochs = train_params.max_epochs
dataset_root_path = train_params.processed_files_path
train_batch_size = train_params.train_bs
test_batch_size = train_params.test_bs
# training dataset & dataloader
train_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="train")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
# validation dataset & dataloader
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
valid_loader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, pin_memory=True, sampler=valid_sampler)
if local_rank == 0:
write_into(file_name_path=log_result_file, log_str="Training dataset={}, sampler={}, loader={}".format(len(train_dataset), len(train_sampler), len(train_loader)))
write_into(file_name_path=log_result_file, log_str="Validation dataset={}, sampler={}, loader={}".format(len(valid_dataset), len(valid_sampler), len(valid_loader)))
best_auc = 0.0
ori_valid_length = len(valid_dataset)
best_model_path = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(local_rank))
all_batch_avg_smooth_loss_list = []
for epoch in range(max_epochs):
train_sampler.set_epoch(epoch)
valid_sampler.set_epoch(epoch)
# train for one epoch
time_start = datetime.now()
if local_rank == 0:
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
original_valid_length=ori_valid_length, result_file=log_result_file)
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
# adjust learning rate
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] / optimizer_params.learning_anneal
time_end = datetime.now()
if local_rank == 0:
write_into(log_result_file, "\n{} end of {}-epoch, with best_auc={}, len of loss={}, end time={}, and time period={} {}".format("*" * 50, epoch, best_auc,
len(smooth_avg_reduced_loss_list),
time_end.strftime("%Y-%m-%d@%H:%M:%S"),
time_end - time_start, "*" * 50))
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
@hydra.main(config_path="../configs/", config_name="default.yaml")
def main_app(config: DictConfig):
# set seed for determinism for reproduction
random.seed(config.Training.seed)
np.random.seed(config.Training.seed)
torch.manual_seed(config.Training.seed)
torch.cuda.manual_seed(config.Training.seed)
torch.cuda.manual_seed_all(config.Training.seed)
# setting hyper-parameter for Training / Model / Optimizer
_train_params = TrainParams(processed_files_path=to_absolute_path(config.Data.preprocess_root), max_epochs=config.Training.max_epoches, train_bs=config.Training.train_batch_size, test_bs=config.Training.test_batch_size, external_func_vocab_file=to_absolute_path(config.Data.train_vocab_file), max_vocab_size=config.Data.max_vocab_size)
_model_params = ModelParams(gnn_type=config.Model.gnn_type, pool_type=config.Model.pool_type, acfg_init_dims=config.Model.acfg_node_init_dims, cfg_filters=config.Model.cfg_filters, fcg_filters=config.Model.fcg_filters, number_classes=config.Model.number_classes, dropout_rate=config.Model.drapout_rate, ablation_models=config.Model.ablation_models)
_optim_params = OptimizerParams(optimizer_name=config.Optimizer.name, lr=config.Optimizer.learning_rate, weight_decay=config.Optimizer.weight_decay, learning_anneal=config.Optimizer.learning_anneal)
# logging
log = logging.getLogger("DistTrainModel.py")
log.setLevel("DEBUG")
log.warning("Hydra's Current Working Directory: {}".format(os.getcwd()))
# setting for the log directory
result_file = '{}_{}_{}_ACFG_{}_FCG_{}_Epoch_{}_TrainBS_{}_LR_{}_Time_{}.txt'.format(_model_params.ablation_models, _model_params.gnn_type, _model_params.pool_type,
_model_params.cfg_filters, _model_params.fcg_filters, _train_params.max_epochs,
_train_params.train_bs, _optim_params.lr, datetime.now().strftime("%Y%m%d_%H%M%S"))
log_result_file = os.path.join(os.getcwd(), result_file)
_other_params = {"Hydra's Current Working Directory": os.getcwd(), "seed": config.Training.seed, "log result file": log_result_file, "only_test_path": config.Training.only_test_path}
params_print_log(_train_params.__dict__, log_result_file)
params_print_log(_model_params.__dict__, log_result_file)
params_print_log(_optim_params.__dict__, log_result_file)
params_print_log(_other_params, log_result_file)
if config.Training.only_test_path == 'None':
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(config.Training.dist_port)
# num_gpus = 1
num_gpus = torch.cuda.device_count()
log.info("Total number of GPUs = {}".format(num_gpus))
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
else:
best_model_file = config.Training.only_test_path
# model re-init and loading
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
device = torch.device('cuda')
train_vocab_path = _train_params.external_func_vocab_file
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
if _model_params.ablation_models.lower() == "full":
model = HierarchicalGraphNeuralNetwork(model_params=_model_params, external_vocab=vocab, global_log=log)
else:
raise NotImplementedError
model.to(device)
model.load_state_dict(torch.load(best_model_file, map_location=device))
criterion = nn.BCELoss().cuda()
test_batch_size = config.Training.test_batch_size
dataset_root_path = _train_params.processed_files_path
# validation dataset & dataloader
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
valid_dataloader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, shuffle=False)
log.info("Total number of all validation samples = {} ".format(len(valid_dataset)))
# testing dataset & dataloader
test_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="test")
test_dataloader = DataLoaderX(dataset=test_dataset, batch_size=test_batch_size, shuffle=False)
log.info("Total number of all testing samples = {} ".format(len(test_dataset)))
_valid_result = validate(valid_loader=valid_dataloader, model=model, criterion=criterion, evaluate_flag="DoubleCheckValidation", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(valid_dataset), result_file=log_result_file, details=True)
log.warning("\n\n" + _valid_result.__str__())
_test_result = validate(valid_loader=test_dataloader, model=model, criterion=criterion, evaluate_flag="FinalTestingResult", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(test_dataset), result_file=log_result_file, details=True)
log.warning("\n\n" + _test_result.__str__())
if __name__ == '__main__':
main_app()

View File

@ -0,0 +1,213 @@
import logging
import sys
import torch
from torch import nn
from torch.nn import Linear
from torch.nn import functional as pt_f
from torch_geometric.data import Batch, Data
from torch_geometric.nn.conv import GCNConv, SAGEConv
from torch_geometric.nn.glob import global_max_pool, global_mean_pool
sys.path.append("..")
from utils.ParameterClasses import ModelParams # noqa
from utils.Vocabulary import Vocab # noqa
def div_with_small_value(n, d, eps=1e-8):
d = d * (d > eps).float() + eps * (d <= eps).float()
return n / d
def padding_tensors(tensor_list):
num = len(tensor_list)
max_len = max([s.shape[0] for s in tensor_list])
out_dims = (num, max_len, *tensor_list[0].shape[1:])
out_tensor = tensor_list[0].data.new(*out_dims).fill_(0)
mask = tensor_list[0].data.new(*out_dims).fill_(0)
for i, tensor in enumerate(tensor_list):
length = tensor.size(0)
out_tensor[i, :length] = tensor
mask[i, :length] = 1
return out_tensor, mask
def inverse_padding_tensors(tensors, masks):
mask_index = torch.sum(masks, dim=-1) / masks.size(-1)
# print("mask_index: ", mask_index.size(), mask_index)
_out_mask_select = torch.masked_select(tensors, (masks == 1)).view(-1, tensors.size(-1))
# print("_out_mask_select: ", _out_mask_select.size(), _out_mask_select)
batch_index = torch.sum(mask_index, dim=-1)
# print("batch_index: ", type(batch_index), batch_index.size(), batch_index)
batch_idx_list = []
for idx, num in enumerate(batch_index):
batch_idx_list.extend([idx for _ in range(int(num))])
return _out_mask_select, batch_idx_list
class HierarchicalGraphNeuralNetwork(nn.Module):
def __init__(self, model_params: ModelParams, external_vocab: Vocab, global_log: logging.Logger): # device=torch.device('cuda')
super(HierarchicalGraphNeuralNetwork, self).__init__()
self.conv = model_params.gnn_type.lower()
if self.conv not in ['graphsage', 'gcn']:
raise NotImplementedError
self.pool = model_params.pool_type.lower()
if self.pool not in ["global_max_pool", "global_mean_pool"]:
raise NotImplementedError
# self.device = device
self.global_log = global_log
# Hierarchical 1: Control Flow Graph (CFG) embedding and pooling
print(type(model_params.cfg_filters), model_params.cfg_filters)
if type(model_params.cfg_filters) == str:
cfg_filter_list = [int(number_filter) for number_filter in model_params.cfg_filters.split("-")]
else:
cfg_filter_list = [int(model_params.cfg_filters)]
cfg_filter_list.insert(0, model_params.acfg_init_dims)
self.cfg_filter_length = len(cfg_filter_list)
cfg_graphsage_params = [dict(in_channels=cfg_filter_list[i], out_channels=cfg_filter_list[i + 1], bias=True) for i in range(self.cfg_filter_length - 1)] # GraphSAGE for cfg
cfg_gcn_params = [dict(in_channels=cfg_filter_list[i], out_channels=cfg_filter_list[i + 1], cached=False, bias=True) for i in range(self.cfg_filter_length - 1)] # GCN for cfg
cfg_conv_layer_constructor = {
'graphsage': dict(constructor=SAGEConv, kwargs=cfg_graphsage_params),
'gcn': dict(constructor=GCNConv, kwargs=cfg_gcn_params)
}
cfg_conv = cfg_conv_layer_constructor[self.conv]
cfg_constructor = cfg_conv['constructor']
for i in range(self.cfg_filter_length - 1):
setattr(self, 'CFG_gnn_{}'.format(i + 1), cfg_constructor(**cfg_conv['kwargs'][i]))
# self.dropout = nn.Dropout(p=model_params.dropout_rate).to(self.device)
self.dropout = nn.Dropout(p=model_params.dropout_rate)
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2, embedding_dim=cfg_filter_list[-1], padding_idx=external_vocab.pad_idx)
print(type(model_params.fcg_filters), model_params.fcg_filters)
if type(model_params.fcg_filters) == str:
fcg_filter_list = [int(number_filter) for number_filter in model_params.fcg_filters.split("-")]
else:
fcg_filter_list = [int(model_params.fcg_filters)]
fcg_filter_list.insert(0, cfg_filter_list[-1])
self.fcg_filter_length = len(fcg_filter_list)
fcg_graphsage_params = [dict(in_channels=fcg_filter_list[i], out_channels=fcg_filter_list[i + 1], bias=True) for i in range(self.fcg_filter_length - 1)] # GraphSAGE for fcg
fcg_gcn_params = [dict(in_channels=fcg_filter_list[i], out_channels=fcg_filter_list[i + 1], cached=False, bias=True) for i in range(self.fcg_filter_length - 1)] # GCN for fcg
fcg_conv_layer_constructor = {
'graphsage': dict(constructor=SAGEConv, kwargs=fcg_graphsage_params),
'gcn': dict(constructor=GCNConv, kwargs=fcg_gcn_params)
}
fcg_conv = fcg_conv_layer_constructor[self.conv]
fcg_constructor = fcg_conv['constructor']
for i in range(self.fcg_filter_length - 1):
setattr(self, 'FCG_gnn_{}'.format(i + 1), fcg_constructor(**fcg_conv['kwargs'][i]))
# Last Projection Function: gradually project with more linear layers
self.pj1 = Linear(in_features=fcg_filter_list[-1], out_features=int(fcg_filter_list[-1] / 2))
self.pj2 = Linear(in_features=int(fcg_filter_list[-1] / 2), out_features=int(fcg_filter_list[-1] / 4))
self.pj3 = Linear(in_features=int(fcg_filter_list[-1] / 4), out_features=1)
self.last_activation = nn.Sigmoid()
# self.last_activation = nn.Softmax(dim=1)
# self.last_activation = nn.LogSoftmax(dim=1)
def forward_cfg_gnn(self, local_batch: Batch):
in_x, edge_index = local_batch.x, local_batch.edge_index
for i in range(self.cfg_filter_length - 1):
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
out_x = pt_f.relu(out_x, inplace=True)
out_x = self.dropout(out_x)
in_x = out_x
local_batch.x = in_x
return local_batch
def aggregate_cfg_batch_pooling(self, local_batch: Batch):
if self.pool == 'global_max_pool':
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)
elif self.pool == 'global_mean_pool':
x_pool = global_mean_pool(x=local_batch.x, batch=local_batch.batch)
else:
raise NotImplementedError
return x_pool
def forward_fcg_gnn(self, function_batch: Batch):
in_x, edge_index = function_batch.x, function_batch.edge_index
for i in range(self.fcg_filter_length - 1):
out_x = getattr(self, 'FCG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
out_x = pt_f.relu(out_x, inplace=True)
out_x = self.dropout(out_x)
in_x = out_x
function_batch.x = in_x
return function_batch
def aggregate_fcg_batch_pooling(self, function_batch: Batch):
if self.pool == 'global_max_pool':
x_pool = global_max_pool(x=function_batch.x, batch=function_batch.batch)
elif self.pool == 'global_mean_pool':
x_pool = global_mean_pool(x=function_batch.x, batch=function_batch.batch)
else:
raise NotImplementedError
return x_pool
def aggregate_final_skip_pooling(self, x, batch):
if self.pool == 'global_max_pool':
x_pool = global_max_pool(x=x, batch=batch)
elif self.pool == 'global_mean_pool':
x_pool = global_mean_pool(x=x, batch=batch)
else:
raise NotImplementedError
return x_pool
@staticmethod
def cosine_attention(mtx1, mtx2):
v1_norm = mtx1.norm(p=2, dim=2, keepdim=True)
v2_norm = mtx2.norm(p=2, dim=2, keepdim=True).permute(0, 2, 1)
a = torch.bmm(mtx1, mtx2.permute(0, 2, 1))
d = v1_norm * v2_norm
return div_with_small_value(a, d)
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device):
rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch)
x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch)
# build the Function Call Graph (FCG) Data/Batch datasets
assert len(real_bt_positions) - 1 == len(bt_external_names), "all should be equal to the batch size ... "
assert len(real_bt_positions) - 1 == len(bt_all_function_edges), "all should be equal to the batch size ... "
fcg_list = []
fcg_internal_list = []
for idx_batch in range(len(real_bt_positions) - 1):
start_pos, end_pos = real_bt_positions[idx_batch: idx_batch + 2]
idx_x_cfg = x_cfg_pool[start_pos: end_pos]
fcg_internal_list.append(idx_x_cfg)
idx_x_external = self.external_embedding_layer(torch.tensor([bt_external_names[idx_batch]], dtype=torch.long).to(local_device))
idx_x_external = idx_x_external.squeeze(dim=0)
idx_x_total = torch.cat([idx_x_cfg, idx_x_external], dim=0)
idx_function_edge = torch.tensor(bt_all_function_edges[idx_batch], dtype=torch.long)
idx_graph_data = Data(x=idx_x_total, edge_index=idx_function_edge).to(local_device)
fcg_list.append(idx_graph_data)
fcg_batch = Batch.from_data_list(fcg_list)
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
rtn_fcg_batch = self.forward_fcg_gnn(function_batch=fcg_batch) # [batch_size, max_node_size, dim]
x_fcg_pool = self.aggregate_fcg_batch_pooling(function_batch=rtn_fcg_batch) # [batch_size, 1, dim] => [batch_size, dim]
batch_final = x_fcg_pool
# step last project to the number_of_numbers (binary)
bt_final_embed = self.pj3(self.pj2(self.pj1(batch_final)))
bt_pred = self.last_activation(bt_final_embed)
return bt_pred

0
src/models/__init__.py Normal file
View File

83
src/requirement_conda.txt Normal file
View File

@ -0,0 +1,83 @@
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
antlr4-python3-runtime=4.8=pypi_0
ase=3.21.1=pypi_0
ca-certificates=2021.1.19=h06a4308_1
cached-property=1.5.2=pypi_0
certifi=2020.12.5=py37h06a4308_0
cffi=1.14.5=pypi_0
chardet=4.0.0=pypi_0
cmake=3.18.4.post1=pypi_0
cycler=0.10.0=pypi_0
dataclasses=0.6=pypi_0
decorator=4.4.2=pypi_0
future=0.18.2=pypi_0
googledrivedownloader=0.4=pypi_0
h5py=3.2.1=pypi_0
hydra-core=1.0.6=pypi_0
idna=2.10=pypi_0
importlib-resources=5.1.2=pypi_0
intel-openmp=2021.1.2=pypi_0
isodate=0.6.0=pypi_0
jinja2=2.11.3=pypi_0
joblib=1.0.1=pypi_0
kiwisolver=1.3.1=pypi_0
ld_impl_linux-64=2.33.1=h53a641e_7
libedit=3.1.20191231=h14c3975_1
libffi=3.3=he6710b0_2
libgcc-ng=9.1.0=hdf63c60_0
libstdcxx-ng=9.1.0=hdf63c60_0
llvmlite=0.35.0=pypi_0
magma-cuda112=2.5.2=1
markupsafe=1.1.1=pypi_0
matplotlib=3.3.4=pypi_0
mkl=2021.1.1=pypi_0
mkl-include=2021.1.1=pypi_0
ncurses=6.2=he6710b0_1
networkx=2.5=pypi_0
ninja=1.10.0.post2=pypi_0
numba=0.52.0=pypi_0
numpy=1.20.1=pypi_0
omegaconf=2.0.6=pypi_0
openssl=1.1.1j=h27cfd23_0
pandas=1.2.3=pypi_0
pillow=8.1.2=pypi_0
pip=21.0.1=py37h06a4308_0
prefetch-generator=1.0.1=pypi_0
pycparser=2.20=pypi_0
pyparsing=2.4.7=pypi_0
python=3.7.9=h7579374_0
python-dateutil=2.8.1=pypi_0
python-louvain=0.15=pypi_0
pytz=2021.1=pypi_0
pyyaml=5.4.1=pypi_0
rdflib=5.0.0=pypi_0
readline=8.1=h27cfd23_0
requests=2.25.1=pypi_0
scikit-learn=0.24.1=pypi_0
scipy=1.6.1=pypi_0
seaborn=0.11.1=pypi_0
setuptools=52.0.0=py37h06a4308_0
six=1.15.0=pypi_0
sqlite=3.33.0=h62c20be_0
tbb=2021.1.1=pypi_0
texttable=1.6.3=pypi_0
threadpoolctl=2.1.0=pypi_0
tk=8.6.10=hbc83047_0
torch=1.8.0+cu111=pypi_0
torch-cluster=1.5.9=pypi_0
torch-geometric=1.6.3=pypi_0
torch-scatter=2.0.6=pypi_0
torch-sparse=0.6.9=pypi_0
torch-spline-conv=1.2.1=pypi_0
torchaudio=0.8.0=pypi_0
torchvision=0.9.0+cu111=pypi_0
tqdm=4.59.0=pypi_0
typing-extensions=3.7.4.3=pypi_0
urllib3=1.26.3=pypi_0
wheel=0.36.2=pyhd3eb1b0_0
xz=5.2.5=h7b6447c_0
zipp=3.4.1=pypi_0
zlib=1.2.11=h7b6447c_3

View File

@ -0,0 +1,108 @@
import logging
import os
import re
from collections import Counter
from dataclasses import dataclass
from typing import Dict
from sklearn.metrics import auc, confusion_matrix, balanced_accuracy_score
from texttable import Texttable
from datetime import datetime
def only_get_fpr(y_true, y_pred):
n_benign = (y_true == 0).sum()
n_false = (y_pred[y_true == 0] == 1).sum()
return float(n_false) / float(n_benign)
def get_fpr(y_true, y_pred):
tn, fp, fn, tp = confusion_matrix(y_true=y_true, y_pred=y_pred).ravel()
return float(fp) / float(fp + tn)
def find_threshold_with_fixed_fpr(y_true, y_pred, fpr_target):
start_time = datetime.now()
threshold = 0.0
fpr = only_get_fpr(y_true, y_pred > threshold)
while fpr > fpr_target and threshold <= 1.0:
threshold += 0.0001
fpr = only_get_fpr(y_true, y_pred > threshold)
tn, fp, fn, tp = confusion_matrix(y_true=y_true, y_pred=y_pred > threshold).ravel()
tpr = tp / (tp + fn)
fpr = fp / (fp + tn)
acc = (tp + tn) / (tn + fp + fn + tp) # equal to accuracy_score(y_true=y_true, y_pred=y_pred > threshold)
balanced_acc = balanced_accuracy_score(y_true=y_true, y_pred=y_pred > threshold)
_info = "Threshold: {:.6f}, TN: {}, FP: {}, FN: {}, TP: {}, TPR: {:.6f}, FPR: {:.6f}, ACC: {:.6f}, Balanced_ACC: {:.6f}. consume about {} time in find threshold".format(
threshold, tn, fp, fn, tp, tpr, fpr, acc, balanced_acc, datetime.now() - start_time)
return _info
def alphabet_lower_strip(str1):
return re.sub("[^A-Za-z]", "", str1).lower()
def filter_counter_with_threshold(counter: Counter, min_threshold):
return {x: counter[x] for x in counter if counter[x] >= min_threshold}
def create_dir_if_not_exists(new_dir: str, log: logging.Logger):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
log.info('We are creating the dir of \"{}\" '.format(new_dir))
else:
log.info('We CANNOT creat the dir of \"{}\" as it is already exists.'.format(new_dir))
def get_jsonl_files_from_path(file_path: str, log: logging.Logger):
file_list = []
for root, dirs, files in os.walk(file_path):
for file in files:
if file.endswith(".jsonl"):
file_list.append(os.path.join(root, file))
file_list.sort()
log.info("{}\nFrom the path of {}, we obtain a list of {} files as follows.".format("-" * 50, file_path, len(file_list)))
log.info("\n" + '\n'.join(str(f) for f in file_list))
return file_list
def write_into(file_name_path: str, log_str: str, print_flag=True):
if print_flag:
print(log_str)
if log_str is None:
log_str = 'None'
if os.path.isfile(file_name_path):
with open(file_name_path, 'a+') as log_file:
log_file.write(log_str + '\n')
else:
with open(file_name_path, 'w+') as log_file:
log_file.write(log_str + '\n')
def params_print_log(param_dict: Dict, log_path: str):
keys = sorted(param_dict.keys())
table = Texttable()
table.set_precision(6)
table.set_cols_align(["l", "l", "c"])
table.add_row(["Index", "Parameters", "Values"])
for index, k in enumerate(keys):
table.add_row([index, k, str(param_dict[k])])
# print(table.draw())
write_into(file_name_path=log_path, log_str=table.draw())
def dataclasses_to_string(ins: dataclass):
name = type(ins).__name__
var_list = [f"{key} = {value}" for key, value in vars(ins).items()]
var_str = '\n=>'.join(var_list)
return f"{name}:\n=>{var_str}\n"
if __name__ == '__main__':
pass

View File

@ -0,0 +1,60 @@
from dataclasses import dataclass
@dataclass
class TrainParams:
processed_files_path: str
# train_test_split_file: str
max_epochs: int
train_bs: int
test_bs: int
external_func_vocab_file: str
max_vocab_size: int
@dataclass
class OptimizerParams:
optimizer_name: str
lr: float
weight_decay: float
learning_anneal: float
@dataclass
class ModelParams:
gnn_type: str
pool_type: str
acfg_init_dims: int
cfg_filters: str
fcg_filters: str
number_classes: int
dropout_rate: float
ablation_models: str
@dataclass
class OneEpochResult:
Epoch_Flag: str
Number_Samples: int
Avg_Loss: float
Info_100: str
Info_1000: str
ROC_AUC_Score: float
Thresholds: list
TPRs: list
FPRs: list
def __str__(self):
s = "\nResult of \"{}\":\n=Epoch_Flag = {}\n=>Number of samples = {}\n=>Avg_Loss = {}\n=>Info_100 = {}\n=>Info_1000 = {}\n=>ROC_AUC_score = {}\n".format(
self.Epoch_Flag,
self.Epoch_Flag,
self.Number_Samples,
self.Avg_Loss,
self.Info_100,
self.Info_1000,
self.ROC_AUC_Score)
return s
if __name__ == '__main__':
pass

View File

@ -0,0 +1,85 @@
import os
import os.path as osp
from datetime import datetime
import torch
from torch_geometric.data import Dataset, DataLoader
from utils.RealBatch import create_real_batch_data # noqa
class MalwareDetectionDataset(Dataset):
def __init__(self, root, train_or_test, transform=None, pre_transform=None):
super(MalwareDetectionDataset, self).__init__(None, transform, pre_transform)
self.flag = train_or_test.lower()
self.malware_root = os.path.join(root, "{}_malware".format(self.flag))
self.benign_root = os.path.join(root, "{}_benign".format(self.flag))
self.malware_files = os.listdir(self.malware_root)
self.benign_files = os.listdir(self.benign_root)
@staticmethod
def _list_files_for_pt(the_path):
files = []
for name in os.listdir(the_path):
if os.path.splitext(name)[-1] == '.pt':
files.append(name)
return files
def __len__(self):
# def len(self):
# return 201
return len(self.malware_files) + len(self.benign_files)
def get(self, idx):
split = len(self.malware_files)
# split = 100
if idx < split:
idx_data = torch.load(osp.join(self.malware_root, 'malware_{}.pt'.format(idx)))
else:
over_fit_idx = idx - split
idx_data = torch.load(osp.join(self.benign_root, "benign_{}.pt".format(over_fit_idx)))
return idx_data
def _simulating(_dataset, _batch_size: int):
print("\nBatch size = {}".format(_batch_size))
time_start = datetime.now()
print("start time: " + time_start.strftime("%Y-%m-%d@%H:%M:%S"))
# https://github.com/pytorch/fairseq/issues/1560
# https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189
# loaders_1 = DataLoader(dataset=benign_exe_dataset, batch_size=10, shuffle=True, num_workers=0)
# increasing the shared memory: ulimit -SHn 51200
loader = DataLoader(dataset=_dataset, batch_size=_batch_size, shuffle=True) # default of prefetch_factor = 2 # num_workers=4
for index, data in enumerate(loader):
if index >= 3:
break
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
print(data)
print("Hash: ", _hash)
print("Position: ", _position)
print("\n")
time_end = datetime.now()
print("end time: " + time_end.strftime("%Y-%m-%d@%H:%M:%S"))
print("All time = {}\n\n".format(time_end - time_start))
if __name__ == '__main__':
root_path: str = '/home/xiang/MalGraph/data/processed_dataset/DatasetJSON/'
i_batch_size = 2
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
print(train_dataset.malware_root, train_dataset.benign_root)
print(len(train_dataset.malware_files), len(train_dataset.benign_files), len(train_dataset))
_simulating(_dataset=train_dataset, _batch_size=i_batch_size)
valid_dataset = MalwareDetectionDataset(root=root_path, train_or_test='valid')
print(valid_dataset.malware_root, valid_dataset.benign_root)
print(len(valid_dataset.malware_files), len(valid_dataset.benign_files), len(valid_dataset))
_simulating(_dataset=valid_dataset, _batch_size=i_batch_size)
test_dataset = MalwareDetectionDataset(root=root_path, train_or_test='test')
print(test_dataset.malware_root, test_dataset.benign_root)
print(len(test_dataset.malware_files), len(test_dataset.benign_files), len(test_dataset))
_simulating(_dataset=test_dataset, _batch_size=i_batch_size)

24
src/utils/RealBatch.py Normal file
View File

@ -0,0 +1,24 @@
import torch
from torch_geometric.data import Batch
from torch_geometric.data import DataLoader
from pprint import pprint
def create_real_batch_data(one_batch: Batch):
real = []
position = [0]
count = 0
assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other"
for item in one_batch.local_acfgs:
for acfg in item:
real.append(acfg)
count += len(item)
position.append(count)
if len(one_batch.local_acfgs) == 1 and len(one_batch.local_acfgs[0]) == 0:
return (None for _ in range(6))
else:
real_batch = Batch.from_data_list(real)
return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets

91
src/utils/Vocabulary.py Normal file
View File

@ -0,0 +1,91 @@
import json
import os
from collections import Counter
from tqdm import tqdm
class Vocab:
def __init__(self, freq_file: str, max_vocab_size: int, min_freq: int = 1, unk_token: str = '<unk>', pad_token: str = '<pad>', special_tokens: list = None):
self.max_vocab_size = max_vocab_size
self.min_freq = min_freq
self.unk_token = unk_token
self.pad_token = pad_token
self.special_tokens = special_tokens
assert os.path.exists(freq_file), "The file of {} is not exist".format(freq_file)
freq_counter = self.load_freq_counter_from_file(file_path=freq_file, min_freq=self.min_freq)
self.token_2_index, self.index_2_token = self.create_vocabulary(freq_counter=freq_counter)
self.unk_idx = None if self.unk_token is None else self.token_2_index[self.unk_token]
self.pad_idx = None if self.pad_token is None else self.token_2_index[self.pad_token]
def __len__(self):
return len(self.index_2_token)
def __getitem__(self, item: str):
assert isinstance(item, str)
if item in self.token_2_index.keys():
return self.token_2_index[item]
else:
if self.unk_token is not None:
return self.token_2_index[self.unk_token]
else:
raise KeyError("{} is not in the vocabulary, and self.unk_token is None".format(item))
def create_vocabulary(self, freq_counter: Counter):
token_2_index = {} # dict
index_2_token = [] # list
if self.unk_token is not None:
index_2_token.append(self.unk_token)
if self.pad_token is not None:
index_2_token.append(self.pad_token)
if self.special_tokens is not None:
for token in self.special_tokens:
index_2_token.append(token)
for f_name, count in tqdm(freq_counter.most_common(self.max_vocab_size), desc="creating vocab ... "):
if f_name in index_2_token:
print("trying to add {} to the vocabulary, but it already exists !!!".format(f_name))
continue
else:
index_2_token.append(f_name)
for index, token in enumerate(index_2_token): # reverse
token_2_index.update({token: index})
return token_2_index, index_2_token
@staticmethod
def load_freq_counter_from_file(file_path: str, min_freq: int):
freq_dict = {}
with open(file_path, 'r') as f:
for line in tqdm(f, desc="Load frequency list from the file of {} ... ".format(file_path)):
line = json.loads(line)
f_name = line["f_name"]
count = int(line["count"])
assert f_name not in freq_dict, "trying to add {} to the vocabulary, but it already exists !!!"
if count < min_freq:
print(line, "break")
break
freq_dict[f_name] = count
return Counter(freq_dict)
if __name__ == '__main__':
max_vocab_size = 1000
vocab = Vocab(freq_file="../../data/processed_dataset/train_external_function_name_vocab.jsonl", max_vocab_size=max_vocab_size)
print(len(vocab.token_2_index), vocab.token_2_index)
print(len(vocab.index_2_token), vocab.index_2_token)
print(vocab.unk_token, vocab.unk_idx)
print(vocab.pad_token, vocab.pad_idx)
print(vocab['queryperformancecounter'])
print(vocab['EmptyClipboard'])
print(vocab[str.lower('EmptyClipboard')])
print(vocab['X_Y_Z'])

0
src/utils/__init__.py Normal file
View File