diff --git a/configs/default.yaml b/configs/default.yaml index 08ce646..f941dd5 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,6 +1,6 @@ Data: - preprocess_root: "../data/processed_dataset/DatasetJSON/" - train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl" + preprocess_root: "/root/autodl-tmp/" + train_vocab_file: "/root/autodl-tmp/train_external_function_name_vocab.jsonl" max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py Training: cuda: True # enable GPU training if cuda is available diff --git a/src/DistTrainModel.py b/src/DistTrainModel.py index 9869479..180a3d8 100644 --- a/src/DistTrainModel.py +++ b/src/DistTrainModel.py @@ -265,7 +265,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m # https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default -@hydra.main(config_path="../configs/", config_name="default.yaml") +@hydra.main(config_path="../configs/", config_name="default.yaml", version_base=None) def main_app(config: DictConfig): # set seed for determinism for reproduction random.seed(config.Training.seed) diff --git a/src/utils/PreProcessedDataset.py b/src/utils/PreProcessedDataset.py index dba7991..898c785 100644 --- a/src/utils/PreProcessedDataset.py +++ b/src/utils/PreProcessedDataset.py @@ -23,12 +23,15 @@ class MalwareDetectionDataset(Dataset): if os.path.splitext(name)[-1] == '.pt': files.append(name) return files - + def __len__(self): # def len(self): # return 201 return len(self.malware_files) + len(self.benign_files) - + + def len(self): + return len(self.malware_files) + len(self.benign_files) + def get(self, idx): split = len(self.malware_files) # split = 100