complete interface for downstream task
This commit is contained in:
parent
b4810738b8
commit
fb61bb2a7b
@ -1,15 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
||||||
TYPE_CHECKING,
|
Sequence, Tuple, Union)
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tokenizers
|
import tokenizers
|
||||||
@ -27,6 +18,8 @@ class MyDataCollatorForPreTraining:
|
|||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# print(self.mlm, self.tokenzier.token_to_id("[MASK]"))
|
||||||
|
# input()
|
||||||
if self.mlm and self.tokenizer.token_to_id("[MASK]") is None:
|
if self.mlm and self.tokenizer.token_to_id("[MASK]") is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||||
@ -36,6 +29,7 @@ class MyDataCollatorForPreTraining:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]],
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]],
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
# print(examples)
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
batch = pad(
|
batch = pad(
|
||||||
@ -56,6 +50,9 @@ class MyDataCollatorForPreTraining:
|
|||||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
batch["input_ids"] = torch.squeeze(batch["input_ids"], dim=0)
|
||||||
|
batch["token_type_ids"] = torch.squeeze(batch["token_type_ids"], dim=0)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def mask_tokens(
|
def mask_tokens(
|
||||||
|
@ -187,8 +187,7 @@ def main():
|
|||||||
# field="data",
|
# field="data",
|
||||||
# )
|
# )
|
||||||
train_files = [
|
train_files = [
|
||||||
os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i))
|
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2)
|
||||||
for i in range(0, 128, 2)
|
|
||||||
]
|
]
|
||||||
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
||||||
raw_datasets = load_dataset(
|
raw_datasets = load_dataset(
|
||||||
@ -199,7 +198,7 @@ def main():
|
|||||||
|
|
||||||
# we use the tokenizer previously trained on the dataset above
|
# we use the tokenizer previously trained on the dataset above
|
||||||
tokenizer = tokenizers.Tokenizer.from_file(
|
tokenizer = tokenizers.Tokenizer.from_file(
|
||||||
os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json")
|
os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py`
|
# NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py`
|
||||||
@ -254,6 +253,7 @@ def main():
|
|||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = tokenized_datasets["train"]
|
train_dataset = tokenized_datasets["train"]
|
||||||
|
128
obtain_inst_vec.py
Normal file
128
obtain_inst_vec.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
import tokenizers
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.nn import DataParallel
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import (
|
||||||
|
CONFIG_MAPPING,
|
||||||
|
MODEL_MAPPING,
|
||||||
|
AdamW,
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BatchEncoding,
|
||||||
|
BertConfig,
|
||||||
|
BertForPreTraining,
|
||||||
|
DataCollatorForLanguageModeling,
|
||||||
|
SchedulerType,
|
||||||
|
get_scheduler,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
from my_data_collator import MyDataCollatorForPreTraining
|
||||||
|
from process_data.utils import CURRENT_DATA_BASE
|
||||||
|
|
||||||
|
model_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.bin")
|
||||||
|
config_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.config.json")
|
||||||
|
tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
config = BertConfig.from_json_file(config_file)
|
||||||
|
model = BertForPreTraining(config)
|
||||||
|
state_dict = torch.load(model_file)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
print("Load model successfully !")
|
||||||
|
|
||||||
|
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||||
|
tokenizer.enable_padding(
|
||||||
|
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32
|
||||||
|
)
|
||||||
|
print("Load tokenizer successfully !")
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def process_input(inst, tokenizer):
|
||||||
|
encoded_input = {}
|
||||||
|
if isinstance(inst, str):
|
||||||
|
# make a batch by myself
|
||||||
|
inst = [inst for _ in range(8)]
|
||||||
|
results = tokenizer.encode_batch(inst)
|
||||||
|
encoded_input["input_ids"] = [result.ids for result in results]
|
||||||
|
encoded_input["token_type_ids"] = [result.type_ids for result in results]
|
||||||
|
encoded_input["special_tokens_mask"] = [
|
||||||
|
result.special_tokens_mask for result in results
|
||||||
|
]
|
||||||
|
|
||||||
|
# print(encoded_input["input_ids"])
|
||||||
|
|
||||||
|
# use `np` rather than `pt` in case of reporting of error
|
||||||
|
batch_output = BatchEncoding(
|
||||||
|
encoded_input, tensor_type="np", prepend_batch_axis=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(batch_output["input_ids"])
|
||||||
|
|
||||||
|
# NOTE: utilize the "special_tokens_mask",
|
||||||
|
# only work if the input consists of single instruction
|
||||||
|
length_mask = 1 - batch_output["special_tokens_mask"]
|
||||||
|
|
||||||
|
data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
model_input = data_collator([batch_output])
|
||||||
|
|
||||||
|
# print(model_input["input_ids"])
|
||||||
|
|
||||||
|
return model_input, length_mask
|
||||||
|
|
||||||
|
|
||||||
|
def generate_inst_vec(inst, method="mean"):
|
||||||
|
model, tokenizer = load_model()
|
||||||
|
|
||||||
|
model_input, length_mask = process_input(inst, tokenizer)
|
||||||
|
length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device)
|
||||||
|
|
||||||
|
output = model(**model_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
if method == "cls":
|
||||||
|
if isinstance(inst, str):
|
||||||
|
return output.hidden_states[-1][0][0]
|
||||||
|
elif isinstance(inst, list):
|
||||||
|
return output.hidden_states[-1, :, 0, :]
|
||||||
|
elif method == "mean":
|
||||||
|
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
|
||||||
|
# print(result.shape)
|
||||||
|
if isinstance(inst, str):
|
||||||
|
result = torch.mean(result[0], dim=0)
|
||||||
|
elif isinstance(inst, list):
|
||||||
|
result = torch.mean(result, dim=1)
|
||||||
|
return result
|
||||||
|
elif method == "max":
|
||||||
|
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
|
||||||
|
# print(result.shape)
|
||||||
|
if isinstance(inst, str):
|
||||||
|
result = torch.max(result[0], dim=0)
|
||||||
|
elif isinstance(inst, list):
|
||||||
|
result = torch.max(result, dim=1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
inst = ["mov ebp esp" for _ in range(8)]
|
||||||
|
print(generate_inst_vec(inst).shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user