Compare commits
No commits in common. "979573651d6d48c10ab69f244a5961eeb13200ea" and "1d1acb29b32f56a7338ad18c7dacd0c81e655110" have entirely different histories.
979573651d
...
1d1acb29b3
21
README.md
21
README.md
@ -1,19 +1,2 @@
|
|||||||
# Inst2Vec Model
|
# ASM-BERT
|
||||||
Using [HuggingFace Transformers](https://github.com/huggingface/transformers) to train a BERT with dynamic mask for Assemble Language from scratch. We name it `Inst2Vec` for it is designed to generate vectors for assemble instructions.
|
Using HuggingFace Transformers to train a BERT for Assemble Language
|
||||||
|
|
||||||
It is a part of the model introduced in the ICONIP 2021 paper [A Hierarchical Graph-based Neural Network for Malware Classification](https://link.springer.com/chapter/10.1007%2F978-3-030-92273-3_51).
|
|
||||||
|
|
||||||
The preprocessing procedure can be found in [process_data](./process_data/readme.md).
|
|
||||||
|
|
||||||
You can simply run `python train_my_tokenizer.py` to obtain an Assemble Tokenizer.
|
|
||||||
|
|
||||||
The script I use to train the `Inst2Vec1` model is as follows:
|
|
||||||
```
|
|
||||||
python my_run_mlm_no_trainer.py \
|
|
||||||
--per_device_train_batch_size 8192 \
|
|
||||||
--per_device_eval_batch_size 16384 \
|
|
||||||
--num_warmup_steps 4000 --output_dir ./ \
|
|
||||||
--seed 1234 --preprocessing_num_workers 32 \
|
|
||||||
--max_train_steps 150000 \
|
|
||||||
--eval_every_steps 1000
|
|
||||||
```
|
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
from typing import (
|
||||||
Sequence, Tuple, Union)
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tokenizers
|
import tokenizers
|
||||||
@ -18,8 +27,6 @@ class MyDataCollatorForPreTraining:
|
|||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# print(self.mlm, self.tokenzier.token_to_id("[MASK]"))
|
|
||||||
# input()
|
|
||||||
if self.mlm and self.tokenizer.token_to_id("[MASK]") is None:
|
if self.mlm and self.tokenizer.token_to_id("[MASK]") is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||||
@ -29,7 +36,6 @@ class MyDataCollatorForPreTraining:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]],
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]],
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
# print(examples)
|
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
batch = pad(
|
batch = pad(
|
||||||
@ -50,9 +56,6 @@ class MyDataCollatorForPreTraining:
|
|||||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
batch["input_ids"] = torch.squeeze(batch["input_ids"], dim=0)
|
|
||||||
batch["token_type_ids"] = torch.squeeze(batch["token_type_ids"], dim=0)
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def mask_tokens(
|
def mask_tokens(
|
||||||
|
@ -43,8 +43,6 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
|
|||||||
from my_data_collator import MyDataCollatorForPreTraining
|
from my_data_collator import MyDataCollatorForPreTraining
|
||||||
from process_data.utils import CURRENT_DATA_BASE
|
from process_data.utils import CURRENT_DATA_BASE
|
||||||
|
|
||||||
HIDDEN_SIZE=16
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
@ -57,13 +55,13 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--per_device_train_batch_size",
|
"--per_device_train_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2048,
|
default=16,
|
||||||
help="Batch size (per device) for the training dataloader.",
|
help="Batch size (per device) for the training dataloader.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--per_device_eval_batch_size",
|
"--per_device_eval_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16384,
|
default=64,
|
||||||
help="Batch size (per device) for the evaluation dataloader.",
|
help="Batch size (per device) for the evaluation dataloader.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -84,7 +82,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_train_steps",
|
"--max_train_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=150000,
|
default=None,
|
||||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -104,19 +102,19 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_warmup_steps",
|
"--num_warmup_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=4000,
|
default=0,
|
||||||
help="Number of steps for the warmup in the lr scheduler.",
|
help="Number of steps for the warmup in the lr scheduler.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir", type=str, default='./dataset/out', help="Where to store the final model."
|
"--output_dir", type=str, default=None, help="Where to store the final model."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed", type=int, default=1234, help="A seed for reproducible training."
|
"--seed", type=int, default=None, help="A seed for reproducible training."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers",
|
"--preprocessing_num_workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=32,
|
default=None,
|
||||||
help="The number of processes to use for the preprocessing.",
|
help="The number of processes to use for the preprocessing.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -189,9 +187,10 @@ def main():
|
|||||||
# field="data",
|
# field="data",
|
||||||
# )
|
# )
|
||||||
train_files = [
|
train_files = [
|
||||||
os.path.join(CURRENT_DATA_BASE, 'json',f"inst.all.{i}.json") for i in range(8)
|
os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i))
|
||||||
|
for i in range(0, 128, 2)
|
||||||
]
|
]
|
||||||
valid_file = os.path.join(CURRENT_DATA_BASE, 'test', "inst.json")
|
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
||||||
raw_datasets = load_dataset(
|
raw_datasets = load_dataset(
|
||||||
"json",
|
"json",
|
||||||
data_files={"train": train_files, "validation": valid_file,},
|
data_files={"train": train_files, "validation": valid_file,},
|
||||||
@ -200,23 +199,23 @@ def main():
|
|||||||
|
|
||||||
# we use the tokenizer previously trained on the dataset above
|
# we use the tokenizer previously trained on the dataset above
|
||||||
tokenizer = tokenizers.Tokenizer.from_file(
|
tokenizer = tokenizers.Tokenizer.from_file(
|
||||||
os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
|
os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json")
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py`
|
# NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py`
|
||||||
tokenizer.enable_padding(
|
tokenizer.enable_padding(
|
||||||
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50
|
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: `max_position_embeddings` here should be consistent with `length` above
|
# NOTE: `max_position_embeddings` here should be consistent with `length` above
|
||||||
# we use a much smaller BERT, config is:
|
# we use a much smaller BERT, config is:
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=tokenizer.get_vocab_size(),
|
vocab_size=tokenizer.get_vocab_size(),
|
||||||
hidden_size=HIDDEN_SIZE,
|
hidden_size=96,
|
||||||
num_hidden_layers=4,
|
num_hidden_layers=4,
|
||||||
num_attention_heads=8,
|
num_attention_heads=12,
|
||||||
intermediate_size=4 * HIDDEN_SIZE,
|
intermediate_size=384,
|
||||||
max_position_embeddings=50,
|
max_position_embeddings=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initalize a new BERT for pre-training
|
# initalize a new BERT for pre-training
|
||||||
@ -232,10 +231,7 @@ def main():
|
|||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
text = [tuple(sent) for sent in examples["text"]]
|
text = [tuple(sent) for sent in examples["text"]]
|
||||||
encoded_inputs = {}
|
encoded_inputs = {}
|
||||||
# try:
|
|
||||||
results = tokenizer.encode_batch(text)
|
results = tokenizer.encode_batch(text)
|
||||||
# except:
|
|
||||||
# return None
|
|
||||||
encoded_inputs["input_ids"] = [result.ids for result in results]
|
encoded_inputs["input_ids"] = [result.ids for result in results]
|
||||||
encoded_inputs["token_type_ids"] = [result.type_ids for result in results]
|
encoded_inputs["token_type_ids"] = [result.type_ids for result in results]
|
||||||
encoded_inputs["special_tokens_mask"] = [
|
encoded_inputs["special_tokens_mask"] = [
|
||||||
@ -258,7 +254,6 @@ def main():
|
|||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = tokenized_datasets["train"]
|
train_dataset = tokenized_datasets["train"]
|
||||||
|
65
my_utils.py
65
my_utils.py
@ -1,65 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
"""
|
|
||||||
日志工具
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
logger = setup_logger(日志记录器的实例名字, 日志文件目录)
|
|
||||||
"""
|
|
||||||
def setup_logger(name, log_file, level=logging.INFO):
|
|
||||||
"""Function setup as many loggers as you want"""
|
|
||||||
if not os.path.exists(os.path.dirname(log_file)):
|
|
||||||
os.makedirs(os.path.dirname(log_file))
|
|
||||||
|
|
||||||
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
|
|
||||||
|
|
||||||
handler = logging.FileHandler(log_file)
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
# 控制台是否输出日志信息
|
|
||||||
# stream_handler = logging.StreamHandler()
|
|
||||||
# stream_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
logger.setLevel(level)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
# 控制台
|
|
||||||
# logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
# 刷新原有log文件
|
|
||||||
|
|
||||||
if os.path.exists(log_file):
|
|
||||||
open(log_file, 'w').close()
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
多线程工具
|
|
||||||
"""
|
|
||||||
THREAD_FULL = os.cpu_count()
|
|
||||||
THREAD_HALF = int(os.cpu_count() / 2)
|
|
||||||
def multi_thread(func, args, thread_num=THREAD_FULL):
|
|
||||||
"""
|
|
||||||
多线程执行函数
|
|
||||||
:param func: 函数
|
|
||||||
:param args: list函数参数
|
|
||||||
:param thread_num: 线程数
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
import concurrent.futures
|
|
||||||
from tqdm import tqdm
|
|
||||||
logger = setup_logger('multi_thread', './multi_thread.log')
|
|
||||||
result = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
|
|
||||||
futures_to_args = {
|
|
||||||
executor.submit(func, arg): arg for arg in args
|
|
||||||
}
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(args)):
|
|
||||||
try:
|
|
||||||
result.append(future.result())
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error('%r generated an exception: %s' % (futures_to_args[future], exc))
|
|
||||||
return result
|
|
||||||
|
|
@ -1,131 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import tokenizers
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from datasets import load_dataset
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import DataParallel
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import (
|
|
||||||
CONFIG_MAPPING,
|
|
||||||
MODEL_MAPPING,
|
|
||||||
AdamW,
|
|
||||||
AutoConfig,
|
|
||||||
AutoModelForMaskedLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BatchEncoding,
|
|
||||||
BertConfig,
|
|
||||||
BertForPreTraining,
|
|
||||||
DataCollatorForLanguageModeling,
|
|
||||||
SchedulerType,
|
|
||||||
get_scheduler,
|
|
||||||
set_seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
from my_data_collator import MyDataCollatorForPreTraining
|
|
||||||
from process_data.utils import CURRENT_DATA_BASE
|
|
||||||
|
|
||||||
model_file = os.path.join(CURRENT_DATA_BASE, 'out' ,"pytorch_model.bin")
|
|
||||||
config_file = os.path.join(CURRENT_DATA_BASE, 'out' ,"config.json")
|
|
||||||
tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
|
||||||
config = BertConfig.from_json_file(config_file)
|
|
||||||
model = BertForPreTraining(config)
|
|
||||||
state_dict = torch.load(model_file)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
print("Load model successfully !")
|
|
||||||
|
|
||||||
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
|
||||||
tokenizer.enable_padding(
|
|
||||||
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50
|
|
||||||
)
|
|
||||||
print("Load tokenizer successfully !")
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def process_input(inst, tokenizer):
|
|
||||||
encoded_input = {}
|
|
||||||
if isinstance(inst, str):
|
|
||||||
# make a batch by myself
|
|
||||||
inst = [inst for _ in range(8)]
|
|
||||||
results = tokenizer.encode_batch(inst)
|
|
||||||
encoded_input["input_ids"] = [result.ids for result in results]
|
|
||||||
encoded_input["token_type_ids"] = [result.type_ids for result in results]
|
|
||||||
encoded_input["special_tokens_mask"] = [
|
|
||||||
result.special_tokens_mask for result in results
|
|
||||||
]
|
|
||||||
|
|
||||||
# print(encoded_input["input_ids"])
|
|
||||||
|
|
||||||
# use `np` rather than `pt` in case of reporting of error
|
|
||||||
batch_output = BatchEncoding(
|
|
||||||
encoded_input, tensor_type="np", prepend_batch_axis=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# print(batch_output["input_ids"])
|
|
||||||
|
|
||||||
# NOTE: utilize the "special_tokens_mask",
|
|
||||||
# only work if the input consists of single instruction
|
|
||||||
length_mask = 1 - batch_output["special_tokens_mask"]
|
|
||||||
|
|
||||||
data_collator = MyDataCollatorForPreTraining(tokenizer=tokenizer, mlm=False)
|
|
||||||
|
|
||||||
model_input = data_collator([batch_output])
|
|
||||||
|
|
||||||
# print(model_input["input_ids"])
|
|
||||||
|
|
||||||
return model_input, length_mask
|
|
||||||
|
|
||||||
|
|
||||||
def generate_inst_vec(inst, method="mean"):
|
|
||||||
model, tokenizer = load_model()
|
|
||||||
|
|
||||||
model_input, length_mask = process_input(inst, tokenizer)
|
|
||||||
length_mask = torch.from_numpy(length_mask).to(model_input["input_ids"].device)
|
|
||||||
|
|
||||||
output = model(**model_input, output_hidden_states=True)
|
|
||||||
|
|
||||||
if method == "cls":
|
|
||||||
if isinstance(inst, str):
|
|
||||||
return output.hidden_states[-1][0][0]
|
|
||||||
elif isinstance(inst, list):
|
|
||||||
return output.hidden_states[-1, :, 0, :]
|
|
||||||
elif method == "mean":
|
|
||||||
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
|
|
||||||
# print(result.shape)
|
|
||||||
if isinstance(inst, str):
|
|
||||||
result = torch.mean(result[0], dim=0)
|
|
||||||
elif isinstance(inst, list):
|
|
||||||
result = torch.mean(result, dim=1)
|
|
||||||
return result
|
|
||||||
elif method == "max":
|
|
||||||
result = output.hidden_states[-1] * torch.unsqueeze(length_mask, dim=-1)
|
|
||||||
# print(result.shape)
|
|
||||||
if isinstance(inst, str):
|
|
||||||
result = torch.max(result[0], dim=0)
|
|
||||||
elif isinstance(inst, list):
|
|
||||||
result = torch.max(result, dim=1)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
inst = ['adc byte [ ebp - 0x74 ] cl','mov dh 0x79','adc eax 1']
|
|
||||||
tmp = generate_inst_vec(inst, method="mean")
|
|
||||||
print(tmp.shape)
|
|
||||||
print(tmp.detach().numpy())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
from utils import ORIGINAL_DATA_BASE, read_file
|
from utils import ORIGINAL_DATA_BASE, read_file
|
||||||
|
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ def check(filename):
|
|||||||
sents = read_file(filename)
|
sents = read_file(filename)
|
||||||
result = 0
|
result = 0
|
||||||
for sent in sents:
|
for sent in sents:
|
||||||
result = max(result, len(sent[:-1].replace("\t", " ").split()))
|
result = max(result, len(sent[-1].replace("\t", " ").split()))
|
||||||
print("The longest sentence in {} has {} words".format(filename, result))
|
print("The longest sentence in {} has {} words".format(filename, result))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -15,10 +15,10 @@ def check(filename):
|
|||||||
def main():
|
def main():
|
||||||
longest = 0
|
longest = 0
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in range(32):
|
for i in [1]:
|
||||||
for group in ("pos", "neg"):
|
for group in ("pos", "neg"):
|
||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
ORIGINAL_DATA_BASE, f'{group}_clean',f"inst.{i}.{group}.txt.clean"
|
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group)
|
||||||
)
|
)
|
||||||
longest = max(check(filename), longest)
|
longest = max(check(filename), longest)
|
||||||
print("The longest sentence in all files has {} words.".format(longest))
|
print("The longest sentence in all files has {} words.".format(longest))
|
||||||
|
@ -1,57 +0,0 @@
|
|||||||
from utils import ORIGINAL_DATA_BASE, read_file, write_file
|
|
||||||
from tqdm import tqdm
|
|
||||||
import os
|
|
||||||
from my_utils import multi_thread, setup_logger
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove(neg_list, pos_file):
|
|
||||||
ret = []
|
|
||||||
for neg in neg_list:
|
|
||||||
if neg in pos_file:
|
|
||||||
continue
|
|
||||||
ret.append(neg)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def split_list_evenly(lst, n):
|
|
||||||
# 计算每块的大小(整除,最后一块可能略短)
|
|
||||||
chunk_size = len(lst) // n
|
|
||||||
# 最后一块可能需要额外的元素
|
|
||||||
last_chunk_size = len(lst) % n
|
|
||||||
# 初始化空列表存放切片后的块
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
# 对于前n-1块
|
|
||||||
for i in range(0, (n - (last_chunk_size > 0)), chunk_size):
|
|
||||||
chunks.append(lst[i:i + chunk_size])
|
|
||||||
|
|
||||||
# 添加最后一个可能稍短的块
|
|
||||||
if last_chunk_size > 0:
|
|
||||||
chunks.append(lst[(n - (last_chunk_size > 0)) * chunk_size:])
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def main():
|
|
||||||
file = os.path.join('../dataset/all/all_clean')
|
|
||||||
pos_file = read_file(os.path.join(file, "inst.all.pos.txt.clean"))
|
|
||||||
neg_file = split_list_evenly(read_file(os.path.join(file, "inst.all.neg.txt.clean")), int(os.cpu_count()*1000))
|
|
||||||
print(len(neg_file))
|
|
||||||
logger = setup_logger('remove', '../out/remove.log')
|
|
||||||
result = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
|
||||||
print('start build task.')
|
|
||||||
futures_to_args = {
|
|
||||||
executor.submit(remove, neg_list, pos_file): neg_list for neg_list in neg_file
|
|
||||||
}
|
|
||||||
print('start run task.')
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(futures_to_args)):
|
|
||||||
try:
|
|
||||||
result.extend(future.result())
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(exc)
|
|
||||||
write_file(result, os.path.join(file, "inst.all.neg.txt.clean"))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -19,11 +19,11 @@ def convert(fin, fout):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
# for i in range(10):
|
for i in [1]:
|
||||||
# fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i))
|
fin = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
||||||
# fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
|
fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
|
||||||
# convert(fin, fout)
|
convert(fin, fout)
|
||||||
convert(os.path.join('../dataset/all/win.all'), os.path.join('../dataset/all/inst.pos.txt'))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -41,12 +41,12 @@ def counter(filename):
|
|||||||
def main():
|
def main():
|
||||||
cnt = set()
|
cnt = set()
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in range(10):
|
for i in [1]:
|
||||||
for group in ["pos", "neg"]:
|
for group in ["pos", "neg"]:
|
||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
|
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group)
|
||||||
)
|
)
|
||||||
cnt = cnt.union(counter(filename))
|
cnt += counter(filename)
|
||||||
print("There are {} charcters in all files".format(len(cnt)))
|
print("There are {} charcters in all files".format(len(cnt)))
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,35 +5,30 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from utils import ORIGINAL_DATA_BASE, read_file
|
from utils import ORIGINAL_DATA_BASE, read_file
|
||||||
|
|
||||||
from my_utils import multi_thread
|
|
||||||
|
|
||||||
|
|
||||||
def create(pos, neg, tgt):
|
def create(pos, neg, tgt):
|
||||||
pos_sents = read_file(pos)
|
pos_sents = read_file(pos)
|
||||||
|
|
||||||
neg_sents = read_file(neg)
|
neg_sents = read_file(neg)
|
||||||
neg_length = len(neg_sents)
|
neg_length = len(neg_sents)
|
||||||
|
print("Start writing negative examples to {}...".format(tgt))
|
||||||
with open(tgt, "w", encoding="utf-8") as fout:
|
with open(tgt, "w", encoding="utf-8") as fout:
|
||||||
for sent in tqdm(pos_sents):
|
for sent in tqdm(pos_sents):
|
||||||
first = sent.split("\t")[0]
|
first = sent.split("\t")[0]
|
||||||
index = randint(0, neg_length - 1)
|
index = randint(0, neg_length - 1)
|
||||||
pair = neg_sents[index].split("\t")
|
pair = neg_sents[index].split("\t")[randint(0, 1)].replace("\n", "")
|
||||||
pair = pair[randint(0, 1)]
|
|
||||||
pair = pair.replace("\n", "")
|
|
||||||
fout.write(first + "\t" + pair + "\n")
|
fout.write(first + "\t" + pair + "\n")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
# neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
for i in [1]:
|
||||||
# pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
j = (i + 1) % 6
|
||||||
file = os.path.join("../dataset/all/pos_clean")
|
pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
||||||
out_file = os.path.join("../dataset/all/neg_txt")
|
neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
||||||
os.makedirs(out_file, exist_ok=True)
|
tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i))
|
||||||
for i in tqdm(range(os.cpu_count()), total=os.cpu_count() ):
|
|
||||||
j = (i + 1) % os.cpu_count()
|
|
||||||
pos = os.path.join(file, f"inst.{i}.pos.txt.clean")
|
|
||||||
neg = os.path.join(file, f"inst.{j}.pos.txt.clean")
|
|
||||||
tgt = os.path.join(out_file, f"inst.{i}.neg.txt")
|
|
||||||
create(pos, neg, tgt)
|
create(pos, neg, tgt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,90 +0,0 @@
|
|||||||
import os
|
|
||||||
import r2pipe
|
|
||||||
from my_utils import setup_logger
|
|
||||||
import concurrent.futures
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def extract_opcode(disasm_text):
|
|
||||||
"""
|
|
||||||
从反汇编文本中提取操作码和操作数
|
|
||||||
正则表达式用于匹配操作码和操作数,考虑到操作数可能包含空格和逗号
|
|
||||||
"""
|
|
||||||
op_list = disasm_text.split(' ')
|
|
||||||
res = []
|
|
||||||
for item in op_list:
|
|
||||||
item = item.strip().replace(',', '')
|
|
||||||
if '[' in item:
|
|
||||||
res.append('[')
|
|
||||||
res.append(item.replace('[', '').replace(']', ''))
|
|
||||||
if ']' in item:
|
|
||||||
res.append(']')
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def double_exe_op_list(op_list):
|
|
||||||
double_exe_op_list = []
|
|
||||||
for i in range(len(op_list) - 1 ):
|
|
||||||
double_exe_op_list.append((op_list[i], op_list[i + 1]))
|
|
||||||
return double_exe_op_list
|
|
||||||
|
|
||||||
def get_all_from_exe(file, out_file):
|
|
||||||
# 获取基础块内的操作码序列
|
|
||||||
r2pipe_open = r2pipe.open(os.path.join(file), flags=['-2'])
|
|
||||||
with open(out_file, 'a') as f:
|
|
||||||
try:
|
|
||||||
# 获取函数列表
|
|
||||||
r2pipe_open.cmd("aaa")
|
|
||||||
r2pipe_open.cmd('e arch=x86')
|
|
||||||
function_list = r2pipe_open.cmdj("aflj")
|
|
||||||
exe_op_list = []
|
|
||||||
for function in function_list:
|
|
||||||
if function['name'][:4] not in ['fcn.', 'loc.', 'main', 'entr']:
|
|
||||||
continue
|
|
||||||
block_list = r2pipe_open.cmdj("afbj @" + str(function['offset']))
|
|
||||||
for block in block_list:
|
|
||||||
# 获取基本块的反汇编指令
|
|
||||||
disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"]))
|
|
||||||
if disasm:
|
|
||||||
for op in disasm:
|
|
||||||
if op["type"] == "invalid" or op["opcode"] == "invalid":
|
|
||||||
continue
|
|
||||||
op_list = extract_opcode(op["disasm"])
|
|
||||||
exe_op_list.append(' '.join(op_list))
|
|
||||||
exe_op_list = double_exe_op_list(exe_op_list)
|
|
||||||
for op_str_before, op_str_after in exe_op_list:
|
|
||||||
f.write(op_str_before + '\t' + op_str_after + '\n')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error: get function list failed in {file} ,error {e}")
|
|
||||||
return False, file, e
|
|
||||||
r2pipe_open.quit()
|
|
||||||
return True, '', ''
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
sample_file_path = '/mnt/d/bishe/dataset/sample_malware/'
|
|
||||||
sample_file_list = os.listdir(sample_file_path)[:1000]
|
|
||||||
out_file_path = '../dataset/all'
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
|
||||||
print(f"start with {os.cpu_count()} works.")
|
|
||||||
future_to_args = {
|
|
||||||
executor.submit(get_all_from_exe,
|
|
||||||
os.path.join(sample_file_path, sample_file_list[file_index]),
|
|
||||||
os.path.join(out_file_path, str(f'inst.{file_index%os.cpu_count()}.pos.txt'))
|
|
||||||
):
|
|
||||||
file_index for file_index in range(len(sample_file_list))
|
|
||||||
}
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(future_to_args), total=len(sample_file_list)):
|
|
||||||
try:
|
|
||||||
future.result()
|
|
||||||
if not future.result()[0]:
|
|
||||||
print(f"Error file: {future.result()[1]}, msg {future.result()[2]}")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Error: {exc}")
|
|
||||||
print(f"Error: {exc}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
logger = setup_logger('exe2all', '../log/exe2all.log')
|
|
||||||
main()
|
|
||||||
|
|
@ -7,8 +7,8 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
|
from utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
|
||||||
|
|
||||||
# BASE = 4600000
|
BASE = 4600000
|
||||||
BASE = 46000
|
|
||||||
|
|
||||||
def write_worker(sents, json_file, index):
|
def write_worker(sents, json_file, index):
|
||||||
examples = []
|
examples = []
|
||||||
@ -24,19 +24,11 @@ def write_worker(sents, json_file, index):
|
|||||||
|
|
||||||
|
|
||||||
def merge_to_json(pos, neg, json_file):
|
def merge_to_json(pos, neg, json_file):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sents = read_file(pos)
|
sents = read_file(pos)
|
||||||
|
|
||||||
p = Pool(6)
|
p = Pool(36)
|
||||||
|
|
||||||
for i in range(6):
|
for i in range(64):
|
||||||
p.apply_async(
|
p.apply_async(
|
||||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,)
|
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,)
|
||||||
)
|
)
|
||||||
@ -63,11 +55,11 @@ def merge_to_json(pos, neg, json_file):
|
|||||||
|
|
||||||
sents = read_file(neg)
|
sents = read_file(neg)
|
||||||
|
|
||||||
p = Pool(6)
|
p = Pool(8)
|
||||||
|
|
||||||
for i in range(6):
|
for i in range(64):
|
||||||
p.apply_async(
|
p.apply_async(
|
||||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 6 + i,)
|
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 64 + i,)
|
||||||
)
|
)
|
||||||
print("Waiting for all sub-processes done...")
|
print("Waiting for all sub-processes done...")
|
||||||
p.close()
|
p.close()
|
||||||
@ -88,14 +80,10 @@ def merge_to_json(pos, neg, json_file):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
# for i in range(6):
|
for i in [1]:
|
||||||
# pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
|
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
|
||||||
# neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
|
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
|
||||||
# json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
|
json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
|
||||||
# merge_to_json(pos, neg, json_file)
|
|
||||||
pos = os.path.join(ORIGINAL_DATA_BASE,'all_clean', "inst.all.pos.txt.clean.label")
|
|
||||||
neg = os.path.join(ORIGINAL_DATA_BASE, 'all_clean',"inst.all.neg.txt.clean.label")
|
|
||||||
json_file = os.path.join(CURRENT_DATA_BASE, 'json',"inst.all.")
|
|
||||||
merge_to_json(pos, neg, json_file)
|
merge_to_json(pos, neg, json_file)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
import r2pipe
|
|
||||||
from my_utils import setup_logger, multi_thread, THREAD_FULL
|
|
||||||
import os
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
def get_all_from_exe(file):
|
|
||||||
# 获取基础块内的操作码序列
|
|
||||||
r2pipe_open = r2pipe.open(os.path.join(file), flags=['-2'])
|
|
||||||
try:
|
|
||||||
# 获取函数列表
|
|
||||||
r2pipe_open.cmd("aaa")
|
|
||||||
r2pipe_open.cmd('e arch=x86')
|
|
||||||
function_list = r2pipe_open.cmdj("aflj")
|
|
||||||
exe_op_count = []
|
|
||||||
for function in function_list:
|
|
||||||
function_op_count_list = []
|
|
||||||
if function['name'][:4] not in ['fcn.', 'loc.', 'main', 'entr']:
|
|
||||||
continue
|
|
||||||
block_list = r2pipe_open.cmdj("afbj @" + str(function['offset']))
|
|
||||||
|
|
||||||
for block in block_list:
|
|
||||||
# 获取基本块的反汇编指令
|
|
||||||
disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"]))
|
|
||||||
block_op_count = 0
|
|
||||||
if disasm:
|
|
||||||
print_flag = 1 if len(disasm) >= 723 else 0
|
|
||||||
for op in disasm:
|
|
||||||
if op["type"] == "invalid" or op["opcode"] == "invalid":
|
|
||||||
continue
|
|
||||||
if print_flag == 1:
|
|
||||||
print(op['disasm'])
|
|
||||||
block_op_count += 1
|
|
||||||
function_op_count_list.append(block_op_count)
|
|
||||||
exe_op_count.append(function_op_count_list)
|
|
||||||
|
|
||||||
logger.info(f"{file} {exe_op_count}")
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error: get function list failed in {file} ,error {e}")
|
|
||||||
return False, file, e
|
|
||||||
r2pipe_open.quit()
|
|
||||||
return True, '', ''
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
logger = setup_logger('get_all_from_exe', '../../log/get_all_from_exe.log')
|
|
||||||
file = '/mnt/d/bishe/dataset/sample_benign'
|
|
||||||
file_list = os.listdir(file)
|
|
||||||
multi_thread(get_all_from_exe, ['/mnt/d/bishe/dataset/sample_benign/00125dcd81261701fcaaf84d0cb45d0e.exe'], thread_num=THREAD_FULL)
|
|
@ -2,64 +2,21 @@
|
|||||||
### 1. run `convert_space_format.py`
|
### 1. run `convert_space_format.py`
|
||||||
Convert the string `<space>` to `SPACE`
|
Convert the string `<space>` to `SPACE`
|
||||||
|
|
||||||
`linux32_0ixxxx.all -> inst.i.pos.txt` located at `/home/ming/malware/data/elfasm_inst_pairs`
|
### 2. run `create_negtive_examples.py`
|
||||||
|
|
||||||
### 2. remove the repete lines in the `inst.i.pos.txt`
|
|
||||||
Using python script is too slow. We use the shell instead.
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
cat inst.i.pos.txt | sort -n | uniq > inst.i.pos.txt.clean
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### 3. create_negtive_examples
|
|
||||||
We use the next file of the current file as its negative examples, which is apparently rational.
|
We use the next file of the current file as its negative examples, which is apparently rational.
|
||||||
|
|
||||||
Specifically, for each instruction in the current positive file, we randomly choose a line in its next file and select one of two instructions in the line as its negative example.
|
Specifically, for each instruction in the current positive file, we randomly choose a line in its next file and select one of two instructions in the line as its negative example.
|
||||||
|
|
||||||
`python create_negtive_examples.py`, generating `inst.i.neg.txt` located at `/home/ming/malware/data/elfasm_inst_pairs`
|
### 3. run `merge_examples_to_json.py`
|
||||||
|
We dump the positive and negative examples with their corresponding labels into several json files.
|
||||||
|
Each json file contains 20m lines of examples.
|
||||||
|
|
||||||
|
### 4. run `check_length.py`
|
||||||
### 4. merge all of the files
|
|
||||||
We catenate all of the `inst.i.pos.txt.clean` files and remove the possible repeting lines between different files:
|
|
||||||
``` shell
|
|
||||||
cat inst.*.pos.txt.clean | sort -n | uniq > inst.all.pos.txt.clean
|
|
||||||
```
|
|
||||||
|
|
||||||
We process the files containing negative examples similarly.
|
|
||||||
``` shell
|
|
||||||
cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean
|
|
||||||
```
|
|
||||||
|
|
||||||
Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`, or
|
|
||||||
<!-- ```shell
|
|
||||||
grep -v -f inst.all.pos.txt.clean inst.all.neg.txt.clean > inst.all.neg.txt.clean
|
|
||||||
``` -->
|
|
||||||
|
|
||||||
|
|
||||||
### 5. convert to json format
|
|
||||||
We first add labels for positive examples and negative examples
|
|
||||||
```shell
|
|
||||||
cat inst.all.neg.txt.clean | sed 's/^/0\t&/g' > inst.all.neg.txt.clean.label
|
|
||||||
cat inst.all.pos.txt.clean | sed 's/^/1\t&/g' > inst.all.pos.txt.clean.label
|
|
||||||
```
|
|
||||||
|
|
||||||
We dump the positive and negative examples with their corresponding labels into several json files, using `python merge_examples_to_json.py`.
|
|
||||||
|
|
||||||
Generate `inst.all.{0,1}.json` located at `/home/ming/malware/inst2vec_bert/data/asm_bert`.
|
|
||||||
|
|
||||||
|
|
||||||
### 6. get the maximum of length in examples
|
|
||||||
We will specify the length padded to when we use the tokenizer, `tokenizer.enable_padding(..., length=)`.
|
We will specify the length padded to when we use the tokenizer, `tokenizer.enable_padding(..., length=)`.
|
||||||
|
|
||||||
So we need to know the longest sentences in the dataset.
|
So we need to know the longest sentences in the dataset.
|
||||||
|
|
||||||
The result is `28`, so I set `length=32`
|
### 5. run `count_word_for_vocab.py`
|
||||||
|
|
||||||
|
|
||||||
### 7. get the size of vocab of examples
|
|
||||||
Similarly, we also need to specify the size of vocabulary when we train the tokenizer, `WordLevelTrainer(vocab_size=, ...)`.
|
Similarly, we also need to specify the size of vocabulary when we train the tokenizer, `WordLevelTrainer(vocab_size=, ...)`.
|
||||||
|
|
||||||
So we need to know how many characters in the dataset.
|
So we need to know how many characters in the dataset.
|
||||||
|
|
||||||
The result is `1016`, so I set `vocab_size=2000`.
|
|
@ -1,18 +0,0 @@
|
|||||||
# Python运行shell脚本
|
|
||||||
import subprocess
|
|
||||||
import os
|
|
||||||
from my_utils import multi_thread
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_shell(file_num):
|
|
||||||
com_line = f'cat ./neg_txt/inst.{file_num}.neg.txt | sort -n | uniq > ./neg_clean/inst.{file_num}.neg.txt.clean'
|
|
||||||
p = subprocess.Popen(com_line, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
||||||
stdout, stderr = p.communicate()
|
|
||||||
return stdout, stderr
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os.chdir('../../dataset/all')
|
|
||||||
result = multi_thread(run_shell, range(os.cpu_count()))
|
|
||||||
|
|
||||||
|
|
@ -1,16 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
|
ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
|
||||||
ORIGINAL_DATA_BASE = "/mnt/d/bishe/Inst2Vec/dataset/all"
|
CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert"
|
||||||
CURRENT_DATA_BASE = "/mnt/d/bishe/Inst2Vec/dataset/all"
|
|
||||||
|
|
||||||
|
|
||||||
def read_file(filename):
|
def read_file(filename):
|
||||||
|
print("Reading data from {}...".format(filename))
|
||||||
with open(filename, "r", encoding="utf-8") as fin:
|
with open(filename, "r", encoding="utf-8") as fin:
|
||||||
return fin.readlines()
|
return fin.readlines()
|
||||||
|
|
||||||
|
|
||||||
def write_file(sents, filename):
|
def write_file(sents, filename):
|
||||||
|
print("Writing data to {}...".format(filename))
|
||||||
with open(filename, "w", encoding="utf-8") as fout:
|
with open(filename, "w", encoding="utf-8") as fout:
|
||||||
for sent in sents:
|
for sent in sents:
|
||||||
fout.write(sent)
|
fout.write(sent)
|
||||||
|
@ -2,6 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from tokenizers.models import WordLevel
|
from tokenizers.models import WordLevel
|
||||||
from tokenizers.pre_tokenizers import Whitespace
|
from tokenizers.pre_tokenizers import Whitespace
|
||||||
@ -10,7 +11,7 @@ from tokenizers.trainers import WordLevelTrainer
|
|||||||
|
|
||||||
from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
|
from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
|
||||||
|
|
||||||
|
BASE_PATH = "/home/ming/malware/inst2vec_bert/bert/"
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -26,7 +27,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--padding_length",
|
"--padding_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=32,
|
||||||
help="The length will be padded to by the tokenizer.",
|
help="The length will be padded to by the tokenizer.",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -98,8 +99,8 @@ def main(tokenizer_file=""):
|
|||||||
# dataset = load_dataset("json", data_files=json_files, field="data")
|
# dataset = load_dataset("json", data_files=json_files, field="data")
|
||||||
|
|
||||||
text_files = [
|
text_files = [
|
||||||
os.path.join(ORIGINAL_DATA_BASE, f'{group}_clean',f"inst.{i}.{group}.txt.clean")
|
os.path.join(ORIGINAL_DATA_BASE, "inst.1.{}.txt".format(group))
|
||||||
for group in ["pos", "neg"] for i in range(32)
|
for group in ["pos", "neg"]
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset = []
|
dataset = []
|
||||||
@ -120,4 +121,4 @@ def main(tokenizer_file=""):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json"))
|
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json"))
|
||||||
|
Loading…
Reference in New Issue
Block a user