complete interface for downstream task

This commit is contained in:
zyr 2021-06-08 15:43:57 +08:00
parent b4810738b8
commit fb61bb2a7b
3 changed files with 139 additions and 14 deletions

View File

@ -1,15 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
TYPE_CHECKING, Sequence, Tuple, Union)
Any,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np import numpy as np
import tokenizers import tokenizers
@ -27,6 +18,8 @@ class MyDataCollatorForPreTraining:
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: Optional[int] = None
def __post_init__(self): def __post_init__(self):
# print(self.mlm, self.tokenzier.token_to_id("[MASK]"))
# input()
if self.mlm and self.tokenizer.token_to_id("[MASK]") is None: if self.mlm and self.tokenizer.token_to_id("[MASK]") is None:
raise ValueError( raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. " "This tokenizer does not have a mask token which is necessary for masked language modeling. "
@ -36,6 +29,7 @@ class MyDataCollatorForPreTraining:
def __call__( def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]], self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]],
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
# print(examples)
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], (dict, BatchEncoding)):
batch = pad( batch = pad(
@ -56,6 +50,9 @@ class MyDataCollatorForPreTraining:
batch["input_ids"], batch["labels"] = self.mask_tokens( batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask batch["input_ids"], special_tokens_mask=special_tokens_mask
) )
else:
batch["input_ids"] = torch.squeeze(batch["input_ids"], dim=0)
batch["token_type_ids"] = torch.squeeze(batch["token_type_ids"], dim=0)
return batch return batch
def mask_tokens( def mask_tokens(

View File

@ -187,8 +187,7 @@ def main():
# field="data", # field="data",
# ) # )
train_files = [ train_files = [
os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2)
for i in range(0, 128, 2)
] ]
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
raw_datasets = load_dataset( raw_datasets = load_dataset(
@ -199,7 +198,7 @@ def main():
# we use the tokenizer previously trained on the dataset above # we use the tokenizer previously trained on the dataset above
tokenizer = tokenizers.Tokenizer.from_file( tokenizer = tokenizers.Tokenizer.from_file(
os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json") os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
) )
# NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py` # NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py`
@ -254,6 +253,7 @@ def main():
batched=True, batched=True,
num_proc=args.preprocessing_num_workers, num_proc=args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=False,
) )
train_dataset = tokenized_datasets["train"] train_dataset = tokenized_datasets["train"]

128
obtain_inst_vec.py Normal file
View File

@ -0,0 +1,128 @@
import argparse
import logging
import math
import os
import random
import datasets
import numpy as np
import tokenizers
import torch
import transformers
from accelerate import Accelerator
from datasets import load_dataset
from torch.nn import DataParallel
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
BatchEncoding,
BertConfig,
BertForPreTraining,
DataCollatorForLanguageModeling,
SchedulerType,
get_scheduler,
set_seed,
)
from my_data_collator import MyDataCollatorForPreTraining
from process_data.utils import CURRENT_DATA_BASE
model_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.bin")
config_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.config.json")
tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
def load_model():
config = BertConfig.from_json_file(config_file)
model = BertForPreTraining(config)
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.eval()
print("Load model successfully !")
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32
)
print("Load tokenizer successfully !")
return model, tokenizer
def process_input(inst, tokenizer):
encoded_input = {}
if isinstance(inst, str):
# make a batch by myself
inst = [inst for _ in range(8)]
results = tokenizer.encode_batch(inst)
encoded_input["input_ids"] = [result.ids for result in results]
encoded_input["token_type_ids"] = [result.type_ids for result in results]
encoded_input["special_tokens_mask"] = [
result.special_tokens_mask for result in results
]
# print(encoded_input["input_ids"])
# use `np` rather than `pt` in case of reporting of error
batch_output = BatchEncoding(
encoded_input, tensor_type="np", prepend_batch_axis=False,
)
# print(batch_output["input_ids"])
# NOTE: utilize the "special_tokens_mask",
# only work if the input consists of single instruction
length_mask = 1 - batch_output["special_tokens_mask"]
data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False)
model_input = data_collator([batch_output])
# print(model_input["input_ids"])
return model_input, length_mask
def generate_inst_vec(inst, method="mean"):
model, tokenizer = load_model()
model_input, length_mask = process_input(inst, tokenizer)
length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device)
output = model(**model_input, output_hidden_states=True)
if method == "cls":
if isinstance(inst, str):
return output.hidden_states[-1][0][0]
elif isinstance(inst, list):
return output.hidden_states[-1, :, 0, :]
elif method == "mean":
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
# print(result.shape)
if isinstance(inst, str):
result = torch.mean(result[0], dim=0)
elif isinstance(inst, list):
result = torch.mean(result, dim=1)
return result
elif method == "max":
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
# print(result.shape)
if isinstance(inst, str):
result = torch.max(result[0], dim=0)
elif isinstance(inst, list):
result = torch.max(result, dim=1)
return result
def main():
inst = ["mov ebp esp" for _ in range(8)]
print(generate_inst_vec(inst).shape)
if __name__ == "__main__":
main()