From 979573651d6d48c10ab69f244a5961eeb13200ea Mon Sep 17 00:00:00 2001 From: huihun <781165206@qq.com> Date: Thu, 11 Apr 2024 16:43:57 +0800 Subject: [PATCH] first commit --- my_run_mlm_no_trainer.py | 24 +++--- my_utils.py | 65 ++++++++++++++ obtain_inst_vec.py | 13 +-- process_data/check_length.py | 4 +- process_data/clean.py | 58 ++++++++++--- process_data/convert_space_format.py | 10 +-- process_data/create_negative_examples.py | 27 +++--- process_data/exe2all.py | 90 ++++++++++++++++++++ process_data/merge_examples_to_json.py | 22 +++-- process_data/r2test/max_op_count.py | 49 +++++++++++ process_data/shell_com/remove_repete_line.py | 18 ++++ process_data/utils.py | 6 +- train_my_tokenizer.py | 9 +- 13 files changed, 332 insertions(+), 63 deletions(-) create mode 100644 my_utils.py create mode 100644 process_data/exe2all.py create mode 100644 process_data/r2test/max_op_count.py create mode 100644 process_data/shell_com/remove_repete_line.py diff --git a/my_run_mlm_no_trainer.py b/my_run_mlm_no_trainer.py index e1864f4..1e50fee 100644 --- a/my_run_mlm_no_trainer.py +++ b/my_run_mlm_no_trainer.py @@ -43,7 +43,7 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig, from my_data_collator import MyDataCollatorForPreTraining from process_data.utils import CURRENT_DATA_BASE -HIDDEN_SIZE=256 +HIDDEN_SIZE=16 logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) @@ -57,13 +57,13 @@ def parse_args(): parser.add_argument( "--per_device_train_batch_size", type=int, - default=16, + default=2048, help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--per_device_eval_batch_size", type=int, - default=64, + default=16384, help="Batch size (per device) for the evaluation dataloader.", ) parser.add_argument( @@ -84,7 +84,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=None, + default=150000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -104,19 +104,19 @@ def parse_args(): parser.add_argument( "--num_warmup_steps", type=int, - default=0, + default=4000, help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--output_dir", type=str, default=None, help="Where to store the final model." + "--output_dir", type=str, default='./dataset/out', help="Where to store the final model." ) parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." + "--seed", type=int, default=1234, help="A seed for reproducible training." ) parser.add_argument( "--preprocessing_num_workers", type=int, - default=None, + default=32, help="The number of processes to use for the preprocessing.", ) parser.add_argument( @@ -189,9 +189,9 @@ def main(): # field="data", # ) train_files = [ - os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in [0,1,2,3,4,5,6] # ,8,9,10,11,12,13] + os.path.join(CURRENT_DATA_BASE, 'json',f"inst.all.{i}.json") for i in range(8) ] - valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" + valid_file = os.path.join(CURRENT_DATA_BASE, 'test', "inst.json") raw_datasets = load_dataset( "json", data_files={"train": train_files, "validation": valid_file,}, @@ -205,7 +205,7 @@ def main(): # NOTE: have to promise the `length` here is consistent with the one used in `train_my_tokenizer.py` tokenizer.enable_padding( - pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32 + pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50 ) # NOTE: `max_position_embeddings` here should be consistent with `length` above @@ -216,7 +216,7 @@ def main(): num_hidden_layers=4, num_attention_heads=8, intermediate_size=4 * HIDDEN_SIZE, - max_position_embeddings=32, + max_position_embeddings=50, ) # initalize a new BERT for pre-training diff --git a/my_utils.py b/my_utils.py new file mode 100644 index 0000000..1a9dead --- /dev/null +++ b/my_utils.py @@ -0,0 +1,65 @@ +import logging +import os + +""" +日志工具 + +使用方法: + logger = setup_logger(日志记录器的实例名字, 日志文件目录) +""" +def setup_logger(name, log_file, level=logging.INFO): + """Function setup as many loggers as you want""" + if not os.path.exists(os.path.dirname(log_file)): + os.makedirs(os.path.dirname(log_file)) + + formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + + handler = logging.FileHandler(log_file) + handler.setFormatter(formatter) + + # 控制台是否输出日志信息 + # stream_handler = logging.StreamHandler() + # stream_handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.addHandler(handler) + # 控制台 + # logger.addHandler(stream_handler) + + # 刷新原有log文件 + + if os.path.exists(log_file): + open(log_file, 'w').close() + + return logger + + +""" +多线程工具 +""" +THREAD_FULL = os.cpu_count() +THREAD_HALF = int(os.cpu_count() / 2) +def multi_thread(func, args, thread_num=THREAD_FULL): + """ + 多线程执行函数 + :param func: 函数 + :param args: list函数参数 + :param thread_num: 线程数 + :return: + """ + import concurrent.futures + from tqdm import tqdm + logger = setup_logger('multi_thread', './multi_thread.log') + result = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor: + futures_to_args = { + executor.submit(func, arg): arg for arg in args + } + for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(args)): + try: + result.append(future.result()) + except Exception as exc: + logger.error('%r generated an exception: %s' % (futures_to_args[future], exc)) + return result + diff --git a/obtain_inst_vec.py b/obtain_inst_vec.py index 174f53b..6f08dff 100644 --- a/obtain_inst_vec.py +++ b/obtain_inst_vec.py @@ -11,6 +11,7 @@ import torch import transformers from accelerate import Accelerator from datasets import load_dataset +from torch import nn from torch.nn import DataParallel from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm @@ -33,8 +34,8 @@ from transformers import ( from my_data_collator import MyDataCollatorForPreTraining from process_data.utils import CURRENT_DATA_BASE -model_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.bin") -config_file = os.path.join(CURRENT_DATA_BASE, "bert-L2-H8.config.json") +model_file = os.path.join(CURRENT_DATA_BASE, 'out' ,"pytorch_model.bin") +config_file = os.path.join(CURRENT_DATA_BASE, 'out' ,"config.json") tokenizer_file = os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json") @@ -48,7 +49,7 @@ def load_model(): tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) tokenizer.enable_padding( - pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=32 + pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=50 ) print("Load tokenizer successfully !") return model, tokenizer @@ -120,8 +121,10 @@ def generate_inst_vec(inst, method="mean"): def main(): - inst = ["mov ebp esp" for _ in range(8)] - print(generate_inst_vec(inst).shape) + inst = ['adc byte [ ebp - 0x74 ] cl','mov dh 0x79','adc eax 1'] + tmp = generate_inst_vec(inst, method="mean") + print(tmp.shape) + print(tmp.detach().numpy()) if __name__ == "__main__": diff --git a/process_data/check_length.py b/process_data/check_length.py index 7276d5d..bcdd4be 100644 --- a/process_data/check_length.py +++ b/process_data/check_length.py @@ -15,10 +15,10 @@ def check(filename): def main(): longest = 0 # for i in range(6): - for i in range(10): + for i in range(32): for group in ("pos", "neg"): filename = os.path.join( - ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group) + ORIGINAL_DATA_BASE, f'{group}_clean',f"inst.{i}.{group}.txt.clean" ) longest = max(check(filename), longest) print("The longest sentence in all files has {} words.".format(longest)) diff --git a/process_data/clean.py b/process_data/clean.py index 134dd6c..34af5ad 100644 --- a/process_data/clean.py +++ b/process_data/clean.py @@ -1,21 +1,57 @@ from utils import ORIGINAL_DATA_BASE, read_file, write_file from tqdm import tqdm import os +from my_utils import multi_thread, setup_logger +import concurrent.futures -def remove(pos_file, neg_file): - pos = read_file(pos_file) - neg = read_file(neg_file) - rets = [] - for n in tqdm(neg): - if n in pos: + + + +def remove(neg_list, pos_file): + ret = [] + for neg in neg_list: + if neg in pos_file: continue - rets.append(n) - write_file(rets, neg_file) + ret.append(neg) + return ret + +def split_list_evenly(lst, n): + # 计算每块的大小(整除,最后一块可能略短) + chunk_size = len(lst) // n + # 最后一块可能需要额外的元素 + last_chunk_size = len(lst) % n + # 初始化空列表存放切片后的块 + chunks = [] + + # 对于前n-1块 + for i in range(0, (n - (last_chunk_size > 0)), chunk_size): + chunks.append(lst[i:i + chunk_size]) + + # 添加最后一个可能稍短的块 + if last_chunk_size > 0: + chunks.append(lst[(n - (last_chunk_size > 0)) * chunk_size:]) + + return chunks def main(): - pos_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean") - neg_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean") - remove(pos_file, neg_file) + file = os.path.join('../dataset/all/all_clean') + pos_file = read_file(os.path.join(file, "inst.all.pos.txt.clean")) + neg_file = split_list_evenly(read_file(os.path.join(file, "inst.all.neg.txt.clean")), int(os.cpu_count()*1000)) + print(len(neg_file)) + logger = setup_logger('remove', '../out/remove.log') + result = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + print('start build task.') + futures_to_args = { + executor.submit(remove, neg_list, pos_file): neg_list for neg_list in neg_file + } + print('start run task.') + for future in tqdm(concurrent.futures.as_completed(futures_to_args), total=len(futures_to_args)): + try: + result.extend(future.result()) + except Exception as exc: + logger.error(exc) + write_file(result, os.path.join(file, "inst.all.neg.txt.clean")) if __name__ == "__main__": main() diff --git a/process_data/convert_space_format.py b/process_data/convert_space_format.py index 60f262c..e9b1584 100644 --- a/process_data/convert_space_format.py +++ b/process_data/convert_space_format.py @@ -19,11 +19,11 @@ def convert(fin, fout): def main(): # for i in range(6): - for i in range(10): - fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i)) - fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i)) - convert(fin, fout) - + # for i in range(10): + # fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i)) + # fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i)) + # convert(fin, fout) + convert(os.path.join('../dataset/all/win.all'), os.path.join('../dataset/all/inst.pos.txt')) if __name__ == "__main__": main() diff --git a/process_data/create_negative_examples.py b/process_data/create_negative_examples.py index 25afb56..4306922 100644 --- a/process_data/create_negative_examples.py +++ b/process_data/create_negative_examples.py @@ -5,32 +5,35 @@ from tqdm import tqdm from utils import ORIGINAL_DATA_BASE, read_file +from my_utils import multi_thread + def create(pos, neg, tgt): pos_sents = read_file(pos) - neg_sents = read_file(neg) neg_length = len(neg_sents) - print("Start writing negative examples to {}...".format(tgt)) with open(tgt, "w", encoding="utf-8") as fout: for sent in tqdm(pos_sents): first = sent.split("\t")[0] index = randint(0, neg_length - 1) - pair = neg_sents[index].split("\t")[randint(0, 1)].replace("\n", "") + pair = neg_sents[index].split("\t") + pair = pair[randint(0, 1)] + pair = pair.replace("\n", "") fout.write(first + "\t" + pair + "\n") def main(): # for i in range(6): - for i in range(10): - j = (i + 1) % 10 - # neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) - # pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) - pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(i)) - neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(j)) - tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i)) + # neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) + # pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) + file = os.path.join("../dataset/all/pos_clean") + out_file = os.path.join("../dataset/all/neg_txt") + os.makedirs(out_file, exist_ok=True) + for i in tqdm(range(os.cpu_count()), total=os.cpu_count() ): + j = (i + 1) % os.cpu_count() + pos = os.path.join(file, f"inst.{i}.pos.txt.clean") + neg = os.path.join(file, f"inst.{j}.pos.txt.clean") + tgt = os.path.join(out_file, f"inst.{i}.neg.txt") create(pos, neg, tgt) - - if __name__ == "__main__": main() diff --git a/process_data/exe2all.py b/process_data/exe2all.py new file mode 100644 index 0000000..15b8bff --- /dev/null +++ b/process_data/exe2all.py @@ -0,0 +1,90 @@ +import os +import r2pipe +from my_utils import setup_logger +import concurrent.futures +from tqdm import tqdm + + +def extract_opcode(disasm_text): + """ + 从反汇编文本中提取操作码和操作数 + 正则表达式用于匹配操作码和操作数,考虑到操作数可能包含空格和逗号 + """ + op_list = disasm_text.split(' ') + res = [] + for item in op_list: + item = item.strip().replace(',', '') + if '[' in item: + res.append('[') + res.append(item.replace('[', '').replace(']', '')) + if ']' in item: + res.append(']') + return res + + +def double_exe_op_list(op_list): + double_exe_op_list = [] + for i in range(len(op_list) - 1 ): + double_exe_op_list.append((op_list[i], op_list[i + 1])) + return double_exe_op_list + +def get_all_from_exe(file, out_file): + # 获取基础块内的操作码序列 + r2pipe_open = r2pipe.open(os.path.join(file), flags=['-2']) + with open(out_file, 'a') as f: + try: + # 获取函数列表 + r2pipe_open.cmd("aaa") + r2pipe_open.cmd('e arch=x86') + function_list = r2pipe_open.cmdj("aflj") + exe_op_list = [] + for function in function_list: + if function['name'][:4] not in ['fcn.', 'loc.', 'main', 'entr']: + continue + block_list = r2pipe_open.cmdj("afbj @" + str(function['offset'])) + for block in block_list: + # 获取基本块的反汇编指令 + disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"])) + if disasm: + for op in disasm: + if op["type"] == "invalid" or op["opcode"] == "invalid": + continue + op_list = extract_opcode(op["disasm"]) + exe_op_list.append(' '.join(op_list)) + exe_op_list = double_exe_op_list(exe_op_list) + for op_str_before, op_str_after in exe_op_list: + f.write(op_str_before + '\t' + op_str_after + '\n') + except Exception as e: + logger.error(f"Error: get function list failed in {file} ,error {e}") + return False, file, e + r2pipe_open.quit() + return True, '', '' + + +def main(): + sample_file_path = '/mnt/d/bishe/dataset/sample_malware/' + sample_file_list = os.listdir(sample_file_path)[:1000] + out_file_path = '../dataset/all' + with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + print(f"start with {os.cpu_count()} works.") + future_to_args = { + executor.submit(get_all_from_exe, + os.path.join(sample_file_path, sample_file_list[file_index]), + os.path.join(out_file_path, str(f'inst.{file_index%os.cpu_count()}.pos.txt')) + ): + file_index for file_index in range(len(sample_file_list)) + } + for future in tqdm(concurrent.futures.as_completed(future_to_args), total=len(sample_file_list)): + try: + future.result() + if not future.result()[0]: + print(f"Error file: {future.result()[1]}, msg {future.result()[2]}") + except Exception as exc: + logger.error(f"Error: {exc}") + print(f"Error: {exc}") + + +if __name__ == '__main__': + logger = setup_logger('exe2all', '../log/exe2all.log') + main() + diff --git a/process_data/merge_examples_to_json.py b/process_data/merge_examples_to_json.py index 9ffb66f..08054ee 100644 --- a/process_data/merge_examples_to_json.py +++ b/process_data/merge_examples_to_json.py @@ -7,8 +7,8 @@ from tqdm import tqdm from utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file -BASE = 4600000 - +# BASE = 4600000 +BASE = 46000 def write_worker(sents, json_file, index): examples = [] @@ -24,13 +24,21 @@ def write_worker(sents, json_file, index): def merge_to_json(pos, neg, json_file): + + + + + + + + sents = read_file(pos) p = Pool(6) for i in range(6): p.apply_async( - write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 2+i,) + write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,) ) print("Waiting for all sub-processes done...") p.close() @@ -59,7 +67,7 @@ def merge_to_json(pos, neg, json_file): for i in range(6): p.apply_async( - write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 8 + i,) + write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 6 + i,) ) print("Waiting for all sub-processes done...") p.close() @@ -85,9 +93,9 @@ def main(): # neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i)) # json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i)) # merge_to_json(pos, neg, json_file) - pos = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean.label") - neg = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean.label") - json_file = os.path.join(CURRENT_DATA_BASE, "inst.all.") + pos = os.path.join(ORIGINAL_DATA_BASE,'all_clean', "inst.all.pos.txt.clean.label") + neg = os.path.join(ORIGINAL_DATA_BASE, 'all_clean',"inst.all.neg.txt.clean.label") + json_file = os.path.join(CURRENT_DATA_BASE, 'json',"inst.all.") merge_to_json(pos, neg, json_file) diff --git a/process_data/r2test/max_op_count.py b/process_data/r2test/max_op_count.py new file mode 100644 index 0000000..83d8b34 --- /dev/null +++ b/process_data/r2test/max_op_count.py @@ -0,0 +1,49 @@ +import r2pipe +from my_utils import setup_logger, multi_thread, THREAD_FULL +import os +from tqdm import tqdm + +def get_all_from_exe(file): + # 获取基础块内的操作码序列 + r2pipe_open = r2pipe.open(os.path.join(file), flags=['-2']) + try: + # 获取函数列表 + r2pipe_open.cmd("aaa") + r2pipe_open.cmd('e arch=x86') + function_list = r2pipe_open.cmdj("aflj") + exe_op_count = [] + for function in function_list: + function_op_count_list = [] + if function['name'][:4] not in ['fcn.', 'loc.', 'main', 'entr']: + continue + block_list = r2pipe_open.cmdj("afbj @" + str(function['offset'])) + + for block in block_list: + # 获取基本块的反汇编指令 + disasm = r2pipe_open.cmdj("pdj " + str(block["ninstr"]) + " @" + str(block["addr"])) + block_op_count = 0 + if disasm: + print_flag = 1 if len(disasm) >= 723 else 0 + for op in disasm: + if op["type"] == "invalid" or op["opcode"] == "invalid": + continue + if print_flag == 1: + print(op['disasm']) + block_op_count += 1 + function_op_count_list.append(block_op_count) + exe_op_count.append(function_op_count_list) + + logger.info(f"{file} {exe_op_count}") + + + except Exception as e: + logger.error(f"Error: get function list failed in {file} ,error {e}") + return False, file, e + r2pipe_open.quit() + return True, '', '' + +if __name__ == '__main__': + logger = setup_logger('get_all_from_exe', '../../log/get_all_from_exe.log') + file = '/mnt/d/bishe/dataset/sample_benign' + file_list = os.listdir(file) + multi_thread(get_all_from_exe, ['/mnt/d/bishe/dataset/sample_benign/00125dcd81261701fcaaf84d0cb45d0e.exe'], thread_num=THREAD_FULL) \ No newline at end of file diff --git a/process_data/shell_com/remove_repete_line.py b/process_data/shell_com/remove_repete_line.py new file mode 100644 index 0000000..a978702 --- /dev/null +++ b/process_data/shell_com/remove_repete_line.py @@ -0,0 +1,18 @@ +# Python运行shell脚本 +import subprocess +import os +from my_utils import multi_thread + + + +def run_shell(file_num): + com_line = f'cat ./neg_txt/inst.{file_num}.neg.txt | sort -n | uniq > ./neg_clean/inst.{file_num}.neg.txt.clean' + p = subprocess.Popen(com_line, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = p.communicate() + return stdout, stderr + +if __name__ == '__main__': + os.chdir('../../dataset/all') + result = multi_thread(run_shell, range(os.cpu_count())) + + diff --git a/process_data/utils.py b/process_data/utils.py index fc819af..e8b47d3 100644 --- a/process_data/utils.py +++ b/process_data/utils.py @@ -1,18 +1,16 @@ import os # ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs" -ORIGINAL_DATA_BASE = "/home/ming/malware/data/malasm_inst_pairs" -CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert" +ORIGINAL_DATA_BASE = "/mnt/d/bishe/Inst2Vec/dataset/all" +CURRENT_DATA_BASE = "/mnt/d/bishe/Inst2Vec/dataset/all" def read_file(filename): - print("Reading data from {}...".format(filename)) with open(filename, "r", encoding="utf-8") as fin: return fin.readlines() def write_file(sents, filename): - print("Writing data to {}...".format(filename)) with open(filename, "w", encoding="utf-8") as fout: for sent in sents: fout.write(sent) diff --git a/train_my_tokenizer.py b/train_my_tokenizer.py index dfe7e6a..fefff6c 100644 --- a/train_my_tokenizer.py +++ b/train_my_tokenizer.py @@ -2,7 +2,6 @@ import argparse import os from itertools import chain -from datasets import load_dataset from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace @@ -11,7 +10,7 @@ from tokenizers.trainers import WordLevelTrainer from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file -BASE_PATH = "/home/ming/malware/inst2vec_bert/bert/" + def parse_args(): @@ -27,7 +26,7 @@ def parse_args(): parser.add_argument( "--padding_length", type=int, - default=32, + default=50, help="The length will be padded to by the tokenizer.", ) args = parser.parse_args() @@ -99,8 +98,8 @@ def main(tokenizer_file=""): # dataset = load_dataset("json", data_files=json_files, field="data") text_files = [ - os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)) - for group in ["pos", "neg"] for i in range(10) + os.path.join(ORIGINAL_DATA_BASE, f'{group}_clean',f"inst.{i}.{group}.txt.clean") + for group in ["pos", "neg"] for i in range(32) ] dataset = []