125 lines
3.3 KiB
Python
125 lines
3.3 KiB
Python
|
import argparse
|
||
|
import os
|
||
|
from itertools import chain
|
||
|
|
||
|
from datasets import load_dataset
|
||
|
from tokenizers import Tokenizer
|
||
|
from tokenizers.models import WordLevel
|
||
|
from tokenizers.pre_tokenizers import Whitespace
|
||
|
from tokenizers.processors import TemplateProcessing
|
||
|
from tokenizers.trainers import WordLevelTrainer
|
||
|
|
||
|
from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
|
||
|
|
||
|
BASE_PATH = "/home/ming/malware/inst2vec_bert/bert/"
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Train a word level tokenizer for ASM_BERT"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--vocab_size",
|
||
|
type=int,
|
||
|
default=2000,
|
||
|
help="The size of vocabulary used to train the tokenizer.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--padding_length",
|
||
|
type=int,
|
||
|
default=32,
|
||
|
help="The length will be padded to by the tokenizer.",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
return args
|
||
|
|
||
|
|
||
|
def train_tokenizer(args, dataset):
|
||
|
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
|
||
|
tokenizer.pre_tokenizer = Whitespace()
|
||
|
|
||
|
trainer = WordLevelTrainer(
|
||
|
vocab_size=args.vocab_size,
|
||
|
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||
|
)
|
||
|
|
||
|
# def batch_iterator(batch_size=1000):
|
||
|
# for i in range(0, len(dataset), batch_size):
|
||
|
# yield dataset[i : i + batch_size]["text"]
|
||
|
|
||
|
# tokenizer.train_from_iterator(
|
||
|
# batch_iterator(), trainer=trainer, length=len(dataset)
|
||
|
# )
|
||
|
|
||
|
tokenizer.train_from_iterator(dataset, trainer)
|
||
|
|
||
|
return tokenizer
|
||
|
|
||
|
|
||
|
def save_tokenizer(tokenizer, tokenizer_file):
|
||
|
tokenizer.save(tokenizer_file)
|
||
|
|
||
|
|
||
|
def load_tokenizer(tokenizer_file):
|
||
|
if not os.path.exists(tokenizer_file):
|
||
|
print("{} doesn't exist, will be retrained...".format(tokenizer_file))
|
||
|
return None
|
||
|
print("The tokenizer has already been trained.")
|
||
|
return Tokenizer.from_file(tokenizer_file)
|
||
|
|
||
|
|
||
|
def post_process(tokenizer):
|
||
|
tokenizer.post_processor = TemplateProcessing(
|
||
|
single="[CLS] $A [SEP]",
|
||
|
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
|
||
|
special_tokens=[
|
||
|
("[CLS]", tokenizer.token_to_id("[CLS]")),
|
||
|
("[SEP]", tokenizer.token_to_id("[SEP]")),
|
||
|
],
|
||
|
)
|
||
|
return tokenizer
|
||
|
|
||
|
|
||
|
def tokenizer_encode(tokenizer, data):
|
||
|
return tokenizer.encode_batch(data)
|
||
|
|
||
|
|
||
|
def main(tokenizer_file=""):
|
||
|
args = parse_args()
|
||
|
|
||
|
tokenizer = load_tokenizer(tokenizer_file)
|
||
|
|
||
|
if tokenizer is not None:
|
||
|
return
|
||
|
|
||
|
# json_files = [
|
||
|
# os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) for i in range(128)
|
||
|
# ]
|
||
|
# dataset = load_dataset("json", data_files=json_files, field="data")
|
||
|
|
||
|
text_files = [
|
||
|
os.path.join(ORIGINAL_DATA_BASE, "inst.1.{}.txt".format(group))
|
||
|
for group in ["pos", "neg"]
|
||
|
]
|
||
|
|
||
|
dataset = []
|
||
|
for f in text_files:
|
||
|
dataset += read_file(f)
|
||
|
|
||
|
dataset = [tuple(sent[:-1].split("\t")) for sent in dataset]
|
||
|
|
||
|
print("Trainging tokenizer...")
|
||
|
tokenizer = train_tokenizer(args, chain.from_iterable(dataset))
|
||
|
tokenizer = post_process(tokenizer)
|
||
|
tokenizer.enable_padding(
|
||
|
pad_id=tokenizer.token_to_id("[PAD]"),
|
||
|
pad_token="[PAD]",
|
||
|
length=args.padding_length,
|
||
|
)
|
||
|
save_tokenizer(tokenizer, tokenizer_file)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json"))
|