changed masking
This commit is contained in:
parent
d101df2f57
commit
f00f7386f4
@ -37,11 +37,14 @@ class BERTDataset(Dataset):
|
|||||||
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
|
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
|
||||||
bert_input = (t1 + t2)[:self.seq_len]
|
bert_input = (t1 + t2)[:self.seq_len]
|
||||||
bert_label = (t1_label + t2_label)[:self.seq_len]
|
bert_label = (t1_label + t2_label)[:self.seq_len]
|
||||||
|
input_mask = ([True for _ in range(len(bert_input))])[:self.seq_len]
|
||||||
|
|
||||||
padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))]
|
padding = [self.tokenizer.pad_index for _ in range(self.seq_len - len(bert_input))]
|
||||||
|
padding_mask = [False for _ in range(self.seq_len - len(bert_input))]
|
||||||
bert_input.extend(padding)
|
bert_input.extend(padding)
|
||||||
bert_label.extend(padding)
|
bert_label.extend(padding)
|
||||||
segment_label.extend(padding)
|
segment_label.extend(padding)
|
||||||
|
input_mask.extend(padding_mask)
|
||||||
|
|
||||||
output = {"bert_input": bert_input,
|
output = {"bert_input": bert_input,
|
||||||
"bert_label": bert_label,
|
"bert_label": bert_label,
|
||||||
|
@ -58,7 +58,7 @@ def train():
|
|||||||
if test_dataset is not None else None
|
if test_dataset is not None else None
|
||||||
|
|
||||||
print("Building BERT model")
|
print("Building BERT model")
|
||||||
bert = BERT(vocab_size, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
|
bert = BERT(vocab_size, tokenizer.pad_index, hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
|
||||||
|
|
||||||
print("Creating BERT Trainer")
|
print("Creating BERT Trainer")
|
||||||
trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader,
|
trainer = BERTTrainer(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=test_data_loader,
|
||||||
|
@ -10,7 +10,7 @@ class BERT(nn.Module):
|
|||||||
BERT model : Bidirectional Encoder Representations from Transformers.
|
BERT model : Bidirectional Encoder Representations from Transformers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
|
def __init__(self, vocab_size, pad_index, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
|
||||||
"""
|
"""
|
||||||
:param vocab_size: vocab_size of total words
|
:param vocab_size: vocab_size of total words
|
||||||
:param hidden: BERT model hidden size
|
:param hidden: BERT model hidden size
|
||||||
@ -23,6 +23,7 @@ class BERT(nn.Module):
|
|||||||
self.hidden = hidden
|
self.hidden = hidden
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
self.attn_heads = attn_heads
|
self.attn_heads = attn_heads
|
||||||
|
self.pad_index = pad_index
|
||||||
|
|
||||||
# paper noted they used 4*hidden_size for ff_network_hidden_size
|
# paper noted they used 4*hidden_size for ff_network_hidden_size
|
||||||
self.feed_forward_hidden = hidden * 4
|
self.feed_forward_hidden = hidden * 4
|
||||||
@ -40,17 +41,24 @@ class BERT(nn.Module):
|
|||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
def make_src_mask(self, src):
|
||||||
|
return src.transpose(0, 1) == self.pad_index
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, segment_info, has_mask=True):
|
def forward(self, x, segment_info, has_mask=True):
|
||||||
if has_mask:
|
# self.src_mask = mask
|
||||||
if self.src_mask is None or self.src_mask.size(0) != len(x):
|
# if has_mask:
|
||||||
mask = self._generate_square_subsequent_mask(len(x))
|
# if self.src_mask is None or self.src_mask.size(0) != len(x):
|
||||||
self.src_mask = mask
|
# mask = self._generate_square_subsequent_mask(len(x))
|
||||||
else:
|
# self.src_mask = mask
|
||||||
self.src_mask = None
|
# else:
|
||||||
|
# self.src_mask = None
|
||||||
|
|
||||||
|
self.src_mask = self.make_src_mask(x)
|
||||||
|
|
||||||
# embedding the indexed sequence to sequence of vectors
|
# embedding the indexed sequence to sequence of vectors
|
||||||
x = self.embedding(x, segment_info)
|
x = self.embedding(x, segment_info)
|
||||||
x = self.transformer_encoder(x, self.src_mask)
|
x = self.transformer_encoder(x, src_key_padding_mask=self.src_mask)
|
||||||
|
|
||||||
return x
|
return x
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user