changing segment embedding size
This commit is contained in:
parent
d988d3e4a3
commit
b910eeb4d0
Binary file not shown.
Binary file not shown.
@ -24,13 +24,11 @@ class BERTEmbedding(nn.Module):
|
||||
super().__init__()
|
||||
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
|
||||
self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
|
||||
#self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
|
||||
self.segment = nn.Embedding(8, self.token.embedding_dim, padding_idx=0)
|
||||
self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.embed_size = embed_size
|
||||
|
||||
def forward(self, sequence, segment_label):
|
||||
#print(segment_label.shape)
|
||||
#segmented = self.segment(segment_label)
|
||||
x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
|
||||
print(x.shape)
|
||||
return self.dropout(x)
|
@ -3,4 +3,4 @@ import torch.nn as nn
|
||||
|
||||
class SegmentEmbedding(nn.Embedding):
|
||||
def __init__(self, embed_size=512):
|
||||
super().__init__(3, embed_size, padding_idx=0)
|
||||
super().__init__(4, embed_size, padding_idx=0)
|
Loading…
Reference in New Issue
Block a user