From 4b6c65862be90d75f949cda37e6914403b2f9f20 Mon Sep 17 00:00:00 2001 From: Setra Solofoniaina <60129070+Setra-Solofoniaina@users.noreply.github.com> Date: Fri, 2 Apr 2021 09:54:59 +0300 Subject: [PATCH] first commit --- .gitignore | 7 + src/dataset/__init__.py | 2 + src/dataset/dataset.py | 158 ++++++++++++++++++ src/dataset/tokenizer.py | 46 +++++ src/main.py | 77 +++++++++ src/model/__init__.py | 2 + src/model/bert.py | 64 +++++++ src/model/embedding/__init__.py | 1 + .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 210 bytes .../embedding/__pycache__/bert.cpython-38.pyc | Bin 0 -> 1698 bytes .../__pycache__/position.cpython-38.pyc | Bin 0 -> 1167 bytes .../__pycache__/segment.cpython-38.pyc | Bin 0 -> 607 bytes .../__pycache__/token.cpython-38.pyc | Bin 0 -> 608 bytes src/model/embedding/bert.py | 36 ++++ src/model/embedding/position.py | 27 +++ src/model/embedding/segment.py | 6 + src/model/embedding/token.py | 6 + src/model/language_model.py | 61 +++++++ src/model/utils/__init__.py | 0 src/model/utils/gelu.py | 12 ++ src/trainer/__init__.py | 1 + src/trainer/optim_schedule.py | 35 ++++ src/trainer/pretrain.py | 152 +++++++++++++++++ 23 files changed, 693 insertions(+) create mode 100644 .gitignore create mode 100644 src/dataset/__init__.py create mode 100644 src/dataset/dataset.py create mode 100644 src/dataset/tokenizer.py create mode 100644 src/main.py create mode 100644 src/model/__init__.py create mode 100644 src/model/bert.py create mode 100644 src/model/embedding/__init__.py create mode 100644 src/model/embedding/__pycache__/__init__.cpython-38.pyc create mode 100644 src/model/embedding/__pycache__/bert.cpython-38.pyc create mode 100644 src/model/embedding/__pycache__/position.cpython-38.pyc create mode 100644 src/model/embedding/__pycache__/segment.cpython-38.pyc create mode 100644 src/model/embedding/__pycache__/token.cpython-38.pyc create mode 100644 src/model/embedding/bert.py create mode 100644 src/model/embedding/position.py create mode 100644 src/model/embedding/segment.py create mode 100644 src/model/embedding/token.py create mode 100644 src/model/language_model.py create mode 100644 src/model/utils/__init__.py create mode 100644 src/model/utils/gelu.py create mode 100644 src/trainer/__init__.py create mode 100644 src/trainer/optim_schedule.py create mode 100644 src/trainer/pretrain.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ddde8a7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +/src/dataset/corpus/* +/src/dataset/tok_model/* +/src/dataset/__pycache__ +/src/model/__pycache__ +/src/output/* +/src/trainer/__pycache__ +/.vscode \ No newline at end of file diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py new file mode 100644 index 0000000..6845a39 --- /dev/null +++ b/src/dataset/__init__.py @@ -0,0 +1,2 @@ +from .tokenizer import BertTokenizer +from .dataset import BERTDataset \ No newline at end of file diff --git a/src/dataset/dataset.py b/src/dataset/dataset.py new file mode 100644 index 0000000..7f05bc0 --- /dev/null +++ b/src/dataset/dataset.py @@ -0,0 +1,158 @@ +"""Dataset Class for Bert""" +import random +import tqdm +import torch +import linecache +from torch.utils.data import Dataset +from .tokenizer import BertTokenizer + + +class BERTDataset(Dataset): + def __init__(self, corpus_path, tokenizer: BertTokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): + self.tokenizer = tokenizer + self.seq_len = seq_len + + self.on_memory = on_memory + self.corpus_lines = corpus_lines + self.corpus_path = corpus_path + self.encoding = encoding + + self.corpus_lines = sum(1 for line in open(self.corpus_path)) + + # with open(corpus_path, "r", encoding=encoding) as f: + # if self.corpus_lines is None and not on_memory: + # for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines): + # self.corpus_lines += 1 + + # if on_memory: + # self.lines = [line[:-1].split("\t") + # for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] + # self.corpus_lines = len(self.lines) + + # if not on_memory: + # self.file = open(corpus_path, "r", encoding=encoding) + # self.random_file = open(corpus_path, "r", encoding=encoding) + + # for _ in range(random.randint(0, self.corpus_lines if self.corpus_lines < 1000 else 1000)): + # self.random_file.__next__() + + def __len__(self): + return self.corpus_lines + + def __getitem__(self, item): + t1, t2, is_next_label = self.random_sent(item) + t1_random, t1_label = self.random_word(t1) + t2_random, t2_label = self.random_word(t2) + + # [CLS] tag = SOS tag, [SEP] tag = EOS tag + t1 = [self.tokenizer.sos_index] + t1_random + [self.tokenizer.eos_index] + t2 = t2_random + [self.tokenizer.eos_index] + + t1_label = [self.tokenizer.pad_index] + t1_label + [self.tokenizer.pad_index] + t2_label = t2_label + [self.tokenizer.pad_index] + + segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] + bert_input = (t1 + t2)[:self.seq_len] + bert_label = (t1_label + t2_label)[:self.seq_len] + + padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))] + bert_input.extend(padding) + bert_label.extend(padding) + segment_label.extend(padding) + + output = {"bert_input": bert_input, + "bert_label": bert_label, + "segment_label": segment_label, + "is_next": is_next_label} + + return {key: torch.tensor(value) for key, value in output.items()} #pylint: disable=not-callable + + def random_word(self, sentence): + # tokens = sentence.split() + output_label = [] + tokens = self.tokenizer.tokenize(sentence) + for i, token in enumerate(tokens): + prob = random.random() + if prob < 0.15: + prob /= 0.15 + + if prob < 0.8: + tokens[i] = self.tokenizer.mask_index + elif prob < 0.9: + tokens[i] = self.tokenizer.getRandomTokenID() + else: + tokens[i] = token + output_label.append(token) + else: + tokens[i] = token + output_label.append(0) + return tokens, output_label + + # for i, token in enumerate(tokens): + # prob = random.random() + # if prob < 0.15: + # prob /= 0.15 + + # # 80% randomly change token to mask token + # if prob < 0.8: + # tokens[i] = self.vocab.mask_index + + # # 10% randomly change token to random token + # elif prob < 0.9: + # tokens[i] = random.randrange(len(self.vocab)) + + # # 10% randomly change token to current token + # else: + # tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) + + # output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) + + # else: + # tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) + # output_label.append(0) + + # return tokens, output_label + + def random_sent(self, index): + t1, t2 = self.get_corpus_line(index) + # t1 = self.tokenizer.tokenize(t1) + # t2 = self.tokenizer.tokenize(t2) + # output_text, label(isNotNext:0, isNext:1) + if random.random() > 0.5: + return t1, t2, 1 + else: + # rand_line = self.tokenizer.tokenize(self.get_random_line()) + return t1, self.get_random_line(), 0 + + # def get_corpus_line(self, item): + # if self.on_memory: + # return self.lines[item][0], self.lines[item][1] + # else: + # line = self.file.__next__() + # if line is None: + # self.file.close() + # self.file = open(self.corpus_path, "r", encoding=self.encoding) + # line = self.file.__next__() + + # t1, t2 = line[:-1].split("\t") + # return t1, t2 + def get_corpus_line(self, item): + t1 = linecache.getline(self.corpus_path, item) + t2 = linecache.getline(self.corpus_path, item+1) + return t1, t2 + + # def get_random_line(self): + # if self.on_memory: + # return self.lines[random.randrange(len(self.lines))][1] + + # line = self.file.__next__() + # if line is None: + # self.file.close() + # self.file = open(self.corpus_path, "r", encoding=self.encoding) + # for _ in range(random.randint(0, self.corpus_lines if self.corpus_lines < 1000 else 1000)): + # self.random_file.__next__() + # line = self.random_file.__next__() + # return line[:-1].split("\t")[1] + + def get_random_line(self): + return linecache.getline(self.corpus_path, random.randint(1, self.corpus_lines)) \ No newline at end of file diff --git a/src/dataset/tokenizer.py b/src/dataset/tokenizer.py new file mode 100644 index 0000000..3a7249c --- /dev/null +++ b/src/dataset/tokenizer.py @@ -0,0 +1,46 @@ +""" Tokenizer class """ +import os +import random +from pathlib import Path +import tokenizers +from tokenizers.pre_tokenizers import Whitespace +from tokenizers.pre_tokenizers import Digits + + +class BertTokenizer(): + """Bert Tokenizer using WordPiece Tokenizer Model""" + def __init__(self, path): + self.path = path + text_paths = [str(x) for x in Path("./dataset/corpus/").glob("**/*.txt")] + savedpath = "./dataset/tok_model/MaLaMo-vocab.txt" + if os.path.exists(savedpath): + self.tokenizer = tokenizers.BertWordPieceTokenizer( + "./dataset/tok_model/MaLaMo-vocab.txt", + ) + else: + self.tokenizer = tokenizers.BertWordPieceTokenizer() + self.tokenizer.train(files=text_paths, special_tokens=[ + "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], vocab_size=14200) + self.tokenizer.save_model("./dataset/tok_model", "MaLaMo") + self.tokenizer.enable_truncation(max_length=512) + self.pretokenizer = tokenizers.pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=True)]) + self.vocab = self.tokenizer.get_vocab() + self.mask_index = self.vocab.get("[MASK]") + self.pad_index = self.vocab.get("[PAD]") + self.eos_index = self.vocab.get("[SEP]") + self.sos_index = self.vocab.get("[CLS]") + self.unk_index = self.vocab.get("[UNK]") + + + def tokenize(self, sentence: str): + return self.tokenizer.encode(sentence).ids + + def getRandomTokenID(self): + return random.randint(6, len(self.vocab) - 1) + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def get_vocab_size(self): + return self.tokenizer.get_vocab_size() + \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..a84dcd5 --- /dev/null +++ b/src/main.py @@ -0,0 +1,77 @@ +"""main entry for training""" + +import argparse + +from torch.utils.data import DataLoader + +from model.bert import BERT +from trainer import BERTTrainer +from dataset import BERTDataset, BertTokenizer + + +def train(): + parser = argparse.ArgumentParser() + + parser.add_argument("-c", "--train_dataset", type=str, default="./dataset/corpus/train.txt", help="train dataset for train bert") + parser.add_argument("-t", "--test_dataset", type=str, default="./dataset/corpus/test.txt", help="test set for evaluate train set") + #parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab") + parser.add_argument("-o", "--output_path", type=str, default="./output/bert.model", help="ex)output/bert.model") + + parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model") + parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") + parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") + parser.add_argument("-s", "--seq_len", type=int, default=512, help="maximum sequence len") + + parser.add_argument("-b", "--batch_size", type=int, default=8, help="number of batch_size") + parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") + parser.add_argument("-w", "--num_workers", type=int, default=1, help="dataloader worker size") + + parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false") + parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") + parser.add_argument("--corpus_lines", type=int, default=5110, help="total number of lines in corpus") + parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") + parser.add_argument("--on_memory", type=bool, default=False, help="Loading on memory: true or false") + + parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") + parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") + + args = parser.parse_args() + + print("Loading Vocab") + tokenizer = BertTokenizer("./dataset/corpus") + vocab_size = tokenizer.get_vocab_size() + print("Vocab Size: ", vocab_size) + + print("Loading Train Dataset", args.train_dataset) + train_dataset = BERTDataset(args.train_dataset, tokenizer, seq_len=args.seq_len, + corpus_lines=args.corpus_lines, on_memory=args.on_memory) + + print("Loading Test Dataset", args.test_dataset) + test_dataset = BERTDataset(args.test_dataset, tokenizer, seq_len=args.seq_len, on_memory=args.on_memory) \ + if args.test_dataset is not None else None + + print("Creating Dataloader") + train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) + test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \ + if test_dataset is not None else None + + print("Building BERT model") + bert = BERT(vocab_size, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) + + print("Creating BERT Trainer") + trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader, + lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, + with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq) + + print("Training Start") + for epoch in range(args.epochs): + trainer.train(epoch) + trainer.save(epoch, args.output_path) + + if test_data_loader is not None: + trainer.test(epoch) + +if __name__ == "__main__": + train() diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..1a62230 --- /dev/null +++ b/src/model/__init__.py @@ -0,0 +1,2 @@ +from .bert import BERT +from .language_model import BERTLM \ No newline at end of file diff --git a/src/model/bert.py b/src/model/bert.py new file mode 100644 index 0000000..22dc5a8 --- /dev/null +++ b/src/model/bert.py @@ -0,0 +1,64 @@ +""" BERT CLASS MODuLE""" + +from .embedding import BERTEmbedding +import torch +import torch.nn as nn + + +class BERT(nn.Module): + """ + BERT model : Bidirectional Encoder Representations from Transformers. + """ + + def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): + """ + :param vocab_size: vocab_size of total words + :param hidden: BERT model hidden size + :param n_layers: numbers of Transformer blocks(layers) + :param attn_heads: number of attention heads + :param dropout: dropout rate + """ + + super().__init__() + self.hidden = hidden + self.n_layers = n_layers + self.attn_heads = attn_heads + + # paper noted they used 4*hidden_size for ff_network_hidden_size + self.feed_forward_hidden = hidden * 4 + + self.src_mask = None + + # embedding for BERT, sum of positional, segment, token embeddings + self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) + + # multi-layers transformer blocks, deep network + #self.transformer_blocks = nn.ModuleList( + # [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]) + encoder_layers = nn.TransformerEncoderLayer(hidden, attn_heads, self.feed_forward_hidden, dropout) + self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers) + + def _generate_square_subsequent_mask(self, sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def forward(self, x, segment_info, has_mask=True): + if has_mask: + if self.src_mask is None or self.src_mask.size(0) != len(x): + mask = self._generate_square_subsequent_mask(len(x)) + self.src_mask = mask + else: + self.src_mask = None + # attention masking for padded token + # torch.ByteTensor([batch_size, 1, seq_len, seq_len) + #mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) + #mask = mask.view(-1, 512, 512) + + #print(x) + + # embedding the indexed sequence to sequence of vectors + x = self.embedding(x, segment_info) + x = self.transformer_encoder(x, self.src_mask) + + return x \ No newline at end of file diff --git a/src/model/embedding/__init__.py b/src/model/embedding/__init__.py new file mode 100644 index 0000000..d9cc742 --- /dev/null +++ b/src/model/embedding/__init__.py @@ -0,0 +1 @@ +from .bert import BERTEmbedding \ No newline at end of file diff --git a/src/model/embedding/__pycache__/__init__.cpython-38.pyc b/src/model/embedding/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7b79af215b2ae452ab34af5a3707ca582670652 GIT binary patch literal 210 zcmWIL<>g`kg7SjY1bHC+7{oyaOhAqU5Elyoi4=wu#vF!R#wbQch7_h?22JLdj6h*c z##_8ju0bKLxk;%hDVcfcews|TSdvnUN{T?*S27f_04Xr>%h%Z|CbT%Us5r(kB{wrK zv$&)vu_U#$D8?l}ximL5ucSC8JGG=JGdnTH*U<-LYD{i^N@`9_D$MYh_;{eclK6PN Vg34PQHbB|ZoK!oIJ3a$30{{TUHh2I4 literal 0 HcmV?d00001 diff --git a/src/model/embedding/__pycache__/bert.cpython-38.pyc b/src/model/embedding/__pycache__/bert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc973959901c535cc9922e0c4ada204c4fbe8ea7 GIT binary patch literal 1698 zcmZuy&2Aev5GHq5+LdfMMuQ|RTC@au$p(gim?Pfy*~=c ztzAJq#XwFz7wySD=FkV|WAw^v4~1W$r_QXTm28v-N5k0}@|$ml^s8R4OJF?uWAx)+ zEkgdr#rp7I@ik2SEeJs(BB*dOnmdt0iExFNxw#j4wcg5Fc{^&?dOP#;APOk?h6rB- z=R^e3Iqtl5qpomXkgzkowAGY#cuaY+VKe&1?8i$7wQVsb)!8=u{gi*lB9UtSW?5Y$Rh-DXHs7 zjDFl_H7 zQWd=$Kg{gl4VaQu@Nv{Yc#+lI>l<&fs5$+mqMvmR% zmT)d8tlP#v4urd~K}jz_=N+(iXRc`1J{vQP+!Q{Rc=~IT!(bAtIA=eVNqoe0I+X)v zN&`iyEkNQ_-uA7_weVe)4m~khgB`po>Fr>l$_X$tU<-w**sM+n9ot<&vE9{Ca20&P z^HRv{Isj_a;>K)D=DomEaRxqgav1h(OIH)AY`}S1q=s{ajM`QWO(Dy+*8t}t&26yk zBbu>P|-a6mCvR0sA6J?m+ijZ_T@E$RV?_tbVcPqJ8R&fk-N51tN& zue4NpxGQpiqzz_Wso}FSsqh->;faLKKZ%D2yWe5E47Eyzb+*F>>F@|1kpARnbq9|4 z0H)pqLE2rqMYo;%AgA|lv|4|ee}efh2j-9SDNOwk#Ee{$8NGr|x}^B~04>hULG2gl zq)T_^UQy^9dKh|Y8zK~*uI{10biw)|{MbOtHx*u@sJEbvXYrBDEIm^nqjN1jL02Dc z>uIb6gM~&k#Ql`Q6S@BK^cOxMRX#X`J{GiybIG~wa_j|oKSAH*{7n^Si!VOsqD-LX z$ddXDgzYVAk;p8A4=X4ip$pehU~8+kg&XQ~u-L$qDjD~SLfwVsoCG9w n&IBefrvtCAZO@n ze9+r3M3u_kB|Z%Yk=2W0p{?kcGmF`Ot@#@J72&7-$kglx%L<7 z56~r0fC$t3BQ0FnVQu<5suDiKm`s2`Nyn_FHJQ+{Gl5S`keYE9t(Z(rr`B;>unGC( z%&bH>6J}4B-nb*&38~2jEWtXAyZC~=yHq>1yAJ4K0%<&_m;uss&SJW45tsk*+jBF2 z%@B_w9;;!zB3d<@P$V5}z6Vbf!SF8|g;~M?2f)`KJBAKrf7Q6U8cJn6&XYVTIX7-u zsOZqRBdH2)+&C@5(qC5cq)L?J0~HF>3RRd7r0G<-KB=S}N#kT;d1zQ#3=ET}L&J2E z8y4XUSCkwix|E8)tKwLye#HRl)b>&4#v;Ek3_2fHG}uyv`zQdTN>R76!K^HK-LvQ)`&80_ypd-(X}>p-h0 zXwU*VZ?9kuv^D&wJcPELuCJm3(q(I;MX!^)v`2chOGcajX`8J%XOSg#E_Q7b*~6hX zQPpsS#b}BOH5ud2n2>WC)0?(H^g!7@?$c@Vj!s4rH|uCKjspJHvSjc>|CEH+rdaV);A9`iS_MrOFRB6TrYEeQibW%({P~fSe%s^+mTFD&KIJ% zfnH@pQ5M~jrfv8hU0rJ|_`%woiz32XtLQLZ(@k4>ZfKsXjy1KW-?}-!@N|=Elcl#k R4e)_&T5Ulb>XBvQvEO|uCHepW literal 0 HcmV?d00001 diff --git a/src/model/embedding/__pycache__/segment.cpython-38.pyc b/src/model/embedding/__pycache__/segment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..475ef6840e1a01da42e889d186e923d8a6b9f92e GIT binary patch literal 607 zcmZ8eJx{|h5WP!M(Uw+XLSkWHKPZpj#ZLo{)0&I0#(J zHe+pEn<%gHnuvczRu|Wkwr{YAw+h{2c9VBo|g2o9w8Pui4t9wQfwZ zKgt?WZIgRloBTdaj(3iB4^GdMa-?*TcuH}T*r~nY_{l7=F7aW#IecFEZ+$q^hDMz5 z^RV11Ej5wr5;1pE$ht@L(4@t5$Md@WpYBW-1(E6T^(n~EOa_!Bjk}Y{@&(;fn0j98#B>(^b literal 0 HcmV?d00001 diff --git a/src/model/embedding/__pycache__/token.cpython-38.pyc b/src/model/embedding/__pycache__/token.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0748f5db0c5ebcebffdbc75bcfe6b1fe364e8b4a GIT binary patch literal 608 zcmY+BzfRmh5XNV`K6H1vOG!EkIyd(KQ4|FTLP7}zfyS~l81E9y`Ojt7Mg-A;M8$*9 z=V`X3;uWZv*^AII(tNw4wSW7~&hd0Q29E88Ir+!ylN75FbMcMSZL(<4SV4^`BJ6>V zw7LVW3|>UHDpeYPK@vancW_XMu7+?V+y;w*6g8xgj_=_N180<~dx#;8nKAGmkJ`Cu zcI#84b=mxWh6tcg0k9O=4>)CY>nt7I!Uf*oO{5w4ppeZnN*!3_(**s8MIKmG>Z>I3 z1J^H%^+TevDLbM>`PiB2%#SbIB0nWpUK>9);vnoF(T|9VDtC_P5nlFpK4-_ySeI?- znh~zEd1rc?eQ%4tHcjWUx#?^<&$FLfKel%dkFvVerpf|Rn#wxqakltt*Vv&n5!ng 1: + print("Using %d GPUS for BERT" % torch.cuda.device_count()) + self.model = nn.DataParallel(self.model, device_ids=cuda_devices) + + # Setting the train and test data loader + self.train_data = train_dataloader + self.test_data = test_dataloader + + # Setting the Adam optimizer with hyper-param + self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) + self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) + + # Using Negative Log Likelihood Loss function for predicting the masked_token + self.criterion = nn.NLLLoss(ignore_index=0) + + self.log_freq = log_freq + + print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) + + def train(self, epoch): + self.iteration(epoch, self.train_data) + + def test(self, epoch): + self.iteration(epoch, self.test_data, train=False) + + def iteration(self, epoch, data_loader, train=True): + """ + loop over the data_loader for training or testing + if on train status, backward operation is activated + and also auto save the model every peoch + + :param epoch: current epoch index + :param data_loader: torch.utils.data.DataLoader for iteration + :param train: boolean value of is train or test + :return: None + """ + str_code = "train" if train else "test" + + # Setting the tqdm progress bar + data_iter = tqdm.tqdm(enumerate(data_loader), + desc="EP_%s:%d" % (str_code, epoch), + total=len(data_loader), + bar_format="{l_bar}{r_bar}") + + avg_loss = 0.0 + total_correct = 0 + total_element = 0 + + for i, data in data_iter: + # 0. batch_data will be sent into the device(GPU or cpu) + data = {key: value.to(self.device) for key, value in data.items()} + + # 1. forward the next_sentence_prediction and masked_lm model + next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) + + # 2-1. NLL(negative log likelihood) loss of is_next classification result + next_loss = self.criterion(next_sent_output, data["is_next"]) + + # 2-2. NLLLoss of predicting masked token word + mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) + + # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure + loss = next_loss + mask_loss + + # 3. backward and optimization only in train + if train: + self.optim_schedule.zero_grad() + loss.backward() + self.optim_schedule.step_and_update_lr() + + # next sentence prediction accuracy + correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item() + avg_loss += loss.item() + total_correct += correct + total_element += data["is_next"].nelement() + + post_fix = { + "epoch": epoch, + "iter": i, + "avg_loss": avg_loss / (i + 1), + "avg_acc": total_correct / total_element * 100, + "loss": loss.item() + } + + if i % self.log_freq == 0: + data_iter.write(str(post_fix)) + + print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", + total_correct * 100.0 / total_element) + + def save(self, epoch, file_path="output/bert_trained.model"): + """ + Saving the current BERT model on file_path + + :param epoch: current epoch number + :param file_path: model output path which gonna be file_path+"ep%d" % epoch + :return: final_output_path + """ + output_path = file_path + ".ep%d" % epoch + torch.save(self.bert.cpu(), output_path) + self.bert.to(self.device) + print("EP:%d Model Saved on:" % epoch, output_path) + return output_path