final version of Inst2Vec

This commit is contained in:
zyr 2021-06-30 19:20:12 +08:00
parent fe2de236b5
commit c050cff9f5
10 changed files with 70 additions and 34 deletions

View File

@ -43,6 +43,8 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
from my_data_collator import MyDataCollatorForPreTraining from my_data_collator import MyDataCollatorForPreTraining
from process_data.utils import CURRENT_DATA_BASE from process_data.utils import CURRENT_DATA_BASE
HIDDEN_SIZE=256
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@ -187,7 +189,7 @@ def main():
# field="data", # field="data",
# ) # )
train_files = [ train_files = [
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2) os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in [0,1,2,3,4,5,6] # ,8,9,10,11,12,13]
] ]
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json" valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
raw_datasets = load_dataset( raw_datasets = load_dataset(
@ -210,10 +212,10 @@ def main():
# we use a much smaller BERT, config is: # we use a much smaller BERT, config is:
config = BertConfig( config = BertConfig(
vocab_size=tokenizer.get_vocab_size(), vocab_size=tokenizer.get_vocab_size(),
hidden_size=96, hidden_size=HIDDEN_SIZE,
num_hidden_layers=4, num_hidden_layers=4,
num_attention_heads=12, num_attention_heads=8,
intermediate_size=384, intermediate_size=4 * HIDDEN_SIZE,
max_position_embeddings=32, max_position_embeddings=32,
) )
@ -230,7 +232,10 @@ def main():
def tokenize_function(examples): def tokenize_function(examples):
text = [tuple(sent) for sent in examples["text"]] text = [tuple(sent) for sent in examples["text"]]
encoded_inputs = {} encoded_inputs = {}
# try:
results = tokenizer.encode_batch(text) results = tokenizer.encode_batch(text)
# except:
# return None
encoded_inputs["input_ids"] = [result.ids for result in results] encoded_inputs["input_ids"] = [result.ids for result in results]
encoded_inputs["token_type_ids"] = [result.type_ids for result in results] encoded_inputs["token_type_ids"] = [result.type_ids for result in results]
encoded_inputs["special_tokens_mask"] = [ encoded_inputs["special_tokens_mask"] = [
@ -253,7 +258,7 @@ def main():
batched=True, batched=True,
num_proc=args.preprocessing_num_workers, num_proc=args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=False, load_from_cache_file=True,
) )
train_dataset = tokenized_datasets["train"] train_dataset = tokenized_datasets["train"]

View File

@ -1,5 +1,5 @@
import os import os
import pdb
from utils import ORIGINAL_DATA_BASE, read_file from utils import ORIGINAL_DATA_BASE, read_file
@ -7,7 +7,7 @@ def check(filename):
sents = read_file(filename) sents = read_file(filename)
result = 0 result = 0
for sent in sents: for sent in sents:
result = max(result, len(sent[-1].replace("\t", " ").split())) result = max(result, len(sent[:-1].replace("\t", " ").split()))
print("The longest sentence in {} has {} words".format(filename, result)) print("The longest sentence in {} has {} words".format(filename, result))
return result return result
@ -15,10 +15,10 @@ def check(filename):
def main(): def main():
longest = 0 longest = 0
# for i in range(6): # for i in range(6):
for i in [1]: for i in range(10):
for group in ("pos", "neg"): for group in ("pos", "neg"):
filename = os.path.join( filename = os.path.join(
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group) ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
) )
longest = max(check(filename), longest) longest = max(check(filename), longest)
print("The longest sentence in all files has {} words.".format(longest)) print("The longest sentence in all files has {} words.".format(longest))

21
process_data/clean.py Normal file
View File

@ -0,0 +1,21 @@
from utils import ORIGINAL_DATA_BASE, read_file, write_file
from tqdm import tqdm
import os
def remove(pos_file, neg_file):
pos = read_file(pos_file)
neg = read_file(neg_file)
rets = []
for n in tqdm(neg):
if n in pos:
continue
rets.append(n)
write_file(rets, neg_file)
def main():
pos_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean")
neg_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean")
remove(pos_file, neg_file)
if __name__ == "__main__":
main()

View File

@ -19,8 +19,8 @@ def convert(fin, fout):
def main(): def main():
# for i in range(6): # for i in range(6):
for i in [1]: for i in range(10):
fin = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i))
fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i)) fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
convert(fin, fout) convert(fin, fout)

