diff --git a/src/model/embedding/__pycache__/bert.cpython-38.pyc b/src/model/embedding/__pycache__/bert.cpython-38.pyc index bc97395..724ef75 100644 Binary files a/src/model/embedding/__pycache__/bert.cpython-38.pyc and b/src/model/embedding/__pycache__/bert.cpython-38.pyc differ diff --git a/src/model/embedding/__pycache__/segment.cpython-38.pyc b/src/model/embedding/__pycache__/segment.cpython-38.pyc index 475ef68..b0768e7 100644 Binary files a/src/model/embedding/__pycache__/segment.cpython-38.pyc and b/src/model/embedding/__pycache__/segment.cpython-38.pyc differ diff --git a/src/model/embedding/bert.py b/src/model/embedding/bert.py index 5984460..8d13edd 100644 --- a/src/model/embedding/bert.py +++ b/src/model/embedding/bert.py @@ -24,13 +24,11 @@ class BERTEmbedding(nn.Module): super().__init__() self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) self.position = PositionalEmbedding(d_model=self.token.embedding_dim) - #self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) - self.segment = nn.Embedding(8, self.token.embedding_dim, padding_idx=0) + self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) self.dropout = nn.Dropout(p=dropout) self.embed_size = embed_size def forward(self, sequence, segment_label): - #print(segment_label.shape) - #segmented = self.segment(segment_label) x = self.token(sequence) + self.position(sequence) + self.segment(segment_label) + print(x.shape) return self.dropout(x) \ No newline at end of file diff --git a/src/model/embedding/segment.py b/src/model/embedding/segment.py index 110a5bf..315de43 100644 --- a/src/model/embedding/segment.py +++ b/src/model/embedding/segment.py @@ -3,4 +3,4 @@ import torch.nn as nn class SegmentEmbedding(nn.Embedding): def __init__(self, embed_size=512): - super().__init__(3, embed_size, padding_idx=0) \ No newline at end of file + super().__init__(4, embed_size, padding_idx=0) \ No newline at end of file