first commit

This commit is contained in:
huihun 2024-04-11 16:43:57 +08:00
parent 51a232453a
commit 979573651d
13 changed files with 332 additions and 63 deletions

View File

@ -43,7 +43,7 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
from my_data_collator import MyDataCollatorForPreTraining from my_data_collator import MyDataCollatorForPreTraining
from process_data.utils import CURRENT_DATA_BASE from process_data.utils import CURRENT_DATA_BASE
HIDDEN_SIZE=256 HIDDEN_SIZE=16
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
@ -57,13 +57,13 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--per_device_train_batch_size", "--per_device_train_batch_size",
type=int, type=int,
default=16, default=2048,
help="Batch size (per device) for the training dataloader.", help="Batch size (per device) for the training dataloader.",
) )
parser.add_argument( parser.add_argument(
"--per_device_eval_batch_size", "--per_device_eval_batch_size",
type=int, type=int,
default=64, default=16384,
help="Batch size (per device) for the evaluation dataloader.", help="Batch size (per device) for the evaluation dataloader.",
) )
parser.add_argument( parser.add_argument(
@ -84,7 +84,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--max_train_steps", "--max_train_steps",
type=int, type=int,
default=None, default=150000,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument( parser.add_argument(
@ -104,19 +104,19 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--num_warmup_steps", "--num_warmup_steps",
type=int, type=int,
default=0, default=4000,
help="Number of steps for the warmup in the lr scheduler.", help="Number of steps for the warmup in the lr scheduler.",
) )
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default=None, help="Where to store the final model." "--output_dir", type=str, default='./dataset/out', help="Where to store the final model."
) )
parser.add_argument( parser.add_argument(
"--seed", type=int, default=None, help="A seed for reproducible training." "--seed", type=int, default=1234, help="A seed for reproducible training."
) )
parser.add_argument( parser.add_argument(
"--preprocessing_num_workers", "--preprocessing_num_workers",
type=int, type=int,
default=None, default=32,
help="The number of processes to use for the preprocessing.", help="The number of processes to use for the preprocessing.",
) )
parser.add_argument( parser.add_argument(
@ -189,9 +189,9 @@ def main():
# field="data", # field="data",
# ) # )
train_files = [ train_files = [
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in [0,1,2,3,4,5,6] # ,8,9,10,11,12,13] os.path.join(CURRENT_DATA_BASE, 'json',f"inst.all.{i}.json") for i in range(8)
] ]
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" valid_file = os.path.join(CURRENT_DATA_BASE, 'test', "inst.json")
raw_datasets = load_dataset( raw_datasets = load_dataset(
"json", "json",
data_files={"train": train_files, "validation": valid_file,}, data_files={"train": train_files, "validation": valid_file,},
@ -205,7 +205,7 @@ def main():
# NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py` # NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py`
tokenizer.enable_padding( tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32 pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50
) )
# NOTE: `max_position_embeddings` here should be consistent with `length` above # NOTE: `max_position_embeddings` here should be consistent with `length` above
@ -216,7 +216,7 @@ def main():
num_hidden_layers=4, num_hidden_layers=4,
num_attention_heads=8, num_attention_heads=8,
intermediate_size=4 * HIDDEN_SIZE, intermediate_size=4 * HIDDEN_SIZE,
max_position_embeddings=32, max_position_embeddings=50,
) )
# initalize a new BERT for pre-training # initalize a new BERT for pre-training

65
my_utils.py Normal file
View File

@ -0,0 +1,65 @@
import logging
import os
"""
日志工具
使用方法
logger = setup_logger(日志记录器的实例名字, 日志文件目录)
"""
def setup_logger(name, log_file, level=logging.INFO):
"""Function setup as many loggers as you want"""
if not os.path.exists(os.path.dirname(log_file)):
os.makedirs(os.path.dirname(log_file))
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
handler = logging.FileHandler(log_file)
handler.setFormatter(formatter)
# 控制台是否输出日志信息
# stream_handler = logging.StreamHandler()
# stream_handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(handler)
# 控制台
# logger.addHandler(stream_handler)
# 刷新原有log文件
if os.path.exists(log_file):
open(log_file, 'w').close()
return logger
"""
多线程工具
"""
THREAD_FULL = os.cpu_count()
THREAD_HALF = int(os.cpu_count() / 2)
def multi_thread(func, args, thread_num=THREAD_FULL):
"""
多线程执行函数
:param func: 函数
:param args: list函数参数
:param thread_num: 线程数
:return:
"""
import concurrent.futures
from tqdm import tqdm
logger = setup_logger('multi_thread', './multi_thread.log')
result = []
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
futures_to_args = {
executor.submit(func, arg): arg for arg in args
}
for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(args)):
try:
result.append(future.result())
except Exception as exc:
logger.error('%r generated an exception: %s' % (futures_to_args[future], exc))
return result