View File

@ -41,12 +41,12 @@ def counter(filename):
def main(): def main():
cnt = set() cnt = set()
# for i in range(6): # for i in range(6):
for i in [1]: for i in range(10):
for group in ["pos", "neg"]: for group in ["pos", "neg"]:
filename = os.path.join( filename = os.path.join(
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group) ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
) )
cnt += counter(filename) cnt = cnt.union(counter(filename))
print("There are {} charcters in all files".format(len(cnt))) print("There are {} charcters in all files".format(len(cnt)))

View File

@ -22,10 +22,12 @@ def create(pos, neg, tgt):
def main(): def main():
# for i in range(6): # for i in range(6):
for i in [1]: for i in range(10):
j = (i + 1) % 6 j = (i + 1) % 10
pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) # neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) # pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(i))
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(j))
tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i)) tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i))
create(pos, neg, tgt) create(pos, neg, tgt)

View File

@ -26,11 +26,11 @@ def write_worker(sents, json_file, index):
def merge_to_json(pos, neg, json_file): def merge_to_json(pos, neg, json_file):
sents = read_file(pos) sents = read_file(pos)
p = Pool(36) p = Pool(6)
for i in range(64): for i in range(6):
p.apply_async( p.apply_async(
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,) write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 2+i,)
) )
print("Waiting for all sub-processes done...") print("Waiting for all sub-processes done...")
p.close() p.close()
@ -55,11 +55,11 @@ def merge_to_json(pos, neg, json_file):
sents = read_file(neg) sents = read_file(neg)
p = Pool(8) p = Pool(6)
for i in range(64): for i in range(6):
p.apply_async( p.apply_async(
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 64 + i,) write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 8 + i,)
) )
print("Waiting for all sub-processes done...") print("Waiting for all sub-processes done...")
p.close() p.close()
@ -80,10 +80,14 @@ def merge_to_json(pos, neg, json_file):
def main(): def main():
# for i in range(6): # for i in range(6):
for i in [1]: # for i in range(6):
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i)) # pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i)) # neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i)) # json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
# merge_to_json(pos, neg, json_file)
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean.label")
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean.label")
json_file = os.path.join(CURRENT_DATA_BASE, "inst.all.")
merge_to_json(pos, neg, json_file) merge_to_json(pos, neg, json_file)

View File

@ -31,7 +31,10 @@ We process the files containing negative examples similarly.
cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean
``` ```
Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`. Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`, or
<!-- ```shell
grep -v -f inst.all.pos.txt.clean inst.all.neg.txt.clean > inst.all.neg.txt.clean
``` -->
### 5. convert to json format ### 5. convert to json format

View File

@ -1,6 +1,7 @@
import os import os
ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs" # ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
ORIGINAL_DATA_BASE = "/home/ming/malware/data/malasm_inst_pairs"
CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert" CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert"

View File

@ -99,8 +99,8 @@ def main(tokenizer_file=""):
# dataset = load_dataset("json", data_files=json_files, field="data") # dataset = load_dataset("json", data_files=json_files, field="data")
text_files = [ text_files = [
os.path.join(ORIGINAL_DATA_BASE, "inst.1.{}.txt".format(group)) os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group))
for group in ["pos", "neg"] for group in ["pos", "neg"] for i in range(10)
] ]
dataset = [] dataset = []
@ -121,4 +121,4 @@ def main(tokenizer_file=""):
if __name__ == "__main__": if __name__ == "__main__":
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json")) main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json"))