MalGraph/src/DistTrainModel_dual.py
2024-01-26 13:10:33 +08:00

383 lines
20 KiB
Python

import logging
import math
import os
import random
from datetime import datetime
import hydra
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as torch_mp
import torch.utils.data
import torch.utils.data.distributed
from hydra.utils import to_absolute_path
from omegaconf import DictConfig
from prefetch_generator import BackgroundGenerator
from sklearn.metrics import roc_auc_score, roc_curve
from torch import nn
from torch_geometric.data import DataLoader
from tqdm import tqdm
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
from utils.ParameterClasses import ModelParams, TrainParams, OptimizerParams, OneEpochResult
from utils.PreProcessedDataset import MalwareDetectionDataset
from utils.RealBatch import create_real_batch_data
from utils.Vocabulary import Vocab
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
def reduce_sum(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
return rt
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
rt /= nprocs
return rt
def all_gather_concat(tensor):
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, optimizer, nprocs, idx_epoch, best_auc, best_model_file, original_valid_length, result_file):
# print(train_loader.dataset.__dict__)
model.train()
local_device = torch.device("cuda", local_rank)
write_into(file_name_path=result_file, log_str="The local device = {} among {} nprocs in the {}-th epoch.".format(local_device, nprocs, idx_epoch))
until_sum_reduced_loss = 0.0
smooth_avg_reduced_loss_list = []
for _idx_bt, _batch in enumerate(tqdm(train_loader, desc="reading _batch from local_rank={}".format(local_rank))):
model.train()
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=_batch)
if _real_batch is None:
write_into(result_file,
"{}\n_real_batch is None in creating the real batch data of training ... ".format("*-" * 100))
continue
_real_batch = _real_batch.to(local_device)
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
train_batch_pred = model(real_local_batch=_real_batch,
real_bt_positions=_position,
bt_external_names=_external_list,
bt_all_function_edges=_function_edges,
local_device=local_device)
train_batch_pred = train_batch_pred.squeeze()
loss = criterion(train_batch_pred, _true_classes)
torch.distributed.barrier()
optimizer.zero_grad()
loss.backward()
optimizer.step()
reduced_loss = reduce_mean(loss, nprocs)
until_sum_reduced_loss += reduced_loss.item()
smooth_avg_reduced_loss_list.append(until_sum_reduced_loss / (_idx_bt + 1))
if _idx_bt != 0 and (_idx_bt % math.ceil(len(train_loader) / 3) == 0 or _idx_bt == int(len(train_loader) - 1)):
val_start_time = datetime.now()
if local_rank == 0:
write_into(result_file, "\nIn {}-th epoch, {}-th batch, we start to validate ... ".format(idx_epoch, _idx_bt))
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
valid_result = validate(local_rank=local_rank,
valid_loader=valid_loader,
model=model,
criterion=criterion,
evaluate_flag=_eval_flag,
distributed=True, # 分布式
nprocs=nprocs,
original_valid_length=original_valid_length,
result_file=result_file,
details=True # 验证细节
)
if best_auc < valid_result.ROC_AUC_Score:
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
_idx_bt,
best_auc,
valid_result.ROC_AUC_Score,
best_model_file)
best_auc = valid_result.ROC_AUC_Score
torch.save(model.module.state_dict(), best_model_file)
else:
_info = "[AUC NOT Increased!] AUC decreased from {:.5f} to {:.5f}!".format(best_auc, valid_result.ROC_AUC_Score)
if local_rank == 0:
write_into(result_file, valid_result.__str__())
write_into(result_file, _info)
write_into(result_file, "[#One Validation Time#] Consume about {} time period for one validation.".format(datetime.now() - val_start_time))
return smooth_avg_reduced_loss_list, best_auc
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
model.eval()
if distributed:
local_device = torch.device("cuda", local_rank)
else:
local_device = torch.device("cuda")
sum_loss = torch.tensor(0.0, dtype=torch.float, device=local_device)
n_samples = torch.tensor(0, dtype=torch.int, device=local_device)
all_true_classes = []
all_positive_probs = []
with torch.no_grad():
for idx_batch, data in enumerate(tqdm(valid_loader)):
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
if _real_batch is None:
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of validation ... ".format("*-" * 100))
continue
_real_batch = _real_batch.to(local_device)
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
batch_pred = model(real_local_batch=_real_batch,
real_bt_positions=_position,
bt_external_names=_external_list,
bt_all_function_edges=_function_edges,
local_device=local_device)
batch_pred = batch_pred.squeeze(-1)
loss = criterion(batch_pred, _true_classes)
sum_loss += loss.item()
n_samples += len(batch_pred)
all_true_classes.append(_true_classes)
all_positive_probs.append(batch_pred)
avg_loss = sum_loss / (idx_batch + 1)
all_true_classes = torch.cat(all_true_classes, dim=0)
all_positive_probs = torch.cat(all_positive_probs, dim=0)
if distributed:
torch.distributed.barrier()
reduced_n_samples = reduce_sum(n_samples)
reduced_avg_loss = reduce_mean(avg_loss, nprocs)
gather_true_classes = all_gather_concat(all_true_classes).detach().cpu().numpy()
gather_positive_prods = all_gather_concat(all_positive_probs).detach().cpu().numpy()
gather_true_classes = gather_true_classes[:original_valid_length]
gather_positive_prods = gather_positive_prods[:original_valid_length]
else:
reduced_n_samples = n_samples
reduced_avg_loss = avg_loss
gather_true_classes = all_true_classes.detach().cpu().numpy()
gather_positive_prods = all_positive_probs.detach().cpu().numpy()
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
if details is True:
_100_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.01)
_1000_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.001)
else:
_100_info, _1000_info = "None", "None"
_eval_result = OneEpochResult(Epoch_Flag=evaluate_flag,
Number_Samples=reduced_n_samples,
Avg_Loss=reduced_avg_loss,
Info_100=_100_info,
Info_1000=_1000_info,
ROC_AUC_Score=_roc_auc_score,
Thresholds=_thresholds,
TPRs=_tpr,
FPRs=_fpr)
return _eval_result
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
log_result_file: str):
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
torch.cuda.set_device(local_rank)
# model configure
vocab = Vocab(freq_file=train_params.external_func_vocab_file, max_vocab_size=train_params.max_vocab_size)
if model_params.ablation_models.lower() == "full":
model = HierarchicalGraphNeuralNetwork(model_params=model_params, external_vocab=vocab, global_log=global_log)
else:
raise NotImplementedError
model.cuda(local_rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
if local_rank == 0:
write_into(file_name_path=log_result_file, log_str=model.__str__())
# loss function
criterion = nn.BCELoss().cuda(local_rank)
lr = optimizer_params.lr
if optimizer_params.optimizer_name.lower() == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif optimizer_params.optimizer_name.lower() == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=optimizer_params.weight_decay)
else:
raise NotImplementedError
max_epochs = train_params.max_epochs
dataset_root_path = train_params.processed_files_path
train_batch_size = train_params.train_bs
test_batch_size = train_params.test_bs
# training dataset & dataloader
train_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="train")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
# validation dataset & dataloader
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
valid_loader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, pin_memory=True, sampler=valid_sampler)
if local_rank == 0:
write_into(file_name_path=log_result_file, log_str="Training dataset={}, sampler={}, loader={}".format(len(train_dataset), len(train_sampler), len(train_loader)))
write_into(file_name_path=log_result_file, log_str="Validation dataset={}, sampler={}, loader={}".format(len(valid_dataset), len(valid_sampler), len(valid_loader)))
best_auc = 0.0
ori_valid_length = len(valid_dataset)
best_model_path = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(local_rank))
all_batch_avg_smooth_loss_list = []
for epoch in range(max_epochs):
train_sampler.set_epoch(epoch)
valid_sampler.set_epoch(epoch)
# train for one epoch
time_start = datetime.now()
if local_rank == 0:
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank,
train_loader=train_loader,
valid_loader=valid_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
nprocs=nprocs,
idx_epoch=epoch,
best_auc=best_auc,
best_model_file=best_model_path,
original_valid_length=ori_valid_length,
result_file=log_result_file)
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
# adjust learning rate
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] / optimizer_params.learning_anneal
time_end = datetime.now()
if local_rank == 0:
write_into(log_result_file, "\n{} end of {}-epoch, with best_auc={}, len of loss={}, end time={}, and time period={} {}".format("*" * 50, epoch, best_auc,
len(smooth_avg_reduced_loss_list),
time_end.strftime("%Y-%m-%d@%H:%M:%S"),
time_end - time_start, "*" * 50))
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
@hydra.main(config_path="../configs/", config_name="default.yaml")
def main_app(config: DictConfig):
# set seed for determinism for reproduction
random.seed(config.Training.seed)
np.random.seed(config.Training.seed)
torch.manual_seed(config.Training.seed)
torch.cuda.manual_seed(config.Training.seed)
torch.cuda.manual_seed_all(config.Training.seed)
# setting hyper-parameter for Training / Model / Optimizer
_train_params = TrainParams(processed_files_path=to_absolute_path(config.Data.preprocess_root), max_epochs=config.Training.max_epoches, train_bs=config.Training.train_batch_size, test_bs=config.Training.test_batch_size, external_func_vocab_file=to_absolute_path(config.Data.train_vocab_file), max_vocab_size=config.Data.max_vocab_size)
_model_params = ModelParams(gnn_type=config.Model.gnn_type, pool_type=config.Model.pool_type, acfg_init_dims=config.Model.acfg_node_init_dims, cfg_filters=config.Model.cfg_filters, fcg_filters=config.Model.fcg_filters, number_classes=config.Model.number_classes, dropout_rate=config.Model.drapout_rate, ablation_models=config.Model.ablation_models)
_optim_params = OptimizerParams(optimizer_name=config.Optimizer.name, lr=config.Optimizer.learning_rate, weight_decay=config.Optimizer.weight_decay, learning_anneal=config.Optimizer.learning_anneal)
# logging
log = logging.getLogger("DistTrainModel.py")
log.setLevel("DEBUG")
log.warning("Hydra's Current Working Directory: {}".format(os.getcwd()))
# setting for the log directory
result_file = '{}_{}_{}_ACFG_{}_FCG_{}_Epoch_{}_TrainBS_{}_LR_{}_Time_{}.txt'.format(_model_params.ablation_models, _model_params.gnn_type, _model_params.pool_type,
_model_params.cfg_filters, _model_params.fcg_filters, _train_params.max_epochs,
_train_params.train_bs, _optim_params.lr, datetime.now().strftime("%Y%m%d_%H%M%S"))
log_result_file = os.path.join(os.getcwd(), result_file)
_other_params = {"Hydra's Current Working Directory": os.getcwd(), "seed": config.Training.seed, "log result file": log_result_file, "only_test_path": config.Training.only_test_path}
params_print_log(_train_params.__dict__, log_result_file)
params_print_log(_model_params.__dict__, log_result_file)
params_print_log(_optim_params.__dict__, log_result_file)
params_print_log(_other_params, log_result_file)
if config.Training.only_test_path == 'None':
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(config.Training.dist_port)
# num_gpus = 1
num_gpus = torch.cuda.device_count()
log.info("Total number of GPUs = {}".format(num_gpus))
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
else:
best_model_file = config.Training.only_test_path
# model re-init and loading
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
device = torch.device('cuda')
train_vocab_path = _train_params.external_func_vocab_file
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
if _model_params.ablation_models.lower() == "full":
model = HierarchicalGraphNeuralNetwork(model_params=_model_params, external_vocab=vocab, global_log=log)
else:
raise NotImplementedError
model.to(device)
model.load_state_dict(torch.load(best_model_file, map_location=device))
criterion = nn.BCELoss().cuda()
test_batch_size = config.Training.test_batch_size
dataset_root_path = _train_params.processed_files_path
# validation dataset & dataloader
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
valid_dataloader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, shuffle=False)
log.info("Total number of all validation samples = {} ".format(len(valid_dataset)))
# testing dataset & dataloader
test_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="test")
test_dataloader = DataLoaderX(dataset=test_dataset, batch_size=test_batch_size, shuffle=False)
log.info("Total number of all testing samples = {} ".format(len(test_dataset)))
_valid_result = validate(valid_loader=valid_dataloader, model=model, criterion=criterion, evaluate_flag="DoubleCheckValidation", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(valid_dataset), result_file=log_result_file, details=True)
log.warning("\n\n" + _valid_result.__str__())
_test_result = validate(valid_loader=test_dataloader, model=model, criterion=criterion, evaluate_flag="FinalTestingResult", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(test_dataset), result_file=log_result_file, details=True)
log.warning("\n\n" + _test_result.__str__())
if __name__ == '__main__':
main_app()