diff --git a/src/dataset/dataset.py b/src/dataset/dataset.py index 12302b7..3aef1a5 100644 --- a/src/dataset/dataset.py +++ b/src/dataset/dataset.py @@ -37,11 +37,14 @@ class BERTDataset(Dataset): segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] bert_input = (t1 + t2)[:self.seq_len] bert_label = (t1_label + t2_label)[:self.seq_len] + input_mask = ([True for _ in range(len(bert_input))])[:self.seq_len] padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))] + padding_mask = [False for _ in range(self.seq_len - len(bert_input))] bert_input.extend(padding) bert_label.extend(padding) segment_label.extend(padding) + input_mask.extend(padding_mask) output = {"bert_input": bert_input, "bert_label": bert_label, diff --git a/src/main.py b/src/main.py index a84dcd5..1af5aa4 100644 --- a/src/main.py +++ b/src/main.py @@ -58,7 +58,7 @@ def train(): if test_dataset is not None else None print("Building BERT model") - bert = BERT(vocab_size, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) + bert = BERT(vocab_size, tokenizer.pad_index, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) print("Creating BERT Trainer") trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader, diff --git a/src/model/bert.py b/src/model/bert.py index 476013f..e099c8e 100644 --- a/src/model/bert.py +++ b/src/model/bert.py @@ -10,7 +10,7 @@ class BERT(nn.Module): BERT model : Bidirectional Encoder Representations from Transformers. """ - def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): + def __init__(self, vocab_size, pad_index, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): """ :param vocab_size: vocab_size of total words :param hidden: BERT model hidden size @@ -23,6 +23,7 @@ class BERT(nn.Module): self.hidden = hidden self.n_layers = n_layers self.attn_heads = attn_heads + self.pad_index = pad_index # paper noted they used 4*hidden_size for ff_network_hidden_size self.feed_forward_hidden = hidden * 4 @@ -40,17 +41,24 @@ class BERT(nn.Module): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask + + def make_src_mask(self, src): + return src.transpose(0, 1) == self.pad_index + def forward(self, x, segment_info, has_mask=True): - if has_mask: - if self.src_mask is None or self.src_mask.size(0) != len(x): - mask = self._generate_square_subsequent_mask(len(x)) - self.src_mask = mask - else: - self.src_mask = None + # self.src_mask = mask + # if has_mask: + # if self.src_mask is None or self.src_mask.size(0) != len(x): + # mask = self._generate_square_subsequent_mask(len(x)) + # self.src_mask = mask + # else: + # self.src_mask = None + + self.src_mask = self.make_src_mask(x) # embedding the indexed sequence to sequence of vectors x = self.embedding(x, segment_info) - x = self.transformer_encoder(x, self.src_mask) + x = self.transformer_encoder(x, src_key_padding_mask=self.src_mask) return x \ No newline at end of file diff --git a/src/model/embedding/__pycache__/bert.cpython-38.pyc b/src/model/embedding/__pycache__/bert.cpython-38.pyc index 724ef75..69a157f 100644 Binary files a/src/model/embedding/__pycache__/bert.cpython-38.pyc and b/src/model/embedding/__pycache__/bert.cpython-38.pyc differ