From b4b131fc61ccdaa9a3bc722ed25128c74681f046 Mon Sep 17 00:00:00 2001 From: huihun <781165206@qq.com> Date: Mon, 29 Apr 2024 17:31:04 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=87=E4=BB=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/default.yaml | 6 +- requirement_conda.txt | 160 +++++++++++++++---------------- samples/PreProcess.py | 12 +-- src/DistTrainModel.py | 12 ++- src/utils/PreProcessedDataset.py | 2 +- torch_test.py | 9 +- 6 files changed, 105 insertions(+), 96 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index e545a15..a1d320e 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,9 +1,9 @@ Data: - preprocess_root: "/home/king/python/data/processed_dataset/DatasetJSON_remake" - train_vocab_file: "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl" + preprocess_root: "/home/king/python/data/DatasetJSON_remake" + train_vocab_file: "/home/king/python/data/fun_name_sort.jsonl" max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py Training: - cuda: True # enable GPU training if cuda is available + cuda: False # enable GPU training if cuda is available dist_backend: "nccl" # if using torch.distribution, the backend to be used dist_port: "1234" max_epoches: 10 diff --git a/requirement_conda.txt b/requirement_conda.txt index f9d3094..d089548 100644 --- a/requirement_conda.txt +++ b/requirement_conda.txt @@ -1,83 +1,83 @@ # This file may be used to create an environment using: # $ conda create --name --file # platform: linux-64 -_libgcc_mutex=0.1=main -antlr4-python3-runtime=4.8=pypi_0 -ase=3.21.1=pypi_0 -ca-certificates=2021.1.19=h06a4308_1 -cached-property=1.5.2=pypi_0 -certifi=2020.12.5=py37h06a4308_0 -cffi=1.14.5=pypi_0 -chardet=4.0.0=pypi_0 -cmake=3.18.4.post1=pypi_0 -cycler=0.10.0=pypi_0 -dataclasses=0.6=pypi_0 -decorator=4.4.2=pypi_0 -future=0.18.2=pypi_0 -googledrivedownloader=0.4=pypi_0 -h5py=3.2.1=pypi_0 -hydra-core=1.0.6=pypi_0 -idna=2.10=pypi_0 -importlib-resources=5.1.2=pypi_0 -intel-openmp=2021.1.2=pypi_0 -isodate=0.6.0=pypi_0 -jinja2=2.11.3=pypi_0 -joblib=1.0.1=pypi_0 -kiwisolver=1.3.1=pypi_0 -ld_impl_linux-64=2.33.1=h53a641e_7 -libedit=3.1.20191231=h14c3975_1 -libffi=3.3=he6710b0_2 -libgcc-ng=9.1.0=hdf63c60_0 -libstdcxx-ng=9.1.0=hdf63c60_0 -llvmlite=0.35.0=pypi_0 -magma-cuda112=2.5.2=1 -markupsafe=1.1.1=pypi_0 -matplotlib=3.3.4=pypi_0 -mkl=2021.1.1=pypi_0 -mkl-include=2021.1.1=pypi_0 -ncurses=6.2=he6710b0_1 -networkx=2.5=pypi_0 -ninja=1.10.0.post2=pypi_0 -numba=0.52.0=pypi_0 -numpy=1.20.1=pypi_0 -omegaconf=2.0.6=pypi_0 -openssl=1.1.1j=h27cfd23_0 -pandas=1.2.3=pypi_0 -pillow=8.1.2=pypi_0 -pip=21.0.1=py37h06a4308_0 -prefetch-generator=1.0.1=pypi_0 -pycparser=2.20=pypi_0 -pyparsing=2.4.7=pypi_0 -python=3.7.9=h7579374_0 -python-dateutil=2.8.1=pypi_0 -python-louvain=0.15=pypi_0 -pytz=2021.1=pypi_0 -pyyaml=5.4.1=pypi_0 -rdflib=5.0.0=pypi_0 -readline=8.1=h27cfd23_0 -requests=2.25.1=pypi_0 -scikit-learn=0.24.1=pypi_0 -scipy=1.6.1=pypi_0 -seaborn=0.11.1=pypi_0 -setuptools=52.0.0=py37h06a4308_0 -six=1.15.0=pypi_0 -sqlite=3.33.0=h62c20be_0 -tbb=2021.1.1=pypi_0 -texttable=1.6.3=pypi_0 -threadpoolctl=2.1.0=pypi_0 -tk=8.6.10=hbc83047_0 -torch=1.8.0+cu111=pypi_0 -torch-cluster=1.5.9=pypi_0 -torch-geometric=1.6.3=pypi_0 -torch-scatter=2.0.6=pypi_0 -torch-sparse=0.6.9=pypi_0 -torch-spline-conv=1.2.1=pypi_0 -torchaudio=0.8.0=pypi_0 -torchvision=0.9.0+cu111=pypi_0 -tqdm=4.59.0=pypi_0 -typing-extensions=3.7.4.3=pypi_0 -urllib3=1.26.3=pypi_0 -wheel=0.36.2=pyhd3eb1b0_0 -xz=5.2.5=h7b6447c_0 -zipp=3.4.1=pypi_0 -zlib=1.2.11=h7b6447c_3 +_libgcc_mutex=0.1 +antlr4-python3-runtime=4.8 +ase=3.21.1 +ca-certificates=2021.1.19 +cached-property=1.5.2 +certifi=2020.12.5 +cffi=1.14.5 +chardet=4.0.0 +cmake=3.18.4.post1 +cycler=0.10.0 +dataclasses=0.6 +decorator=4.4.2 +future=0.18.2 +googledrivedownloader=0.4 +h5py=3.2.1 +hydra-core=1.0.6 +idna=2.10 +importlib-resources=5.1.2 +intel-openmp=2021.1.2 +isodate=0.6.0 +jinja2=2.11.3 +joblib=1.0.1 +kiwisolver=1.3.1 +ld_impl_linux-64=2.33.1 +libedit=3.1.20191231 +libffi=3.3 +libgcc-ng=9.1.0 +libstdcxx-ng=9.1.0 +llvmlite=0.35.0 +magma-cuda112=2.5.2 +markupsafe=1.1.1 +matplotlib=3.3.4 +mkl=2021.1.1 +mkl-include=2021.1.1 +ncurses=6.2 +networkx=2.5 +ninja=1.10.0.post2 +numba=0.52.0 +numpy=1.20.1 +omegaconf=2.0.6 +openssl=1.1.1j +pandas=1.2.3 +pillow=8.1.2 +pip=21.0.1 +prefetch-generator=1.0.1 +pycparser=2.20 +pyparsing=2.4.7 +python=3.7.9 +python-dateutil=2.8.1 +python-louvain=0.15 +pytz=2021.1 +pyyaml=5.4.1 +rdflib=5.0.0 +readline=8.1 +requests=2.25.1 +scikit-learn=0.24.1 +scipy=1.6.1 +seaborn=0.11.1 +setuptools=52.0.0 +six=1.15.0 +sqlite=3.33.0 +tbb=2021.1.1 +texttable=1.6.3 +threadpoolctl=2.1.0 +tk=8.6.10 +torch=1.8.0+cu111 +torch-cluster=1.5.9 +torch-geometric=1.6.3 +torch-scatter=2.0.6 +torch-sparse=0.6.9 +torch-spline-conv=1.2.1 +torchaudio=0.8.0 +torchvision=0.9.0+cu111 +tqdm=4.59.0 +typing-extensions=3.7.4.3 +urllib3=1.26.3 +wheel=0.36.2 +xz=5.2.5 +zipp=3.4.1 +zlib=1.2.11 diff --git a/samples/PreProcess.py b/samples/PreProcess.py index bac6057..24fae78 100644 --- a/samples/PreProcess.py +++ b/samples/PreProcess.py @@ -18,7 +18,7 @@ def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save test_flag = True file_len = len(os.listdir(jsonl_file)) - for file in tqdm(os.listdir(jsonl_file)): + for file in tqdm(os.listdir(jsonl_file), desc=file_type): if index >= file_len * 0.8 and valid_flag: type_index += 1 valid_flag = False @@ -36,7 +36,7 @@ def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int): if not os.path.exists(save_path+f"{train_type}_{file_type}/"): - os.mkdir(save_path+f"{train_type}_{file_type}/") + os.makedirs(save_path+f"{train_type}_{file_type}/") with open(file, "r", encoding="utf-8") as item: line = item.readline() item = json.loads(line) @@ -67,10 +67,10 @@ def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: s if __name__ == '__main__': - malware_json_path = "/home/king/python/data/jsonl/infected_jsonl/" - benign_json_path = "/home/king/python/data/jsonl/refind_jsonl/" - train_vocab_file = "/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl" - save_vocab_file = "/home/king/python/data/processed_dataset/DatasetJSON_remake/" + malware_json_path = "/home/king/python/data/jsonl/malware/" + benign_json_path = "/home/king/python/data/jsonl/benign/" + train_vocab_file = "/home/king/python/data/fun_name_sort.jsonl" + save_vocab_file = "/home/king/python/data/DatasetJSON_remake/" file_type = ["malware", "benign"] max_vocab_size = 10000 vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size) diff --git a/src/DistTrainModel.py b/src/DistTrainModel.py index db47597..e9ce468 100644 --- a/src/DistTrainModel.py +++ b/src/DistTrainModel.py @@ -16,7 +16,7 @@ from omegaconf import DictConfig from prefetch_generator import BackgroundGenerator from sklearn.metrics import roc_auc_score, roc_curve from torch import nn -from torch_geometric.data import DataLoader +from torch_geometric.loader import DataLoader from tqdm import tqdm from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork @@ -26,6 +26,10 @@ from utils.PreProcessedDataset import MalwareDetectionDataset from utils.RealBatch import create_real_batch_data from utils.Vocabulary import Vocab +os.environ['TORCH_USE_CUDA_DSA'] = "1" +os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + + class DataLoaderX(DataLoader): def __iter__(self): @@ -304,8 +308,12 @@ def main_app(config: DictConfig): # num_gpus = 1 num_gpus = torch.cuda.device_count() log.info("Total number of GPUs = {}".format(num_gpus)) + # try: + # torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) + # except Exception as e: + # print(e) torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,)) - + # main_train_worker(0, num_gpus, _train_params, _model_params, _optim_params, log, log_result_file) best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0)) else: diff --git a/src/utils/PreProcessedDataset.py b/src/utils/PreProcessedDataset.py index 174e5f5..2e72826 100644 --- a/src/utils/PreProcessedDataset.py +++ b/src/utils/PreProcessedDataset.py @@ -24,7 +24,7 @@ class MalwareDetectionDataset(Dataset): files.append(name) return files - def __len__(self): + def len(self): # def len(self): # return 201 return len(self.malware_files) + len(self.benign_files) diff --git a/torch_test.py b/torch_test.py index 9044dc9..845a5dc 100644 --- a/torch_test.py +++ b/torch_test.py @@ -1,8 +1,9 @@ import torch_geometric import torch if __name__ == '__main__': - # print(torch.__version__) - # print(torch.cuda.device_count()) - # print(torch.cuda.get_device_name()) - print(torch.cuda.nccl.is_available()) + print(torch.__version__) + print(torch.cuda.device_count()) + print(torch.cuda.get_device_name()) + print(torch.cuda.is_available()) + # print(torch.cuda.nccl.is_available()) print(torch.cuda.nccl.version())