refactor
This commit is contained in:
parent
42efbf86b6
commit
d988d3e4a3
@ -19,23 +19,6 @@ class BERTDataset(Dataset):
|
|||||||
|
|
||||||
self.corpus_lines = sum(1 for line in open(self.corpus_path))
|
self.corpus_lines = sum(1 for line in open(self.corpus_path))
|
||||||
|
|
||||||
# with open(corpus_path, "r", encoding=encoding) as f:
|
|
||||||
# if self.corpus_lines is None and not on_memory:
|
|
||||||
# for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
|
|
||||||
# self.corpus_lines += 1
|
|
||||||
|
|
||||||
# if on_memory:
|
|
||||||
# self.lines = [line[:-1].split("\t")
|
|
||||||
# for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
|
|
||||||
# self.corpus_lines = len(self.lines)
|
|
||||||
|
|
||||||
# if not on_memory:
|
|
||||||
# self.file = open(corpus_path, "r", encoding=encoding)
|
|
||||||
# self.random_file = open(corpus_path, "r", encoding=encoding)
|
|
||||||
|
|
||||||
# for _ in range(random.randint(0, self.corpus_lines if self.corpus_lines < 1000 else 1000)):
|
|
||||||
# self.random_file.__next__()
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.corpus_lines
|
return self.corpus_lines
|
||||||
|
|
||||||
@ -88,71 +71,19 @@ class BERTDataset(Dataset):
|
|||||||
output_label.append(0)
|
output_label.append(0)
|
||||||
return tokens, output_label
|
return tokens, output_label
|
||||||
|
|
||||||
# for i, token in enumerate(tokens):
|
|
||||||
# prob = random.random()
|
|
||||||
# if prob < 0.15:
|
|
||||||
# prob /= 0.15
|
|
||||||
|
|
||||||
# # 80% randomly change token to mask token
|
|
||||||
# if prob < 0.8:
|
|
||||||
# tokens[i] = self.vocab.mask_index
|
|
||||||
|
|
||||||
# # 10% randomly change token to random token
|
|
||||||
# elif prob < 0.9:
|
|
||||||
# tokens[i] = random.randrange(len(self.vocab))
|
|
||||||
|
|
||||||
# # 10% randomly change token to current token
|
|
||||||
# else:
|
|
||||||
# tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
|
|
||||||
|
|
||||||
# output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
|
|
||||||
|
|
||||||
# else:
|
|
||||||
# tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
|
|
||||||
# output_label.append(0)
|
|
||||||
|
|
||||||
# return tokens, output_label
|
|
||||||
|
|
||||||
def random_sent(self, index):
|
def random_sent(self, index):
|
||||||
t1, t2 = self.get_corpus_line(index)
|
t1, t2 = self.get_corpus_line(index)
|
||||||
# t1 = self.tokenizer.tokenize(t1)
|
|
||||||
# t2 = self.tokenizer.tokenize(t2)
|
|
||||||
# output_text, label(isNotNext:0, isNext:1)
|
# output_text, label(isNotNext:0, isNext:1)
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.5:
|
||||||
return t1, t2, 1
|
return t1, t2, 1
|
||||||
else:
|
else:
|
||||||
# rand_line = self.tokenizer.tokenize(self.get_random_line())
|
|
||||||
return t1, self.get_random_line(), 0
|
return t1, self.get_random_line(), 0
|
||||||
|
|
||||||
# def get_corpus_line(self, item):
|
|
||||||
# if self.on_memory:
|
|
||||||
# return self.lines[item][0], self.lines[item][1]
|
|
||||||
# else:
|
|
||||||
# line = self.file.__next__()
|
|
||||||
# if line is None:
|
|
||||||
# self.file.close()
|
|
||||||
# self.file = open(self.corpus_path, "r", encoding=self.encoding)
|
|
||||||
# line = self.file.__next__()
|
|
||||||
|
|
||||||
# t1, t2 = line[:-1].split("\t")
|
|
||||||
# return t1, t2
|
|
||||||
def get_corpus_line(self, item):
|
def get_corpus_line(self, item):
|
||||||
t1 = linecache.getline(self.corpus_path, item)
|
t1 = linecache.getline(self.corpus_path, item)
|
||||||
t2 = linecache.getline(self.corpus_path, item+1)
|
t2 = linecache.getline(self.corpus_path, item+1)
|
||||||
return t1, t2
|
return t1, t2
|
||||||
|
|
||||||
# def get_random_line(self):
|
|
||||||
# if self.on_memory:
|
|
||||||
# return self.lines[random.randrange(len(self.lines))][1]
|
|
||||||
|
|
||||||
# line = self.file.__next__()
|
|
||||||
# if line is None:
|
|
||||||
# self.file.close()
|
|
||||||
# self.file = open(self.corpus_path, "r", encoding=self.encoding)
|
|
||||||
# for _ in range(random.randint(0, self.corpus_lines if self.corpus_lines < 1000 else 1000)):
|
|
||||||
# self.random_file.__next__()
|
|
||||||
# line = self.random_file.__next__()
|
|
||||||
# return line[:-1].split("\t")[1]
|
|
||||||
|
|
||||||
def get_random_line(self):
|
def get_random_line(self):
|
||||||
return linecache.getline(self.corpus_path, random.randint(1, self.corpus_lines))
|
return linecache.getline(self.corpus_path, random.randint(1, self.corpus_lines))
|
@ -12,16 +12,16 @@ class BertTokenizer():
|
|||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.path = path
|
self.path = path
|
||||||
text_paths = [str(x) for x in Path("./dataset/corpus/").glob("**/*.txt")]
|
text_paths = [str(x) for x in Path("./dataset/corpus/").glob("**/*.txt")]
|
||||||
savedpath = "./dataset/tok_model/MaLaMo-vocab.txt"
|
savedpath = "./dataset/tok_model/MALBERT-vocab.txt"
|
||||||
if os.path.exists(savedpath):
|
if os.path.exists(savedpath):
|
||||||
self.tokenizer = tokenizers.BertWordPieceTokenizer(
|
self.tokenizer = tokenizers.BertWordPieceTokenizer(
|
||||||
"./dataset/tok_model/MaLaMo-vocab.txt",
|
"./dataset/tok_model/MALBERT-vocab.txt",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.tokenizer = tokenizers.BertWordPieceTokenizer()
|
self.tokenizer = tokenizers.BertWordPieceTokenizer()
|
||||||
self.tokenizer.train(files=text_paths, special_tokens=[
|
self.tokenizer.train(files=text_paths, special_tokens=[
|
||||||
"[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], vocab_size=14200)
|
"[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], vocab_size=14200)
|
||||||
self.tokenizer.save_model("./dataset/tok_model", "MaLaMo")
|
self.tokenizer.save_model("./dataset/tok_model", "MALBERT")
|
||||||
self.tokenizer.enable_truncation(max_length=512)
|
self.tokenizer.enable_truncation(max_length=512)
|
||||||
self.pretokenizer = tokenizers.pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=True)])
|
self.pretokenizer = tokenizers.pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=True)])
|
||||||
self.vocab = self.tokenizer.get_vocab()
|
self.vocab = self.tokenizer.get_vocab()
|
||||||
|
@ -33,8 +33,6 @@ class BERT(nn.Module):
|
|||||||
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
|
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
|
||||||
|
|
||||||
# multi-layers transformer blocks, deep network
|
# multi-layers transformer blocks, deep network
|
||||||
#self.transformer_blocks = nn.ModuleList(
|
|
||||||
# [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
|
|
||||||
encoder_layers = nn.TransformerEncoderLayer(hidden, attn_heads, self.feed_forward_hidden, dropout, activation="gelu")
|
encoder_layers = nn.TransformerEncoderLayer(hidden, attn_heads, self.feed_forward_hidden, dropout, activation="gelu")
|
||||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
|
||||||
|
|
||||||
@ -50,12 +48,6 @@ class BERT(nn.Module):
|
|||||||
self.src_mask = mask
|
self.src_mask = mask
|
||||||
else:
|
else:
|
||||||
self.src_mask = None
|
self.src_mask = None
|
||||||
# attention masking for padded token
|
|
||||||
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
|
|
||||||
#mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
|
||||||
#mask = mask.view(-1, 512, 512)
|
|
||||||
|
|
||||||
#print(x)
|
|
||||||
|
|
||||||
# embedding the indexed sequence to sequence of vectors
|
# embedding the indexed sequence to sequence of vectors
|
||||||
x = self.embedding(x, segment_info)
|
x = self.embedding(x, segment_info)
|
||||||
|
@ -17,8 +17,6 @@ class BERTTrainer:
|
|||||||
1. Masked Language Model : 3.3.1 Task #1: Masked LM
|
1. Masked Language Model : 3.3.1 Task #1: Masked LM
|
||||||
2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction
|
2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction
|
||||||
|
|
||||||
please check the details on README.md with simple example.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bert: BERT, vocab_size: int,
|
def __init__(self, bert: BERT, vocab_size: int,
|
||||||
|
Loading…
Reference in New Issue
Block a user