import os import numpy as np import tokenizers import torch from transformers import ( BatchEncoding, BertConfig, BertForPreTraining ) from .my_data_collator import MyDataCollatorForPreTraining model_file = os.path.join("./bert/pytorch_model.bin") tokenizer_file = os.path.join("./bert/tokenizer-inst.all.json") config_file = os.path.join('./bert/bert.json') # 禁用分词器多线程 # os.environ["TOKENIZERS_PARALLELISM"] = "false" # from my_data_collator import MyDataCollatorForPreTraining # model_file = os.path.join("./pytorch_model.bin") # tokenizer_file = os.path.join("./tokenizer-inst.all.json") # config_file = os.path.join('./bert.json') def load_model(): config = BertConfig.from_json_file(config_file) model = BertForPreTraining(config) # state_dict = torch.load(model_file) state_dict = torch.load(model_file, map_location='cpu') model.load_state_dict(state_dict) model.eval() tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) tokenizer.enable_padding( pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50 ) return model, tokenizer def process_input(inst, tokenizer): encoded_input = {} if isinstance(inst, str): # make a batch by myself inst = [inst for _ in range(8)] results = tokenizer.encode_batch(inst) encoded_input["input_ids"] = [result.ids for result in results] encoded_input["token_type_ids"] = [result.type_ids for result in results] encoded_input["special_tokens_mask"] = [ result.special_tokens_mask for result in results ] # print(encoded_input["input_ids"]) # use `np` rather than `pt` in case of reporting of error batch_output = BatchEncoding( encoded_input, tensor_type="np", prepend_batch_axis=False, ) # print(batch_output["input_ids"]) # NOTE: utilize the "special_tokens_mask", # only work if the input consists of single instruction length_mask = 1 - batch_output["special_tokens_mask"] data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False) model_input = data_collator([batch_output]) # print(model_input["input_ids"]) return model_input, length_mask def generate_inst_vec(inst, method="mean"): model, tokenizer = load_model() model_input, length_mask = process_input(inst, tokenizer) length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device) output = model(**model_input, output_hidden_states=True) if method == "cls": if isinstance(inst, str): return output.hidden_states[-1][0][0] elif isinstance(inst, list): return output.hidden_states[-1, :, 0, :] elif method == "mean": result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1) # print(result.shape) if isinstance(inst, str): result = torch.mean(result[0], dim=0) elif isinstance(inst, list): result = torch.mean(result, dim=1) return result elif method == "max": result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1) # print(result.shape) if isinstance(inst, str): result = torch.max(result[0], dim=0) elif isinstance(inst, list): result = torch.max(result, dim=1) return result def bb2vec(item): tmp = generate_inst_vec(item['opcode'], method="mean") return item['addr'], list(np.mean(tmp.detach().numpy(), axis=0).astype(float)) if __name__ == "__main__": temp = bb2vec(['adc byte [ ebp - 0x74 ] cl','mov dh 0x79','adc eax 1']) temp = list(temp) print(temp)