备份
This commit is contained in:
parent
b4b131fc61
commit
facd48a0b7
@ -3,19 +3,19 @@ Data:
|
|||||||
train_vocab_file: "/home/king/python/data/fun_name_sort.jsonl"
|
train_vocab_file: "/home/king/python/data/fun_name_sort.jsonl"
|
||||||
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
|
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
|
||||||
Training:
|
Training:
|
||||||
cuda: False # enable GPU training if cuda is available
|
cuda: True # enable GPU training if cuda is available
|
||||||
dist_backend: "nccl" # if using torch.distribution, the backend to be used
|
dist_backend: "nccl" # if using torch.distribution, the backend to be used
|
||||||
dist_port: "1234"
|
dist_port: "1234"
|
||||||
max_epoches: 10
|
max_epoches: 10
|
||||||
train_batch_size: 16
|
train_batch_size: 16
|
||||||
test_batch_size: 32
|
test_batch_size: 8
|
||||||
seed: 19920208
|
seed: 19920208
|
||||||
only_test_path: 'None'
|
only_test_path: 'None'
|
||||||
Model:
|
Model:
|
||||||
ablation_models: "Full" # "Full"
|
ablation_models: "Full" # "Full"
|
||||||
gnn_type: "GraphSAGE" # "GraphSAGE" / "GCN"
|
gnn_type: "GraphSAGE" # "GraphSAGE" / "GCN"
|
||||||
pool_type: "global_max_pool" # "global_max_pool" / "global_mean_pool"
|
pool_type: "global_max_pool" # "global_max_pool" / "global_mean_pool"
|
||||||
acfg_node_init_dims: 11
|
acfg_node_init_dims: 32
|
||||||
cfg_filters: "200-200"
|
cfg_filters: "200-200"
|
||||||
fcg_filters: "200-200"
|
fcg_filters: "200-200"
|
||||||
number_classes: 1
|
number_classes: 1
|
||||||
|
@ -39,7 +39,12 @@ def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: s
|
|||||||
os.makedirs(save_path+f"{train_type}_{file_type}/")
|
os.makedirs(save_path+f"{train_type}_{file_type}/")
|
||||||
with open(file, "r", encoding="utf-8") as item:
|
with open(file, "r", encoding="utf-8") as item:
|
||||||
line = item.readline()
|
line = item.readline()
|
||||||
|
try:
|
||||||
item = json.loads(line)
|
item = json.loads(line)
|
||||||
|
except json.decoder.JSONDecodeError as e:
|
||||||
|
print(e)
|
||||||
|
print(file)
|
||||||
|
return False
|
||||||
item_hash = item['hash']
|
item_hash = item['hash']
|
||||||
acfg_list = []
|
acfg_list = []
|
||||||
for one_acfg in item['acfg_list']: # list of dict of acfg
|
for one_acfg in item['acfg_list']: # list of dict of acfg
|
||||||
@ -74,7 +79,7 @@ if __name__ == '__main__':
|
|||||||
file_type = ["malware", "benign"]
|
file_type = ["malware", "benign"]
|
||||||
max_vocab_size = 10000
|
max_vocab_size = 10000
|
||||||
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
|
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
|
||||||
parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
# parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
||||||
file_type=file_type[0])
|
# file_type=file_type[0])
|
||||||
parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file,
|
parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file,
|
||||||
file_type=file_type[1])
|
file_type=file_type[1])
|
||||||
|
@ -16,7 +16,7 @@ from omegaconf import DictConfig
|
|||||||
from prefetch_generator import BackgroundGenerator
|
from prefetch_generator import BackgroundGenerator
|
||||||
from sklearn.metrics import roc_auc_score, roc_curve
|
from sklearn.metrics import roc_auc_score, roc_curve
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch_geometric.loader import DataLoader
|
from torch_geometric.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||||
@ -122,7 +122,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op
|
|||||||
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
|
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
|
||||||
model.eval()
|
model.eval()
|
||||||
if distributed:
|
if distributed:
|
||||||
local_device = torch.device("cuda", local_rank)
|
local_device = torch.device("cpu", local_rank)
|
||||||
else:
|
else:
|
||||||
local_device = torch.device("cuda")
|
local_device = torch.device("cuda")
|
||||||
|
|
||||||
@ -252,9 +252,18 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
|
|||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||||
|
|
||||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank,
|
||||||
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
train_loader=train_loader,
|
||||||
original_valid_length=ori_valid_length, result_file=log_result_file)
|
valid_loader=valid_loader,
|
||||||
|
model=model,
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer,
|
||||||
|
nprocs=nprocs,
|
||||||
|
idx_epoch=epoch,
|
||||||
|
best_auc=best_auc,
|
||||||
|
best_model_file=best_model_path,
|
||||||
|
original_valid_length=ori_valid_length,
|
||||||
|
result_file=log_result_file)
|
||||||
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
|
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
|
||||||
|
|
||||||
# adjust learning rate
|
# adjust learning rate
|
||||||
|
Loading…
Reference in New Issue
Block a user