changing segment embedding size
This commit is contained in:
parent
d988d3e4a3
commit
b910eeb4d0
Binary file not shown.
Binary file not shown.
@ -24,13 +24,11 @@ class BERTEmbedding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
|
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
|
||||||
self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
|
self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
|
||||||
#self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
|
self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
|
||||||
self.segment = nn.Embedding(8, self.token.embedding_dim, padding_idx=0)
|
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
self.embed_size = embed_size
|
self.embed_size = embed_size
|
||||||
|
|
||||||
def forward(self, sequence, segment_label):
|
def forward(self, sequence, segment_label):
|
||||||
#print(segment_label.shape)
|
|
||||||
#segmented = self.segment(segment_label)
|
|
||||||
x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
|
x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
|
||||||
|
print(x.shape)
|
||||||
return self.dropout(x)
|
return self.dropout(x)
|
@ -3,4 +3,4 @@ import torch.nn as nn
|
|||||||
|
|
||||||
class SegmentEmbedding(nn.Embedding):
|
class SegmentEmbedding(nn.Embedding):
|
||||||
def __init__(self, embed_size=512):
|
def __init__(self, embed_size=512):
|
||||||
super().__init__(3, embed_size, padding_idx=0)
|
super().__init__(4, embed_size, padding_idx=0)
|
Loading…
Reference in New Issue
Block a user