use DataParallel and smaller BERT
This commit is contained in:
parent
e79dd6c76e
commit
1aa21f4cc6
@ -34,21 +34,11 @@ from datasets import load_dataset
|
|||||||
from torch.nn import DataParallel
|
from torch.nn import DataParallel
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import (
|
from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
|
||||||
CONFIG_MAPPING,
|
AutoModelForMaskedLM, AutoTokenizer, BatchEncoding,
|
||||||
MODEL_MAPPING,
|
BertConfig, BertForPreTraining,
|
||||||
AdamW,
|
DataCollatorForLanguageModeling, SchedulerType,
|
||||||
AutoConfig,
|
get_scheduler, set_seed)
|
||||||
AutoModelForMaskedLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BatchEncoding,
|
|
||||||
BertConfig,
|
|
||||||
BertForPreTraining,
|
|
||||||
DataCollatorForLanguageModeling,
|
|
||||||
SchedulerType,
|
|
||||||
get_scheduler,
|
|
||||||
set_seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
from my_data_collator import MyDataCollatorForPreTraining
|
from my_data_collator import MyDataCollatorForPreTraining
|
||||||
from process_data.utils import CURRENT_DATA_BASE
|
from process_data.utils import CURRENT_DATA_BASE
|
||||||
@ -197,7 +187,8 @@ def main():
|
|||||||
# field="data",
|
# field="data",
|
||||||
# )
|
# )
|
||||||
train_files = [
|
train_files = [
|
||||||
os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) for i in range(128)
|
os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i))
|
||||||
|
for i in range(0, 128, 2)
|
||||||
]
|
]
|
||||||
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
||||||
raw_datasets = load_dataset(
|
raw_datasets = load_dataset(
|
||||||
@ -221,7 +212,7 @@ def main():
|
|||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=tokenizer.get_vocab_size(),
|
vocab_size=tokenizer.get_vocab_size(),
|
||||||
hidden_size=96,
|
hidden_size=96,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=4,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
intermediate_size=384,
|
intermediate_size=384,
|
||||||
max_position_embeddings=32,
|
max_position_embeddings=32,
|
||||||
@ -230,6 +221,8 @@ def main():
|
|||||||
# initalize a new BERT for pre-training
|
# initalize a new BERT for pre-training
|
||||||
model = BertForPreTraining(config)
|
model = BertForPreTraining(config)
|
||||||
|
|
||||||
|
model = DataParallel(model)
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
column_names = raw_datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
@ -313,14 +306,13 @@ def main():
|
|||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
# model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||||
# model, optimizer, train_dataloader, eval_dataloader
|
model, optimizer, train_dataloader, eval_dataloader
|
||||||
# )
|
|
||||||
model, optimizer, train_dataloader = accelerator.prepare(
|
|
||||||
model, optimizer, train_dataloader
|
|
||||||
)
|
)
|
||||||
|
# model, optimizer, train_dataloader = accelerator.prepare(
|
||||||
|
# model, optimizer, train_dataloader
|
||||||
|
# )
|
||||||
|
|
||||||
model = DataParallel(model)
|
|
||||||
# model.to("cuda:0")
|
# model.to("cuda:0")
|
||||||
|
|
||||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||||
@ -374,6 +366,7 @@ def main():
|
|||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
loss = loss.sum()
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if (
|
if (
|
||||||
|
Loading…
Reference in New Issue
Block a user