2024-04-15 20:01:20 +08:00
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import tokenizers
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from transformers import (
|
|
|
|
BatchEncoding,
|
|
|
|
BertConfig,
|
|
|
|
BertForPreTraining
|
|
|
|
)
|
|
|
|
|
|
|
|
from .my_data_collator import MyDataCollatorForPreTraining
|
|
|
|
model_file = os.path.join("./bert/pytorch_model.bin")
|
|
|
|
tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json")
|
|
|
|
config_file = os.path.join('./bert/bert.json')
|
|
|
|
|
2024-04-17 15:54:00 +08:00
|
|
|
|
|
|
|
# 禁用分词器多线程
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
2024-04-15 20:01:20 +08:00
|
|
|
# from my_data_collator import MyDataCollatorForPreTraining
|
|
|
|
# model_file = os.path.join("./pytorch_model.bin")
|
|
|
|
# tokenizer_file = os.path.join("./tokenizer-inst.all.json")
|
|
|
|
# config_file = os.path.join('./bert.json')
|
|
|
|
|
|
|
|
|
|
|
|
def load_model():
|
|
|
|
config = BertConfig.from_json_file(config_file)
|
|
|
|
model = BertForPreTraining(config)
|
2024-04-17 15:54:00 +08:00
|
|
|
state_dict = torch.load(model_file, map_location='cpu')
|
2024-04-15 20:01:20 +08:00
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
|
|
|
tokenizer.enable_padding(
|
|
|
|
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50
|
|
|
|
)
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def process_input(inst, tokenizer):
|
|
|
|
encoded_input = {}
|
|
|
|
if isinstance(inst, str):
|
|
|
|
# make a batch by myself
|
|
|
|
inst = [inst for _ in range(8)]
|
|
|
|
results = tokenizer.encode_batch(inst)
|
|
|
|
encoded_input["input_ids"] = [result.ids for result in results]
|
|
|
|
encoded_input["token_type_ids"] = [result.type_ids for result in results]
|
|
|
|
encoded_input["special_tokens_mask"] = [
|
|
|
|
result.special_tokens_mask for result in results
|
|
|
|
]
|
|
|
|
|
|
|
|
# print(encoded_input["input_ids"])
|
|
|
|
|
|
|
|
# use `np` rather than `pt` in case of reporting of error
|
|
|
|
batch_output = BatchEncoding(
|
|
|
|
encoded_input, tensor_type="np", prepend_batch_axis=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
# print(batch_output["input_ids"])
|
|
|
|
|
|
|
|
# NOTE: utilize the "special_tokens_mask",
|
|
|
|
# only work if the input consists of single instruction
|
|
|
|
length_mask = 1 - batch_output["special_tokens_mask"]
|
|
|
|
|
|
|
|
data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False)
|
|
|
|
|
|
|
|
model_input = data_collator([batch_output])
|
|
|
|
|
|
|
|
# print(model_input["input_ids"])
|
|
|
|
|
|
|
|
return model_input, length_mask
|
|
|
|
|
|
|
|
|
|
|
|
def generate_inst_vec(inst, method="mean"):
|
|
|
|
model, tokenizer = load_model()
|
|
|
|
|
|
|
|
model_input, length_mask = process_input(inst, tokenizer)
|
|
|
|
length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device)
|
|
|
|
|
|
|
|
output = model(**model_input, output_hidden_states=True)
|
|
|
|
|
|
|
|
if method == "cls":
|
|
|
|
if isinstance(inst, str):
|
|
|
|
return output.hidden_states[-1][0][0]
|
|
|
|
elif isinstance(inst, list):
|
|
|
|
return output.hidden_states[-1, :, 0, :]
|
|
|
|
elif method == "mean":
|
|
|
|
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
|
|
|
|
# print(result.shape)
|
|
|
|
if isinstance(inst, str):
|
|
|
|
result = torch.mean(result[0], dim=0)
|
|
|
|
elif isinstance(inst, list):
|
|
|
|
result = torch.mean(result, dim=1)
|
|
|
|
return result
|
|
|
|
elif method == "max":
|
|
|
|
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
|
|
|
|
# print(result.shape)
|
|
|
|
if isinstance(inst, str):
|
|
|
|
result = torch.max(result[0], dim=0)
|
|
|
|
elif isinstance(inst, list):
|
|
|
|
result = torch.max(result, dim=1)
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def bb2vec(inst):
|
|
|
|
tmp = generate_inst_vec(inst, method="mean")
|
2024-04-17 15:54:00 +08:00
|
|
|
return list(np.mean(tmp.detach().numpy(), axis=0).astype(float))
|
2024-04-15 20:01:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
temp = bb2vec(['adc byte [ ebp - 0x74 ] cl','mov dh 0x79','adc eax 1'])
|
|
|
|
temp = list(temp)
|
|
|
|
print(temp)
|
|
|
|
|