备份
This commit is contained in:
parent
b4b131fc61
commit
facd48a0b7
@ -3,19 +3,19 @@ Data:
|
||||
train_vocab_file: "/home/king/python/data/fun_name_sort.jsonl"
|
||||
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
|
||||
Training:
|
||||
cuda: False # enable GPU training if cuda is available
|
||||
cuda: True # enable GPU training if cuda is available
|
||||
dist_backend: "nccl" # if using torch.distribution, the backend to be used
|
||||
dist_port: "1234"
|
||||
max_epoches: 10
|
||||
train_batch_size: 16
|
||||
test_batch_size: 32
|
||||
test_batch_size: 8
|
||||
seed: 19920208
|
||||
only_test_path: 'None'
|
||||
Model:
|
||||
ablation_models: "Full" # "Full"
|
||||
gnn_type: "GraphSAGE" # "GraphSAGE" / "GCN"
|
||||
pool_type: "global_max_pool" # "global_max_pool" / "global_mean_pool"
|
||||
acfg_node_init_dims: 11
|
||||
acfg_node_init_dims: 32
|
||||
cfg_filters: "200-200"
|
||||
fcg_filters: "200-200"
|
||||
number_classes: 1
|
||||
|
@ -39,7 +39,12 @@ def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: s
|
||||
os.makedirs(save_path+f"{train_type}_{file_type}/")
|
||||
with open(file, "r", encoding="utf-8") as item:
|
||||
line = item.readline()
|
||||
try:
|
||||
item = json.loads(line)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
print(e)
|
||||
print(file)
|
||||
return False
|
||||
item_hash = item['hash']
|
||||
acfg_list = []
|
||||
for one_acfg in item['acfg_list']: # list of dict of acfg
|
||||
@ -74,7 +79,7 @@ if __name__ == '__main__':
|
||||
file_type = ["malware", "benign"]
|
||||
max_vocab_size = 10000
|
||||
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
|
||||
parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
||||
file_type=file_type[0])
|
||||
# parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
||||
# file_type=file_type[0])
|
||||
parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file,
|
||||
file_type=file_type[1])
|
||||
|
@ -16,7 +16,7 @@ from omegaconf import DictConfig
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
from torch import nn
|
||||
from torch_geometric.loader import DataLoader
|
||||
from torch_geometric.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||
@ -122,7 +122,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op
|
||||
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
|
||||
model.eval()
|
||||
if distributed:
|
||||
local_device = torch.device("cuda", local_rank)
|
||||
local_device = torch.device("cpu", local_rank)
|
||||
else:
|
||||
local_device = torch.device("cuda")
|
||||
|
||||
@ -252,9 +252,18 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
||||
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length, result_file=log_result_file)
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank,
|
||||
train_loader=train_loader,
|
||||
valid_loader=valid_loader,
|
||||
model=model,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
nprocs=nprocs,
|
||||
idx_epoch=epoch,
|
||||
best_auc=best_auc,
|
||||
best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length,
|
||||
result_file=log_result_file)
|
||||
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
|
||||
|
||||
# adjust learning rate
|
||||
|
Loading…
Reference in New Issue
Block a user