changed masking

This commit is contained in:
Setra Solofoniaina 2021-04-02 16:51:46 +03:00
parent d101df2f57
commit f00f7386f4
4 changed files with 20 additions and 9 deletions

View File

@ -37,11 +37,14 @@ class BERTDataset(Dataset):
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
bert_input = (t1 + t2)[:self.seq_len] bert_input = (t1 + t2)[:self.seq_len]
bert_label = (t1_label + t2_label)[:self.seq_len] bert_label = (t1_label + t2_label)[:self.seq_len]
input_mask = ([True for _ in range(len(bert_input))])[:self.seq_len]
padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))] padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))]
padding_mask = [False for _ in range(self.seq_len - len(bert_input))]
bert_input.extend(padding) bert_input.extend(padding)
bert_label.extend(padding) bert_label.extend(padding)
segment_label.extend(padding) segment_label.extend(padding)
input_mask.extend(padding_mask)
output = {"bert_input": bert_input, output = {"bert_input": bert_input,
"bert_label": bert_label, "bert_label": bert_label,

View File

@ -58,7 +58,7 @@ def train():
if test_dataset is not None else None if test_dataset is not None else None
print("Building BERT model") print("Building BERT model")
bert = BERT(vocab_size, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) bert = BERT(vocab_size, tokenizer.pad_index, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
print("Creating BERT Trainer") print("Creating BERT Trainer")
trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader, trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader,

View File

@ -10,7 +10,7 @@ class BERT(nn.Module):
BERT model : Bidirectional Encoder Representations from Transformers. BERT model : Bidirectional Encoder Representations from Transformers.
""" """
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): def __init__(self, vocab_size, pad_index, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
""" """
:param vocab_size: vocab_size of total words :param vocab_size: vocab_size of total words
:param hidden: BERT model hidden size :param hidden: BERT model hidden size
@ -23,6 +23,7 @@ class BERT(nn.Module):
self.hidden = hidden self.hidden = hidden
self.n_layers = n_layers self.n_layers = n_layers
self.attn_heads = attn_heads self.attn_heads = attn_heads
self.pad_index = pad_index
# paper noted they used 4*hidden_size for ff_network_hidden_size # paper noted they used 4*hidden_size for ff_network_hidden_size
self.feed_forward_hidden = hidden * 4 self.feed_forward_hidden = hidden * 4
@ -40,17 +41,24 @@ class BERT(nn.Module):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask return mask
def make_src_mask(self, src):
return src.transpose(0, 1) == self.pad_index
def forward(self, x, segment_info, has_mask=True): def forward(self, x, segment_info, has_mask=True):
if has_mask: # self.src_mask = mask
if self.src_mask is None or self.src_mask.size(0) != len(x): # if has_mask:
mask = self._generate_square_subsequent_mask(len(x)) # if self.src_mask is None or self.src_mask.size(0) != len(x):
self.src_mask = mask # mask = self._generate_square_subsequent_mask(len(x))
else: # self.src_mask = mask
self.src_mask = None # else:
# self.src_mask = None
self.src_mask = self.make_src_mask(x)
# embedding the indexed sequence to sequence of vectors # embedding the indexed sequence to sequence of vectors
x = self.embedding(x, segment_info) x = self.embedding(x, segment_info)
x = self.transformer_encoder(x, self.src_mask) x = self.transformer_encoder(x, src_key_padding_mask=self.src_mask)
return x return x