diff --git a/my_data_collator.py b/my_data_collator.py index fe3f21c..6420dca 100644 --- a/my_data_collator.py +++ b/my_data_collator.py @@ -1,15 +1,6 @@ from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, -) +from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, + Sequence, Tuple, Union) import numpy as np import tokenizers @@ -27,6 +18,8 @@ class MyDataCollatorForPreTraining: pad_to_multiple_of: Optional[int] = None def __post_init__(self): + # print(self.mlm, self.tokenzier.token_to_id("[MASK]")) + # input() if self.mlm and self.tokenizer.token_to_id("[MASK]") is None: raise ValueError( "This tokenizer does not have a mask token which is necessary for masked language modeling. " @@ -36,6 +29,7 @@ class MyDataCollatorForPreTraining: def __call__( self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]], ) -> Dict[str, torch.Tensor]: + # print(examples) # Handle dict or lists with proper padding and conversion to tensor. if isinstance(examples[0], (dict, BatchEncoding)): batch = pad( @@ -56,6 +50,9 @@ class MyDataCollatorForPreTraining: batch["input_ids"], batch["labels"] = self.mask_tokens( batch["input_ids"], special_tokens_mask=special_tokens_mask ) + else: + batch["input_ids"] = torch.squeeze(batch["input_ids"], dim=0) + batch["token_type_ids"] = torch.squeeze(batch["token_type_ids"], dim=0) return batch def mask_tokens( diff --git a/my_run_mlm_no_trainer.py b/my_run_mlm_no_trainer.py index fbc570e..15261b8 100644 --- a/my_run_mlm_no_trainer.py +++ b/my_run_mlm_no_trainer.py @@ -187,8 +187,7 @@ def main(): # field="data", # ) train_files = [ - os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) - for i in range(0, 128, 2) + os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2) ] valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" raw_datasets = load_dataset( @@ -199,7 +198,7 @@ def main(): # we use the tokenizer previously trained on the dataset above tokenizer = tokenizers.Tokenizer.from_file( - os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json") + os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json") ) # NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py` @@ -254,6 +253,7 @@ def main(): batched=True, num_proc=args.preprocessing_num_workers, remove_columns=column_names, + load_from_cache_file=False, ) train_dataset = tokenized_datasets["train"] diff --git a/obtain_inst_vec.py b/obtain_inst_vec.py new file mode 100644 index 0000000..174f53b --- /dev/null +++ b/obtain_inst_vec.py @@ -0,0 +1,128 @@ +import argparse +import logging +import math +import os +import random + +import datasets +import numpy as np +import tokenizers +import torch +import transformers +from accelerate import Accelerator +from datasets import load_dataset +from torch.nn import DataParallel +from torch.utils.data.dataloader import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AdamW, + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + BatchEncoding, + BertConfig, + BertForPreTraining, + DataCollatorForLanguageModeling, + SchedulerType, + get_scheduler, + set_seed, +) + +from my_data_collator import MyDataCollatorForPreTraining +from process_data.utils import CURRENT_DATA_BASE + +model_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.bin") +config_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.config.json") +tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json") + + +def load_model(): + config = BertConfig.from_json_file(config_file) + model = BertForPreTraining(config) + state_dict = torch.load(model_file) + model.load_state_dict(state_dict) + model.eval() + print("Load model successfully !") + + tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) + tokenizer.enable_padding( + pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32 + ) + print("Load tokenizer successfully !") + return model, tokenizer + + +def process_input(inst, tokenizer): + encoded_input = {} + if isinstance(inst, str): + # make a batch by myself + inst = [inst for _ in range(8)] + results = tokenizer.encode_batch(inst) + encoded_input["input_ids"] = [result.ids for result in results] + encoded_input["token_type_ids"] = [result.type_ids for result in results] + encoded_input["special_tokens_mask"] = [ + result.special_tokens_mask for result in results + ] + + # print(encoded_input["input_ids"]) + + # use `np` rather than `pt` in case of reporting of error + batch_output = BatchEncoding( + encoded_input, tensor_type="np", prepend_batch_axis=False, + ) + + # print(batch_output["input_ids"]) + + # NOTE: utilize the "special_tokens_mask", + # only work if the input consists of single instruction + length_mask = 1 - batch_output["special_tokens_mask"] + + data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False) + + model_input = data_collator([batch_output]) + + # print(model_input["input_ids"]) + + return model_input, length_mask + + +def generate_inst_vec(inst, method="mean"): + model, tokenizer = load_model() + + model_input, length_mask = process_input(inst, tokenizer) + length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device) + + output = model(**model_input, output_hidden_states=True) + + if method == "cls": + if isinstance(inst, str): + return output.hidden_states[-1][0][0] + elif isinstance(inst, list): + return output.hidden_states[-1, :, 0, :] + elif method == "mean": + result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1) + # print(result.shape) + if isinstance(inst, str): + result = torch.mean(result[0], dim=0) + elif isinstance(inst, list): + result = torch.mean(result, dim=1) + return result + elif method == "max": + result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1) + # print(result.shape) + if isinstance(inst, str): + result = torch.max(result[0], dim=0) + elif isinstance(inst, list): + result = torch.max(result, dim=1) + return result + + +def main(): + inst = ["mov ebp esp" for _ in range(8)] + print(generate_inst_vec(inst).shape) + + +if __name__ == "__main__": + main()