From 1aa21f4cc66974689b407838861714a8109b33ce Mon Sep 17 00:00:00 2001 From: zyr Date: Mon, 7 Jun 2021 18:41:29 +0800 Subject: [PATCH] use DataParallel and smaller BERT --- my_run_mlm_no_trainer.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/my_run_mlm_no_trainer.py b/my_run_mlm_no_trainer.py index 76dbd10..fbc570e 100644 --- a/my_run_mlm_no_trainer.py +++ b/my_run_mlm_no_trainer.py @@ -34,21 +34,11 @@ from datasets import load_dataset from torch.nn import DataParallel from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm -from transformers import ( - CONFIG_MAPPING, - MODEL_MAPPING, - AdamW, - AutoConfig, - AutoModelForMaskedLM, - AutoTokenizer, - BatchEncoding, - BertConfig, - BertForPreTraining, - DataCollatorForLanguageModeling, - SchedulerType, - get_scheduler, - set_seed, -) +from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig, + AutoModelForMaskedLM, AutoTokenizer, BatchEncoding, + BertConfig, BertForPreTraining, + DataCollatorForLanguageModeling, SchedulerType, + get_scheduler, set_seed) from my_data_collator import MyDataCollatorForPreTraining from process_data.utils import CURRENT_DATA_BASE @@ -197,7 +187,8 @@ def main(): # field="data", # ) train_files = [ - os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) for i in range(128) + os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) + for i in range(0, 128, 2) ] valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" raw_datasets = load_dataset( @@ -221,7 +212,7 @@ def main(): config = BertConfig( vocab_size=tokenizer.get_vocab_size(), hidden_size=96, - num_hidden_layers=12, + num_hidden_layers=4, num_attention_heads=12, intermediate_size=384, max_position_embeddings=32, @@ -230,6 +221,8 @@ def main(): # initalize a new BERT for pre-training model = BertForPreTraining(config) + model = DataParallel(model) + # Preprocessing the datasets. column_names = raw_datasets["train"].column_names text_column_name = "text" if "text" in column_names else column_names[0] @@ -313,14 +306,13 @@ def main(): optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) # Prepare everything with our `accelerator`. - # model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( - # model, optimizer, train_dataloader, eval_dataloader - # ) - model, optimizer, train_dataloader = accelerator.prepare( - model, optimizer, train_dataloader + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader ) + # model, optimizer, train_dataloader = accelerator.prepare( + # model, optimizer, train_dataloader + # ) - model = DataParallel(model) # model.to("cuda:0") # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be @@ -374,6 +366,7 @@ def main(): for step, batch in enumerate(train_dataloader): outputs = model(**batch) loss = outputs.loss + loss = loss.sum() loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) if (