Inst2Vec/obtain_inst_vec.py

129 lines
3.7 KiB
Python

import argparse
import logging
import math
import os
import random
import datasets
import numpy as np
import tokenizers
import torch
import transformers
from accelerate import Accelerator
from datasets import load_dataset
from torch.nn import DataParallel
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
BatchEncoding,
BertConfig,
BertForPreTraining,
DataCollatorForLanguageModeling,
SchedulerType,
get_scheduler,
set_seed,
)
from my_data_collator import MyDataCollatorForPreTraining
from process_data.utils import CURRENT_DATA_BASE
model_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.bin")
config_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.config.json")
tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
def load_model():
config = BertConfig.from_json_file(config_file)
model = BertForPreTraining(config)
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.eval()
print("Load model successfully !")
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32
)
print("Load tokenizer successfully !")
return model, tokenizer
def process_input(inst, tokenizer):
encoded_input = {}
if isinstance(inst, str):
# make a batch by myself
inst = [inst for _ in range(8)]
results = tokenizer.encode_batch(inst)
encoded_input["input_ids"] = [result.ids for result in results]
encoded_input["token_type_ids"] = [result.type_ids for result in results]
encoded_input["special_tokens_mask"] = [
result.special_tokens_mask for result in results
]
# print(encoded_input["input_ids"])
# use `np` rather than `pt` in case of reporting of error
batch_output = BatchEncoding(
encoded_input, tensor_type="np", prepend_batch_axis=False,
)
# print(batch_output["input_ids"])
# NOTE: utilize the "special_tokens_mask",
# only work if the input consists of single instruction
length_mask = 1 - batch_output["special_tokens_mask"]
data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False)
model_input = data_collator([batch_output])
# print(model_input["input_ids"])
return model_input, length_mask
def generate_inst_vec(inst, method="mean"):
model, tokenizer = load_model()
model_input, length_mask = process_input(inst, tokenizer)
length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device)
output = model(**model_input, output_hidden_states=True)
if method == "cls":
if isinstance(inst, str):
return output.hidden_states[-1][0][0]
elif isinstance(inst, list):
return output.hidden_states[-1, :, 0, :]
elif method == "mean":
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
# print(result.shape)
if isinstance(inst, str):
result = torch.mean(result[0], dim=0)
elif isinstance(inst, list):
result = torch.mean(result, dim=1)
return result
elif method == "max":
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
# print(result.shape)
if isinstance(inst, str):
result = torch.max(result[0], dim=0)
elif isinstance(inst, list):
result = torch.max(result, dim=1)
return result
def main():
inst = ["mov ebp esp" for _ in range(8)]
print(generate_inst_vec(inst).shape)
if __name__ == "__main__":
main()