View File

@ -11,6 +11,7 @@ import torch
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from torch import nn
from torch.nn import DataParallel from torch.nn import DataParallel
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -33,8 +34,8 @@ from transformers import (
from my_data_collator import MyDataCollatorForPreTraining from my_data_collator import MyDataCollatorForPreTraining
from process_data.utils import CURRENT_DATA_BASE from process_data.utils import CURRENT_DATA_BASE
model_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.bin") model_file = os.path.join(CURRENT_DATA_BASE, 'out' ,"pytorch_model.bin")
config_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.config.json") config_file = os.path.join(CURRENT_DATA_BASE, 'out' ,"config.json")
tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json") tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json")
@ -48,7 +49,7 @@ def load_model():
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
tokenizer.enable_padding( tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32 pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50
) )
print("Load tokenizer successfully !") print("Load tokenizer successfully !")
return model, tokenizer return model, tokenizer
@ -120,8 +121,10 @@ def generate_inst_vec(inst, method="mean"):
def main(): def main():
inst = ["mov ebp esp" for _ in range(8)] inst = ['adc byte [ ebp - 0x74 ] cl','mov dh 0x79','adc eax 1']
print(generate_inst_vec(inst).shape) tmp = generate_inst_vec(inst, method="mean")
print(tmp.shape)
print(tmp.detach().numpy())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -15,10 +15,10 @@ def check(filename):
def main(): def main():
longest = 0 longest = 0
# for i in range(6): # for i in range(6):
for i in range(10): for i in range(32):
for group in ("pos", "neg"): for group in ("pos", "neg"):
filename = os.path.join( filename = os.path.join(
ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group) ORIGINAL_DATA_BASE, f'{group}_clean',f"inst.{i}.{group}.txt.clean"
) )
longest = max(check(filename), longest) longest = max(check(filename), longest)
print("The longest sentence in all files has {} words.".format(longest)) print("The longest sentence in all files has {} words.".format(longest))

View File

