This commit is contained in:
TinyCaviar 2023-09-09 15:50:54 +08:00
parent 4d5f0f80ab
commit 161d8a80d2
3 changed files with 8 additions and 5 deletions

View File

@ -1,6 +1,6 @@
Data: Data:
preprocess_root: "../data/processed_dataset/DatasetJSON/" preprocess_root: "/root/autodl-tmp/"
train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl" train_vocab_file: "/root/autodl-tmp/train_external_function_name_vocab.jsonl"
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
Training: Training:
cuda: True # enable GPU training if cuda is available cuda: True # enable GPU training if cuda is available

View File

@ -265,7 +265,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default # https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
@hydra.main(config_path="../configs/", config_name="default.yaml") @hydra.main(config_path="../configs/", config_name="default.yaml", version_base=None)
def main_app(config: DictConfig): def main_app(config: DictConfig):
# set seed for determinism for reproduction # set seed for determinism for reproduction
random.seed(config.Training.seed) random.seed(config.Training.seed)

View File

@ -29,6 +29,9 @@ class MalwareDetectionDataset(Dataset):
# return 201 # return 201
return len(self.malware_files) + len(self.benign_files) return len(self.malware_files) + len(self.benign_files)
def len(self):
return len(self.malware_files) + len(self.benign_files)
def get(self, idx): def get(self, idx):
split = len(self.malware_files) split = len(self.malware_files)
# split = 100 # split = 100