use DataParallel and smaller BERT

This commit is contained in:
zyr 2021-06-07 18:41:29 +08:00
parent e79dd6c76e
commit 1aa21f4cc6

View File

@ -34,21 +34,11 @@ from datasets import load_dataset
from torch.nn import DataParallel from torch.nn import DataParallel
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import ( from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
CONFIG_MAPPING, AutoModelForMaskedLM, AutoTokenizer, BatchEncoding,
MODEL_MAPPING, BertConfig, BertForPreTraining,
AdamW, DataCollatorForLanguageModeling, SchedulerType,
AutoConfig, get_scheduler, set_seed)
AutoModelForMaskedLM,
AutoTokenizer,
BatchEncoding,
BertConfig,
BertForPreTraining,
DataCollatorForLanguageModeling,
SchedulerType,
get_scheduler,
set_seed,
)
from my_data_collator import MyDataCollatorForPreTraining from my_data_collator import MyDataCollatorForPreTraining
from process_data.utils import CURRENT_DATA_BASE from process_data.utils import CURRENT_DATA_BASE
@ -197,7 +187,8 @@ def main():
# field="data", # field="data",
# ) # )
train_files = [ train_files = [
os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) for i in range(128) os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i))
for i in range(0, 128, 2)
] ]
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
raw_datasets = load_dataset( raw_datasets = load_dataset(
@ -221,7 +212,7 @@ def main():
config = BertConfig( config = BertConfig(
vocab_size=tokenizer.get_vocab_size(), vocab_size=tokenizer.get_vocab_size(),
hidden_size=96, hidden_size=96,
num_hidden_layers=12, num_hidden_layers=4,
num_attention_heads=12, num_attention_heads=12,
intermediate_size=384, intermediate_size=384,
max_position_embeddings=32, max_position_embeddings=32,
@ -230,6 +221,8 @@ def main():
# initalize a new BERT for pre-training # initalize a new BERT for pre-training
model = BertForPreTraining(config) model = BertForPreTraining(config)
model = DataParallel(model)
# Preprocessing the datasets. # Preprocessing the datasets.
column_names = raw_datasets["train"].column_names column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
@ -313,14 +306,13 @@ def main():
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
# model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
# model, optimizer, train_dataloader, eval_dataloader model, optimizer, train_dataloader, eval_dataloader
# )
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
) )
# model, optimizer, train_dataloader = accelerator.prepare(
# model, optimizer, train_dataloader
# )
model = DataParallel(model)
# model.to("cuda:0") # model.to("cuda:0")
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
@ -374,6 +366,7 @@ def main():
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
loss = loss.sum()
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss) accelerator.backward(loss)
if ( if (