@ -1,21 +1,57 @@
from utils import ORIGINAL_DATA_BASE, read_file, write_file from utils import ORIGINAL_DATA_BASE, read_file, write_file
from tqdm import tqdm from tqdm import tqdm
import os import os
from my_utils import multi_thread, setup_logger
import concurrent.futures
def remove(pos_file, neg_file):
pos = read_file(pos_file)
neg = read_file(neg_file)
rets = [] def remove(neg_list, pos_file):
for n in tqdm(neg): ret = []
if n in pos: for neg in neg_list:
if neg in pos_file:
continue continue
rets.append(n) ret.append(neg)
write_file(rets, neg_file) return ret
def split_list_evenly(lst, n):
# 计算每块的大小(整除,最后一块可能略短)
chunk_size = len(lst) // n
# 最后一块可能需要额外的元素
last_chunk_size = len(lst) % n
# 初始化空列表存放切片后的块
chunks = []
# 对于前n-1块
for i in range(0, (n - (last_chunk_size > 0)), chunk_size):
chunks.append(lst[i:i + chunk_size])
# 添加最后一个可能稍短的块
if last_chunk_size > 0:
chunks.append(lst[(n - (last_chunk_size > 0)) * chunk_size:])
return chunks
def main(): def main():
pos_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean") file = os.path.join('../dataset/all/all_clean')
neg_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean") pos_file = read_file(os.path.join(file, "inst.all.pos.txt.clean"))
remove(pos_file, neg_file) neg_file = split_list_evenly(read_file(os.path.join(file, "inst.all.neg.txt.clean")), int(os.cpu_count()*1000))
print(len(neg_file))
logger = setup_logger('remove', '../out/remove.log')
result = []
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
print('start build task.')
futures_to_args = {
executor.submit(remove, neg_list, pos_file): neg_list for neg_list in neg_file
}
print('start run task.')
for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(futures_to_args)):
try:
result.extend(future.result())
except Exception as exc:
logger.error(exc)
write_file(result, os.path.join(file, "inst.all.neg.txt.clean"))
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -19,11 +19,11 @@ def convert(fin, fout):
def main(): def main():
# for i in range(6): # for i in range(6):
for i in range(10): # for i in range(10):
fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i)) # fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i))
fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i)) # fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
convert(fin, fout) # convert(fin, fout)
convert(os.path.join('../dataset/all/win.all'), os.path.join('../dataset/all/inst.pos.txt'))
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -5,32 +5,35 @@ from tqdm import tqdm
from utils import ORIGINAL_DATA_BASE, read_file from utils import ORIGINAL_DATA_BASE, read_file
from my_utils import multi_thread
def create(pos, neg, tgt): def create(pos, neg, tgt):
pos_sents = read_file(pos) pos_sents = read_file(pos)
neg_sents = read_file(neg) neg_sents = read_file(neg)
neg_length = len(neg_sents) neg_length = len(neg_sents)
print("Start writing negative examples to {}...".format(tgt))
with open(tgt, "w", encoding="utf-8") as fout: with open(tgt, "w", encoding="utf-8") as fout:
for sent in tqdm(pos_sents): for sent in tqdm(pos_sents):
first = sent.split("\t")[0] first = sent.split("\t")[0]
index = randint(0, neg_length - 1) index = randint(0, neg_length - 1)
pair = neg_sents[index].split("\t")[randint(0, 1)].replace("\n", "") pair = neg_sents[index].split("\t")
pair = pair[randint(0, 1)]
pair = pair.replace("\n", "")
fout.write(first + "\t" + pair + "\n") fout.write(first + "\t" + pair + "\n")
def main(): def main():
# for i in range(6): # for i in range(6):
for i in range(10): # neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
j = (i + 1) % 10 # pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
# neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) file = os.path.join("../dataset/all/pos_clean")
# pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) out_file = os.path.join("../dataset/all/neg_txt")
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(i)) os.makedirs(out_file, exist_ok=True)
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(j)) for i in tqdm(range(os.cpu_count()), total=os.cpu_count() ):
tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i)) j = (i + 1) % os.cpu_count()
pos = os.path.join(file, f"inst.{i}.pos.txt.clean")
neg = os.path.join(file, f"inst.{j}.pos.txt.clean")
tgt = os.path.join(out_file, f"inst.{i}.neg.txt")
create(pos, neg, tgt) create(pos, neg, tgt)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

90
process_data/exe2all.py Normal file
View File

