changed masking
This commit is contained in:
parent
d101df2f57
commit
f00f7386f4
@ -37,11 +37,14 @@ class BERTDataset(Dataset):
|
||||
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
|
||||
bert_input = (t1 + t2)[:self.seq_len]
|
||||
bert_label = (t1_label + t2_label)[:self.seq_len]
|
||||
input_mask = ([True for _ in range(len(bert_input))])[:self.seq_len]
|
||||
|
||||
padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))]
|
||||
padding_mask = [False for _ in range(self.seq_len - len(bert_input))]
|
||||
bert_input.extend(padding)
|
||||
bert_label.extend(padding)
|
||||
segment_label.extend(padding)
|
||||
input_mask.extend(padding_mask)
|
||||
|
||||
output = {"bert_input": bert_input,
|
||||
"bert_label": bert_label,
|
||||
|
@ -58,7 +58,7 @@ def train():
|
||||
if test_dataset is not None else None
|
||||
|
||||
print("Building BERT model")
|
||||
bert = BERT(vocab_size, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
|
||||
bert = BERT(vocab_size, tokenizer.pad_index, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
|
||||
|
||||
print("Creating BERT Trainer")
|
||||
trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader,
|
||||
|
@ -10,7 +10,7 @@ class BERT(nn.Module):
|
||||
BERT model : Bidirectional Encoder Representations from Transformers.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
|
||||
def __init__(self, vocab_size, pad_index, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
|
||||
"""
|
||||
:param vocab_size: vocab_size of total words
|
||||
:param hidden: BERT model hidden size
|
||||
@ -23,6 +23,7 @@ class BERT(nn.Module):
|
||||
self.hidden = hidden
|
||||
self.n_layers = n_layers
|
||||
self.attn_heads = attn_heads
|
||||
self.pad_index = pad_index
|
||||
|
||||
# paper noted they used 4*hidden_size for ff_network_hidden_size
|
||||
self.feed_forward_hidden = hidden * 4
|
||||
@ -41,16 +42,23 @@ class BERT(nn.Module):
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
return mask
|
||||
|
||||
def make_src_mask(self, src):
|
||||
return src.transpose(0, 1) == self.pad_index
|
||||
|
||||
|
||||
def forward(self, x, segment_info, has_mask=True):
|
||||
if has_mask:
|
||||
if self.src_mask is None or self.src_mask.size(0) != len(x):
|
||||
mask = self._generate_square_subsequent_mask(len(x))
|
||||
self.src_mask = mask
|
||||
else:
|
||||
self.src_mask = None
|
||||
# self.src_mask = mask
|
||||
# if has_mask:
|
||||
# if self.src_mask is None or self.src_mask.size(0) != len(x):
|
||||
# mask = self._generate_square_subsequent_mask(len(x))
|
||||
# self.src_mask = mask
|
||||
# else:
|
||||
# self.src_mask = None
|
||||
|
||||
self.src_mask = self.make_src_mask(x)
|
||||
|
||||
# embedding the indexed sequence to sequence of vectors
|
||||
x = self.embedding(x, segment_info)
|
||||
x = self.transformer_encoder(x, self.src_mask)
|
||||
x = self.transformer_encoder(x, src_key_padding_mask=self.src_mask)
|
||||
|
||||
return x
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user