final version of Inst2Vec
This commit is contained in:
parent
fe2de236b5
commit
c050cff9f5
@ -43,6 +43,8 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
|
|||||||
from my_data_collator import MyDataCollatorForPreTraining
|
from my_data_collator import MyDataCollatorForPreTraining
|
||||||
from process_data.utils import CURRENT_DATA_BASE
|
from process_data.utils import CURRENT_DATA_BASE
|
||||||
|
|
||||||
|
HIDDEN_SIZE=256
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
@ -187,7 +189,7 @@ def main():
|
|||||||
# field="data",
|
# field="data",
|
||||||
# )
|
# )
|
||||||
train_files = [
|
train_files = [
|
||||||
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2)
|
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in [0,1,2,3,4,5,6] # ,8,9,10,11,12,13]
|
||||||
]
|
]
|
||||||
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
||||||
raw_datasets = load_dataset(
|
raw_datasets = load_dataset(
|
||||||
@ -210,10 +212,10 @@ def main():
|
|||||||
# we use a much smaller BERT, config is:
|
# we use a much smaller BERT, config is:
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=tokenizer.get_vocab_size(),
|
vocab_size=tokenizer.get_vocab_size(),
|
||||||
hidden_size=96,
|
hidden_size=HIDDEN_SIZE,
|
||||||
num_hidden_layers=4,
|
num_hidden_layers=4,
|
||||||
num_attention_heads=12,
|
num_attention_heads=8,
|
||||||
intermediate_size=384,
|
intermediate_size=4 * HIDDEN_SIZE,
|
||||||
max_position_embeddings=32,
|
max_position_embeddings=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -230,7 +232,10 @@ def main():
|
|||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
text = [tuple(sent) for sent in examples["text"]]
|
text = [tuple(sent) for sent in examples["text"]]
|
||||||
encoded_inputs = {}
|
encoded_inputs = {}
|
||||||
|
# try:
|
||||||
results = tokenizer.encode_batch(text)
|
results = tokenizer.encode_batch(text)
|
||||||
|
# except:
|
||||||
|
# return None
|
||||||
encoded_inputs["input_ids"] = [result.ids for result in results]
|
encoded_inputs["input_ids"] = [result.ids for result in results]
|
||||||
encoded_inputs["token_type_ids"] = [result.type_ids for result in results]
|
encoded_inputs["token_type_ids"] = [result.type_ids for result in results]
|
||||||
encoded_inputs["special_tokens_mask"] = [
|
encoded_inputs["special_tokens_mask"] = [
|
||||||
@ -253,7 +258,7 @@ def main():
|
|||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=False,
|
load_from_cache_file=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = tokenized_datasets["train"]
|
train_dataset = tokenized_datasets["train"]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import pdb
|
||||||
from utils import ORIGINAL_DATA_BASE, read_file
|
from utils import ORIGINAL_DATA_BASE, read_file
|
||||||
|
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ def check(filename):
|
|||||||
sents = read_file(filename)
|
sents = read_file(filename)
|
||||||
result = 0
|
result = 0
|
||||||
for sent in sents:
|
for sent in sents:
|
||||||
result = max(result, len(sent[-1].replace("\t", " ").split()))
|
result = max(result, len(sent[:-1].replace("\t", " ").split()))
|
||||||
print("The longest sentence in {} has {} words".format(filename, result))
|
print("The longest sentence in {} has {} words".format(filename, result))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -15,10 +15,10 @@ def check(filename):
|
|||||||
def main():
|
def main():
|
||||||
longest = 0
|
longest = 0
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in [1]:
|
for i in range(10):
|
||||||
for group in ("pos", "neg"):
|
for group in ("pos", "neg"):
|
||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group)
|
ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
|
||||||
)
|
)
|
||||||
longest = max(check(filename), longest)
|
longest = max(check(filename), longest)
|
||||||
print("The longest sentence in all files has {} words.".format(longest))
|
print("The longest sentence in all files has {} words.".format(longest))
|
||||||
|
21
process_data/clean.py
Normal file
21
process_data/clean.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from utils import ORIGINAL_DATA_BASE, read_file, write_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
|
||||||
|
def remove(pos_file, neg_file):
|
||||||
|
pos = read_file(pos_file)
|
||||||
|
neg = read_file(neg_file)
|
||||||
|
rets = []
|
||||||
|
for n in tqdm(neg):
|
||||||
|
if n in pos:
|
||||||
|
continue
|
||||||
|
rets.append(n)
|
||||||
|
write_file(rets, neg_file)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
pos_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean")
|
||||||
|
neg_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean")
|
||||||
|
remove(pos_file, neg_file)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -19,8 +19,8 @@ def convert(fin, fout):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in [1]:
|
for i in range(10):
|
||||||
fin = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i))
|
||||||
fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
|
fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
|
||||||
convert(fin, fout)
|
convert(fin, fout)
|
||||||
|
|
||||||
|
@ -41,12 +41,12 @@ def counter(filename):
|
|||||||
def main():
|
def main():
|
||||||
cnt = set()
|
cnt = set()
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in [1]:
|
for i in range(10):
|
||||||
for group in ["pos", "neg"]:
|
for group in ["pos", "neg"]:
|
||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group)
|
ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
|
||||||
)
|
)
|
||||||
cnt += counter(filename)
|
cnt = cnt.union(counter(filename))
|
||||||
print("There are {} charcters in all files".format(len(cnt)))
|
print("There are {} charcters in all files".format(len(cnt)))
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,10 +22,12 @@ def create(pos, neg, tgt):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in [1]:
|
for i in range(10):
|
||||||
j = (i + 1) % 6
|
j = (i + 1) % 10
|
||||||
pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
# neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
||||||
neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
# pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
||||||
|
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(i))
|
||||||
|
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(j))
|
||||||
tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i))
|
tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i))
|
||||||
create(pos, neg, tgt)
|
create(pos, neg, tgt)
|
||||||
|
|
||||||
|
@ -26,11 +26,11 @@ def write_worker(sents, json_file, index):
|
|||||||
def merge_to_json(pos, neg, json_file):
|
def merge_to_json(pos, neg, json_file):
|
||||||
sents = read_file(pos)
|
sents = read_file(pos)
|
||||||
|
|
||||||
p = Pool(36)
|
p = Pool(6)
|
||||||
|
|
||||||
for i in range(64):
|
for i in range(6):
|
||||||
p.apply_async(
|
p.apply_async(
|
||||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,)
|
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 2+i,)
|
||||||
)
|
)
|
||||||
print("Waiting for all sub-processes done...")
|
print("Waiting for all sub-processes done...")
|
||||||
p.close()
|
p.close()
|
||||||
@ -55,11 +55,11 @@ def merge_to_json(pos, neg, json_file):
|
|||||||
|
|
||||||
sents = read_file(neg)
|
sents = read_file(neg)
|
||||||
|
|
||||||
p = Pool(8)
|
p = Pool(6)
|
||||||
|
|
||||||
for i in range(64):
|
for i in range(6):
|
||||||
p.apply_async(
|
p.apply_async(
|
||||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 64 + i,)
|
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 8 + i,)
|
||||||
)
|
)
|
||||||
print("Waiting for all sub-processes done...")
|
print("Waiting for all sub-processes done...")
|
||||||
p.close()
|
p.close()
|
||||||
@ -80,11 +80,15 @@ def merge_to_json(pos, neg, json_file):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for i in range(6):
|
# for i in range(6):
|
||||||
for i in [1]:
|
# for i in range(6):
|
||||||
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
|
# pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
|
||||||
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
|
# neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
|
||||||
json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
|
# json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
|
||||||
merge_to_json(pos, neg, json_file)
|
# merge_to_json(pos, neg, json_file)
|
||||||
|
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean.label")
|
||||||
|
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean.label")
|
||||||
|
json_file = os.path.join(CURRENT_DATA_BASE, "inst.all.")
|
||||||
|
merge_to_json(pos, neg, json_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -31,7 +31,10 @@ We process the files containing negative examples similarly.
|
|||||||
cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean
|
cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean
|
||||||
```
|
```
|
||||||
|
|
||||||
Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`.
|
Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`, or
|
||||||
|
<!-- ```shell
|
||||||
|
grep -v -f inst.all.pos.txt.clean inst.all.neg.txt.clean > inst.all.neg.txt.clean
|
||||||
|
``` -->
|
||||||
|
|
||||||
|
|
||||||
### 5. convert to json format
|
### 5. convert to json format
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
|
# ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
|
||||||
|
ORIGINAL_DATA_BASE = "/home/ming/malware/data/malasm_inst_pairs"
|
||||||
CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert"
|
CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert"
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,8 +99,8 @@ def main(tokenizer_file=""):
|
|||||||
# dataset = load_dataset("json", data_files=json_files, field="data")
|
# dataset = load_dataset("json", data_files=json_files, field="data")
|
||||||
|
|
||||||
text_files = [
|
text_files = [
|
||||||
os.path.join(ORIGINAL_DATA_BASE, "inst.1.{}.txt".format(group))
|
os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group))
|
||||||
for group in ["pos", "neg"]
|
for group in ["pos", "neg"] for i in range(10)
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset = []
|
dataset = []
|
||||||
@ -121,4 +121,4 @@ def main(tokenizer_file=""):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json"))
|
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json"))
|
||||||
|
Loading…
Reference in New Issue
Block a user