@ -0,0 +1,90 @@
import os
import r2pipe
from my_utils import setup_logger
import concurrent.futures
from tqdm import tqdm
def extract_opcode(disasm_text):
"""
从反汇编文本中提取操作码和操作数
正则表达式用于匹配操作码和操作数考虑到操作数可能包含空格和逗号
"""
op_list = disasm_text.split(' ')
res = []
for item in op_list:
item = item.strip().replace(',', '')
if '[' in item:
res.append('[')
res.append(item.replace('[', '').replace(']', ''))
if ']' in item:
res.append(']')
return res
def double_exe_op_list(op_list):
double_exe_op_list = []
for i in range(len(op_list) - 1 ):
double_exe_op_list.append((op_list[i], op_list[i + 1]))
return double_exe_op_list
def get_all_from_exe(file, out_file):
# 获取基础块内的操作码序列
r2pipe_open = r2pipe.open(os.path.join(file), flags=['-2'])
with open(out_file, 'a') as f:
try:
# 获取函数列表
r2pipe_open.cmd("aaa")
r2pipe_open.cmd('e arch=x86')
function_list = r2pipe_open.cmdj("aflj")
exe_op_list = []
for function in function_list:
if function['name'][:4] not in ['fcn.', 'loc.', 'main', 'entr']:
continue
block_list = r2pipe_open.cmdj("afbj @" + str(function['offset']))
for block in block_list:
# 获取基本块的反汇编指令
disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"]))
if disasm:
for op in disasm:
if op["type"] == "invalid" or op["opcode"] == "invalid":
continue
op_list = extract_opcode(op["disasm"])
exe_op_list.append(' '.join(op_list))
exe_op_list = double_exe_op_list(exe_op_list)
for op_str_before, op_str_after in exe_op_list:
f.write(op_str_before + '\t' + op_str_after + '\n')
except Exception as e:
logger.error(f"Error: get function list failed in {file} ,error {e}")
return False, file, e
r2pipe_open.quit()
return True, '', ''
def main():
sample_file_path = '/mnt/d/bishe/dataset/sample_malware/'
sample_file_list = os.listdir(sample_file_path)[:1000]
out_file_path = '../dataset/all'
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
print(f"start with {os.cpu_count()} works.")
future_to_args = {
executor.submit(get_all_from_exe,
os.path.join(sample_file_path, sample_file_list[file_index]),
os.path.join(out_file_path, str(f'inst.{file_index%os.cpu_count()}.pos.txt'))
):
file_index for file_index in range(len(sample_file_list))
}
for future in tqdm(concurrent.futures.as_completed(future_to_args), total=len(sample_file_list)):
try:
future.result()
if not future.result()[0]:
print(f"Error file: {future.result()[1]}, msg {future.result()[2]}")
except Exception as exc:
logger.error(f"Error: {exc}")
print(f"Error: {exc}")
if __name__ == '__main__':
logger = setup_logger('exe2all', '../log/exe2all.log')
main()

View File

@ -7,8 +7,8 @@ from tqdm import tqdm
from utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file from utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
BASE = 4600000 # BASE = 4600000
BASE = 46000
def write_worker(sents, json_file, index): def write_worker(sents, json_file, index):
examples = [] examples = []
@ -24,13 +24,21 @@ def write_worker(sents, json_file, index):
def merge_to_json(pos, neg, json_file): def merge_to_json(pos, neg, json_file):
sents = read_file(pos) sents = read_file(pos)
p = Pool(6) p = Pool(6)
for i in range(6): for i in range(6):
p.apply_async( p.apply_async(
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 2+i,) write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,)
) )
print("Waiting for all sub-processes done...") print("Waiting for all sub-processes done...")
p.close() p.close()
@ -59,7 +67,7 @@ def merge_to_json(pos, neg, json_file):
for i in range(6): for i in range(6):
p.apply_async( p.apply_async(
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 8 + i,) write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 6 + i,)
) )
print("Waiting for all sub-processes done...") print("Waiting for all sub-processes done...")
p.close() p.close()
@ -85,9 +93,9 @@ def main():
# neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i)) # neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
# json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i)) # json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
# merge_to_json(pos, neg, json_file) # merge_to_json(pos, neg, json_file)
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean.label") pos = os.path.join(ORIGINAL_DATA_BASE,'all_clean', "inst.all.pos.txt.clean.label")
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean.label") neg = os.path.join(ORIGINAL_DATA_BASE, 'all_clean',"inst.all.neg.txt.clean.label")
json_file = os.path.join(CURRENT_DATA_BASE, "inst.all.") json_file = os.path.join(CURRENT_DATA_BASE, 'json',"inst.all.")
merge_to_json(pos, neg, json_file) merge_to_json(pos, neg, json_file)

View File

