diff --git a/my_run_mlm_no_trainer.py b/my_run_mlm_no_trainer.py index 15261b8..e1864f4 100644 --- a/my_run_mlm_no_trainer.py +++ b/my_run_mlm_no_trainer.py @@ -43,6 +43,8 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig, from my_data_collator import MyDataCollatorForPreTraining from process_data.utils import CURRENT_DATA_BASE +HIDDEN_SIZE=256 + logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -187,7 +189,7 @@ def main(): # field="data", # ) train_files = [ - os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2) + os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in [0,1,2,3,4,5,6] # ,8,9,10,11,12,13] ] valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" raw_datasets = load_dataset( @@ -210,10 +212,10 @@ def main(): # we use a much smaller BERT, config is: config = BertConfig( vocab_size=tokenizer.get_vocab_size(), - hidden_size=96, + hidden_size=HIDDEN_SIZE, num_hidden_layers=4, - num_attention_heads=12, - intermediate_size=384, + num_attention_heads=8, + intermediate_size=4 * HIDDEN_SIZE, max_position_embeddings=32, ) @@ -230,7 +232,10 @@ def main(): def tokenize_function(examples): text = [tuple(sent) for sent in examples["text"]] encoded_inputs = {} + # try: results = tokenizer.encode_batch(text) + # except: + # return None encoded_inputs["input_ids"] = [result.ids for result in results] encoded_inputs["token_type_ids"] = [result.type_ids for result in results] encoded_inputs["special_tokens_mask"] = [ @@ -253,7 +258,7 @@ def main(): batched=True, num_proc=args.preprocessing_num_workers, remove_columns=column_names, - load_from_cache_file=False, + load_from_cache_file=True, ) train_dataset = tokenized_datasets["train"] diff --git a/process_data/check_length.py b/process_data/check_length.py index e803142..7276d5d 100644 --- a/process_data/check_length.py +++ b/process_data/check_length.py @@ -1,5 +1,5 @@ import os - +import pdb from utils import ORIGINAL_DATA_BASE, read_file @@ -7,7 +7,7 @@ def check(filename): sents = read_file(filename) result = 0 for sent in sents: - result = max(result, len(sent[-1].replace("\t", " ").split())) + result = max(result, len(sent[:-1].replace("\t", " ").split())) print("The longest sentence in {} has {} words".format(filename, result)) return result @@ -15,10 +15,10 @@ def check(filename): def main(): longest = 0 # for i in range(6): - for i in [1]: + for i in range(10): for group in ("pos", "neg"): filename = os.path.join( - ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group) + ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group) ) longest = max(check(filename), longest) print("The longest sentence in all files has {} words.".format(longest)) diff --git a/process_data/clean.py b/process_data/clean.py new file mode 100644 index 0000000..134dd6c --- /dev/null +++ b/process_data/clean.py @@ -0,0 +1,21 @@ +from utils import ORIGINAL_DATA_BASE, read_file, write_file +from tqdm import tqdm +import os + +def remove(pos_file, neg_file): + pos = read_file(pos_file) + neg = read_file(neg_file) + rets = [] + for n in tqdm(neg): + if n in pos: + continue + rets.append(n) + write_file(rets, neg_file) + +def main(): + pos_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean") + neg_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean") + remove(pos_file, neg_file) + +if __name__ == "__main__": + main() diff --git a/process_data/convert_space_format.py b/process_data/convert_space_format.py index 43799ed..60f262c 100644 --- a/process_data/convert_space_format.py +++ b/process_data/convert_space_format.py @@ -19,8 +19,8 @@ def convert(fin, fout): def main(): # for i in range(6): - for i in [1]: - fin = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) + for i in range(10): + fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i)) fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i)) convert(fin, fout) diff --git a/process_data/count_word_for_vocab.py b/process_data/count_word_for_vocab.py index 543590b..978c2c9 100644 --- a/process_data/count_word_for_vocab.py +++ b/process_data/count_word_for_vocab.py @@ -41,12 +41,12 @@ def counter(filename): def main(): cnt = set() # for i in range(6): - for i in [1]: + for i in range(10): for group in ["pos", "neg"]: filename = os.path.join( - ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group) + ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group) ) - cnt += counter(filename) + cnt = cnt.union(counter(filename)) print("There are {} charcters in all files".format(len(cnt))) diff --git a/process_data/create_negative_examples.py b/process_data/create_negative_examples.py index 6ca9a2a..25afb56 100644 --- a/process_data/create_negative_examples.py +++ b/process_data/create_negative_examples.py @@ -22,10 +22,12 @@ def create(pos, neg, tgt): def main(): # for i in range(6): - for i in [1]: - j = (i + 1) % 6 - pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) - neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) + for i in range(10): + j = (i + 1) % 10 + # neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) + # pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) + pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(i)) + neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(j)) tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i)) create(pos, neg, tgt) diff --git a/process_data/merge_examples_to_json.py b/process_data/merge_examples_to_json.py index 9a49068..9ffb66f 100644 --- a/process_data/merge_examples_to_json.py +++ b/process_data/merge_examples_to_json.py @@ -26,11 +26,11 @@ def write_worker(sents, json_file, index): def merge_to_json(pos, neg, json_file): sents = read_file(pos) - p = Pool(36) + p = Pool(6) - for i in range(64): + for i in range(6): p.apply_async( - write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,) + write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 2+i,) ) print("Waiting for all sub-processes done...") p.close() @@ -55,11 +55,11 @@ def merge_to_json(pos, neg, json_file): sents = read_file(neg) - p = Pool(8) + p = Pool(6) - for i in range(64): + for i in range(6): p.apply_async( - write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 64 + i,) + write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 8 + i,) ) print("Waiting for all sub-processes done...") p.close() @@ -80,11 +80,15 @@ def merge_to_json(pos, neg, json_file): def main(): # for i in range(6): - for i in [1]: - pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i)) - neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i)) - json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i)) - merge_to_json(pos, neg, json_file) + # for i in range(6): + # pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i)) + # neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i)) + # json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i)) + # merge_to_json(pos, neg, json_file) + pos = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean.label") + neg = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean.label") + json_file = os.path.join(CURRENT_DATA_BASE, "inst.all.") + merge_to_json(pos, neg, json_file) if __name__ == "__main__": diff --git a/process_data/readme.md b/process_data/readme.md index 7f43fc9..fe0d0f4 100644 --- a/process_data/readme.md +++ b/process_data/readme.md @@ -31,7 +31,10 @@ We process the files containing negative examples similarly. cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean ``` -Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`. +Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`, or + ### 5. convert to json format diff --git a/process_data/utils.py b/process_data/utils.py index 6183db4..fc819af 100644 --- a/process_data/utils.py +++ b/process_data/utils.py @@ -1,6 +1,7 @@ import os -ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs" +# ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs" +ORIGINAL_DATA_BASE = "/home/ming/malware/data/malasm_inst_pairs" CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert" diff --git a/train_my_tokenizer.py b/train_my_tokenizer.py index e31eb4f..dfe7e6a 100644 --- a/train_my_tokenizer.py +++ b/train_my_tokenizer.py @@ -99,8 +99,8 @@ def main(tokenizer_file=""): # dataset = load_dataset("json", data_files=json_files, field="data") text_files = [ - os.path.join(ORIGINAL_DATA_BASE, "inst.1.{}.txt".format(group)) - for group in ["pos", "neg"] + os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)) + for group in ["pos", "neg"] for i in range(10) ] dataset = [] @@ -121,4 +121,4 @@ def main(tokenizer_file=""): if __name__ == "__main__": - main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json")) + main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json"))