This commit is contained in:
TinyCaviar 2023-09-09 15:50:54 +08:00
parent 4d5f0f80ab
commit 161d8a80d2
3 changed files with 8 additions and 5 deletions

View File

@ -1,6 +1,6 @@
Data:
preprocess_root: "../data/processed_dataset/DatasetJSON/"
train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl"
preprocess_root: "/root/autodl-tmp/"
train_vocab_file: "/root/autodl-tmp/train_external_function_name_vocab.jsonl"
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
Training:
cuda: True # enable GPU training if cuda is available

View File

@ -265,7 +265,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
@hydra.main(config_path="../configs/", config_name="default.yaml")
@hydra.main(config_path="../configs/", config_name="default.yaml", version_base=None)
def main_app(config: DictConfig):
# set seed for determinism for reproduction
random.seed(config.Training.seed)

View File

@ -23,12 +23,15 @@ class MalwareDetectionDataset(Dataset):
if os.path.splitext(name)[-1] == '.pt':
files.append(name)
return files
def __len__(self):
# def len(self):
# return 201
return len(self.malware_files) + len(self.benign_files)
def len(self):
return len(self.malware_files) + len(self.benign_files)
def get(self, idx):
split = len(self.malware_files)
# split = 100