@ -0,0 +1,49 @@
import r2pipe
from my_utils import setup_logger, multi_thread, THREAD_FULL
import os
from tqdm import tqdm
def get_all_from_exe(file):
# 获取基础块内的操作码序列
r2pipe_open = r2pipe.open(os.path.join(file), flags=['-2'])
try:
# 获取函数列表
r2pipe_open.cmd("aaa")
r2pipe_open.cmd('e arch=x86')
function_list = r2pipe_open.cmdj("aflj")
exe_op_count = []
for function in function_list:
function_op_count_list = []
if function['name'][:4] not in ['fcn.', 'loc.', 'main', 'entr']:
continue
block_list = r2pipe_open.cmdj("afbj @" + str(function['offset']))
for block in block_list:
# 获取基本块的反汇编指令
disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"]))
block_op_count = 0
if disasm:
print_flag = 1 if len(disasm) >= 723 else 0
for op in disasm:
if op["type"] == "invalid" or op["opcode"] == "invalid":
continue
if print_flag == 1:
print(op['disasm'])
block_op_count += 1
function_op_count_list.append(block_op_count)
exe_op_count.append(function_op_count_list)
logger.info(f"{file} {exe_op_count}")
except Exception as e:
logger.error(f"Error: get function list failed in {file} ,error {e}")
return False, file, e
r2pipe_open.quit()
return True, '', ''
if __name__ == '__main__':
logger = setup_logger('get_all_from_exe', '../../log/get_all_from_exe.log')
file = '/mnt/d/bishe/dataset/sample_benign'
file_list = os.listdir(file)
multi_thread(get_all_from_exe, ['/mnt/d/bishe/dataset/sample_benign/00125dcd81261701fcaaf84d0cb45d0e.exe'], thread_num=THREAD_FULL)

View File

@ -0,0 +1,18 @@
# Python运行shell脚本
import subprocess
import os
from my_utils import multi_thread
def run_shell(file_num):
com_line = f'cat ./neg_txt/inst.{file_num}.neg.txt | sort -n | uniq > ./neg_clean/inst.{file_num}.neg.txt.clean'
p = subprocess.Popen(com_line, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = p.communicate()
return stdout, stderr
if __name__ == '__main__':
os.chdir('../../dataset/all')
result = multi_thread(run_shell, range(os.cpu_count()))

View File

@ -1,18 +1,16 @@
import os import os
# ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs" # ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
ORIGINAL_DATA_BASE = "/home/ming/malware/data/malasm_inst_pairs" ORIGINAL_DATA_BASE = "/mnt/d/bishe/Inst2Vec/dataset/all"
CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert" CURRENT_DATA_BASE = "/mnt/d/bishe/Inst2Vec/dataset/all"
def read_file(filename): def read_file(filename):
print("Reading data from {}...".format(filename))
with open(filename, "r", encoding="utf-8") as fin: with open(filename, "r", encoding="utf-8") as fin:
return fin.readlines() return fin.readlines()
def write_file(sents, filename): def write_file(sents, filename):
print("Writing data to {}...".format(filename))
with open(filename, "w", encoding="utf-8") as fout: with open(filename, "w", encoding="utf-8") as fout:
for sent in sents: for sent in sents:
fout.write(sent) fout.write(sent)

View File

@ -2,7 +2,6 @@ import argparse
import os import os
from itertools import chain from itertools import chain
from datasets import load_dataset
from tokenizers import Tokenizer from tokenizers import Tokenizer
from tokenizers.models import WordLevel from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace from tokenizers.pre_tokenizers import Whitespace
@ -11,7 +10,7 @@ from tokenizers.trainers import WordLevelTrainer
from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
BASE_PATH = "/home/ming/malware/inst2vec_bert/bert/"
def parse_args(): def parse_args():
@ -27,7 +26,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--padding_length", "--padding_length",
type=int, type=int,
default=32, default=50,
help="The length will be padded to by the tokenizer.", help="The length will be padded to by the tokenizer.",
) )
args = parser.parse_args() args = parser.parse_args()
@ -99,8 +98,8 @@ def main(tokenizer_file=""):
# dataset = load_dataset("json", data_files=json_files, field="data") # dataset = load_dataset("json", data_files=json_files, field="data")
text_files = [ text_files = [
os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)) os.path.join(ORIGINAL_DATA_BASE, f'{group}_clean',f"inst.{i}.{group}.txt.clean")
for group in ["pos", "neg"] for i in range(10) for group in ["pos", "neg"] for i in range(32)
] ]
dataset = [] dataset = []