diff --git a/configs/default.yaml b/configs/default.yaml index e3ecef9..e545a15 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,5 +1,5 @@ Data: - preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON" + preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON_remake" train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl" max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py Training: diff --git a/req_pip.txt b/req_pip.txt new file mode 100644 index 0000000..68c8934 --- /dev/null +++ b/req_pip.txt @@ -0,0 +1,25 @@ +antlr4-python3-runtime==4.8 +ase==3.21.1 +cmake==3.18.4.post1 +dataclasses==0.6 +googledrivedownloader==0.4 +hydra-core==1.0.6 +importlib-resources==5.1.2 +intel-openmp==2021.1.2 +magma-cuda112==2.5.2 +mkl==2021.1.1 +mkl-include==2021.1.1 +ninja==1.10.0.post2 +omegaconf==2.0.6 +prefetch-generator==1.0.1 +rdflib==5.0.0 +tbb==2021.1.1 +texttable==1.6.3 +torch==1.8.0+cu111 +torch-cluster==1.5.9 +torch-geometric==1.6.3 +torch-scatter==2.0.6 +torch-sparse==0.6.9 +torch-spline-conv==1.2.1 +torchaudio==0.8.0 +torchvision==0.9.0+cu111 diff --git a/samples/PreProcess.py b/samples/PreProcess.py index 6a5ba3d..bac6057 100644 --- a/samples/PreProcess.py +++ b/samples/PreProcess.py @@ -1,10 +1,10 @@ import json import os - +import sys import torch from torch_geometric.data import Data from tqdm import tqdm - +sys.path.append(os.path.dirname(sys.path[0])) from src.utils.Vocabulary import Vocab @@ -35,6 +35,8 @@ def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int): + if not os.path.exists(save_path+f"{train_type}_{file_type}/"): + os.mkdir(save_path+f"{train_type}_{file_type}/") with open(file, "r", encoding="utf-8") as item: line = item.readline() item = json.loads(line) @@ -65,11 +67,14 @@ def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: s if __name__ == '__main__': - json_path = "./jsonl/infected_jsonl/" - train_vocab_file = "../data/processed_dataset/train_external_function_name_vocab.jsonl" - save_vocab_file = "../data/processed_dataset/DatasetJSON/" - file_type = "malware" + malware_json_path = "/home/king/python/data/jsonl/infected_jsonl/" + benign_json_path = "/home/king/python/data/jsonl/refind_jsonl/" + train_vocab_file = "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl" + save_vocab_file = "/home/king/python/data/processed_dataset/DatasetJSON_remake/" + file_type = ["malware", "benign"] max_vocab_size = 10000 vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size) - parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary, save_path=save_vocab_file, - file_type=file_type) + parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file, + file_type=file_type[0]) + parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file, + file_type=file_type[1]) diff --git a/samples/funCount.py b/samples/funCount.py index 034bd9b..3bc6dc0 100644 --- a/samples/funCount.py +++ b/samples/funCount.py @@ -6,8 +6,8 @@ import heapq from tqdm import tqdm if __name__ == '__main__': - mal_file_name = './jsonl/infected_jsonl/' - ben_file_name = './jsonl/refind_jsonl/' + mal_file_name = '/home/king/python/data/jsonl/infected_jsonl/' + ben_file_name = '/home/king/python/data/jsonl/refind_jsonl/' fun_name_dict = {} for file in tqdm(os.listdir(mal_file_name)): with open(mal_file_name + file, 'r') as item: @@ -29,7 +29,7 @@ if __name__ == '__main__': fun_name_dict[fun_name] += 1 else: fun_name_dict[fun_name] = 1 - with open('./res.jsonl', 'w') as file: + with open('/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl', 'w') as file: largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1]) for key, value in largest_10000_items: temp = {"f_name": key, "count": value} diff --git a/src/DistTrainModel.py b/src/DistTrainModel.py index 91f09ee..db47597 100644 --- a/src/DistTrainModel.py +++ b/src/DistTrainModel.py @@ -170,6 +170,7 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html + _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) _fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods) if details is True: @@ -185,8 +186,8 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger, log_result_file: str): - dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) - # dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) + # dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank) + dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank) torch.cuda.set_device(local_rank) # model configure @@ -304,7 +305,7 @@ def main_app(config: DictConfig): num_gpus = torch.cuda.device_count() log.info("Total number of GPUs = {}".format(num_gpus)) torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) - # main_train_worker(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file, "") + best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0)) else: @@ -312,7 +313,7 @@ def main_app(config: DictConfig): # model re-init and loading log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file)) - device = torch.device('cuda:0') + device = torch.device('cuda') train_vocab_path = _train_params.external_func_vocab_file vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size) @@ -343,4 +344,4 @@ def main_app(config: DictConfig): if __name__ == '__main__': - main_app() + main_app() \ No newline at end of file diff --git a/src/DistTrainModel_dual.py b/src/DistTrainModel_dual.py index fff2073..06d168e 100644 --- a/src/DistTrainModel_dual.py +++ b/src/DistTrainModel_dual.py @@ -190,16 +190,6 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html - # try: - # _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) - # except ValueError: - # _roc_auc_score = 0.001 - print(gather_true_classes) - print(gather_positive_prods) - # try: - # _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) - # except ValueError: - # pass _roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods) _fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods) if details is True: diff --git a/src/utils/PreProcessedDataset.py b/src/utils/PreProcessedDataset.py index 87e7010..174e5f5 100644 --- a/src/utils/PreProcessedDataset.py +++ b/src/utils/PreProcessedDataset.py @@ -55,7 +55,6 @@ def _simulating(_dataset, _batch_size: int): if index >= 3: break _real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data) - print(data) print("Hash: ", _hash) print("Position: ", _position) print("\n") diff --git a/src/utils/RealBatch.py b/src/utils/RealBatch.py index 39313c8..99b5a2e 100644 --- a/src/utils/RealBatch.py +++ b/src/utils/RealBatch.py @@ -8,7 +8,6 @@ def create_real_batch_data(one_batch: Batch): real = [] position = [0] count = 0 - assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other" for item in one_batch.local_acfgs: