changing segment embedding size

This commit is contained in:
Setra Solofoniaina 2021-04-02 15:48:10 +03:00
parent d988d3e4a3
commit b910eeb4d0
4 changed files with 3 additions and 5 deletions

View File

@ -24,13 +24,11 @@ class BERTEmbedding(nn.Module):
super().__init__() super().__init__()
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
self.position = PositionalEmbedding(d_model=self.token.embedding_dim) self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
#self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
self.segment = nn.Embedding(8, self.token.embedding_dim, padding_idx=0)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.embed_size = embed_size self.embed_size = embed_size
def forward(self, sequence, segment_label): def forward(self, sequence, segment_label):
#print(segment_label.shape)
#segmented = self.segment(segment_label)
x = self.token(sequence) + self.position(sequence) + self.segment(segment_label) x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
print(x.shape)
return self.dropout(x) return self.dropout(x)

View File

@ -3,4 +3,4 @@ import torch.nn as nn
class SegmentEmbedding(nn.Embedding): class SegmentEmbedding(nn.Embedding):
def __init__(self, embed_size=512): def __init__(self, embed_size=512):
super().__init__(3, embed_size, padding_idx=0) super().__init__(4, embed_size, padding_idx=0)