final version of Inst2Vec
This commit is contained in:
parent
fe2de236b5
commit
c050cff9f5
@ -43,6 +43,8 @@ from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AdamW, AutoConfig,
|
||||
from my_data_collator import MyDataCollatorForPreTraining
|
||||
from process_data.utils import CURRENT_DATA_BASE
|
||||
|
||||
HIDDEN_SIZE=256
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
@ -187,7 +189,7 @@ def main():
|
||||
# field="data",
|
||||
# )
|
||||
train_files = [
|
||||
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in range(2)
|
||||
os.path.join(CURRENT_DATA_BASE, "inst.all.{}.json".format(i)) for i in [0,1,2,3,4,5,6] # ,8,9,10,11,12,13]
|
||||
]
|
||||
valid_file = "/home/ming/malware/inst2vec_bert/data/test_lm/inst.json"
|
||||
raw_datasets = load_dataset(
|
||||
@ -210,10 +212,10 @@ def main():
|
||||
# we use a much smaller BERT, config is:
|
||||
config = BertConfig(
|
||||
vocab_size=tokenizer.get_vocab_size(),
|
||||
hidden_size=96,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=384,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=4 * HIDDEN_SIZE,
|
||||
max_position_embeddings=32,
|
||||
)
|
||||
|
||||
@ -230,7 +232,10 @@ def main():
|
||||
def tokenize_function(examples):
|
||||
text = [tuple(sent) for sent in examples["text"]]
|
||||
encoded_inputs = {}
|
||||
# try:
|
||||
results = tokenizer.encode_batch(text)
|
||||
# except:
|
||||
# return None
|
||||
encoded_inputs["input_ids"] = [result.ids for result in results]
|
||||
encoded_inputs["token_type_ids"] = [result.type_ids for result in results]
|
||||
encoded_inputs["special_tokens_mask"] = [
|
||||
@ -253,7 +258,7 @@ def main():
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=False,
|
||||
load_from_cache_file=True,
|
||||
)
|
||||
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
|
||||
import pdb
|
||||
from utils import ORIGINAL_DATA_BASE, read_file
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ def check(filename):
|
||||
sents = read_file(filename)
|
||||
result = 0
|
||||
for sent in sents:
|
||||
result = max(result, len(sent[-1].replace("\t", " ").split()))
|
||||
result = max(result, len(sent[:-1].replace("\t", " ").split()))
|
||||
print("The longest sentence in {} has {} words".format(filename, result))
|
||||
return result
|
||||
|
||||
@ -15,10 +15,10 @@ def check(filename):
|
||||
def main():
|
||||
longest = 0
|
||||
# for i in range(6):
|
||||
for i in [1]:
|
||||
for i in range(10):
|
||||
for group in ("pos", "neg"):
|
||||
filename = os.path.join(
|
||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group)
|
||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
|
||||
)
|
||||
longest = max(check(filename), longest)
|
||||
print("The longest sentence in all files has {} words.".format(longest))
|
||||
|
21
process_data/clean.py
Normal file
21
process_data/clean.py
Normal file
@ -0,0 +1,21 @@
|
||||
from utils import ORIGINAL_DATA_BASE, read_file, write_file
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
def remove(pos_file, neg_file):
|
||||
pos = read_file(pos_file)
|
||||
neg = read_file(neg_file)
|
||||
rets = []
|
||||
for n in tqdm(neg):
|
||||
if n in pos:
|
||||
continue
|
||||
rets.append(n)
|
||||
write_file(rets, neg_file)
|
||||
|
||||
def main():
|
||||
pos_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean")
|
||||
neg_file = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean")
|
||||
remove(pos_file, neg_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -19,8 +19,8 @@ def convert(fin, fout):
|
||||
|
||||
def main():
|
||||
# for i in range(6):
|
||||
for i in [1]:
|
||||
fin = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
||||
for i in range(10):
|
||||
fin = os.path.join(ORIGINAL_DATA_BASE, "win32_0{}xxxx.all".format(i))
|
||||
fout = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt".format(i))
|
||||
convert(fin, fout)
|
||||
|
||||
|
@ -41,12 +41,12 @@ def counter(filename):
|
||||
def main():
|
||||
cnt = set()
|
||||
# for i in range(6):
|
||||
for i in [1]:
|
||||
for i in range(10):
|
||||
for group in ["pos", "neg"]:
|
||||
filename = os.path.join(
|
||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt".format(i, group)
|
||||
ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group)
|
||||
)
|
||||
cnt += counter(filename)
|
||||
cnt = cnt.union(counter(filename))
|
||||
print("There are {} charcters in all files".format(len(cnt)))
|
||||
|
||||
|
||||
|
@ -22,10 +22,12 @@ def create(pos, neg, tgt):
|
||||
|
||||
def main():
|
||||
# for i in range(6):
|
||||
for i in [1]:
|
||||
j = (i + 1) % 6
|
||||
pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
||||
neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
||||
for i in range(10):
|
||||
j = (i + 1) % 10
|
||||
# neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
||||
# pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
||||
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(i))
|
||||
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.txt.clean".format(j))
|
||||
tgt = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.txt".format(i))
|
||||
create(pos, neg, tgt)
|
||||
|
||||
|
@ -26,11 +26,11 @@ def write_worker(sents, json_file, index):
|
||||
def merge_to_json(pos, neg, json_file):
|
||||
sents = read_file(pos)
|
||||
|
||||
p = Pool(36)
|
||||
p = Pool(6)
|
||||
|
||||
for i in range(64):
|
||||
for i in range(6):
|
||||
p.apply_async(
|
||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, i,)
|
||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 2+i,)
|
||||
)
|
||||
print("Waiting for all sub-processes done...")
|
||||
p.close()
|
||||
@ -55,11 +55,11 @@ def merge_to_json(pos, neg, json_file):
|
||||
|
||||
sents = read_file(neg)
|
||||
|
||||
p = Pool(8)
|
||||
p = Pool(6)
|
||||
|
||||
for i in range(64):
|
||||
for i in range(6):
|
||||
p.apply_async(
|
||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 64 + i,)
|
||||
write_worker, args=(sents[i * BASE : (i + 1) * BASE], json_file, 8 + i,)
|
||||
)
|
||||
print("Waiting for all sub-processes done...")
|
||||
p.close()
|
||||
@ -80,11 +80,15 @@ def merge_to_json(pos, neg, json_file):
|
||||
|
||||
def main():
|
||||
# for i in range(6):
|
||||
for i in [1]:
|
||||
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
|
||||
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
|
||||
json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
|
||||
merge_to_json(pos, neg, json_file)
|
||||
# for i in range(6):
|
||||
# pos = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.pos.label.txt".format(i))
|
||||
# neg = os.path.join(ORIGINAL_DATA_BASE, "inst.{}.neg.label.txt".format(i))
|
||||
# json_file = os.path.join(CURRENT_DATA_BASE, "inst.{}.".format(i))
|
||||
# merge_to_json(pos, neg, json_file)
|
||||
pos = os.path.join(ORIGINAL_DATA_BASE, "inst.all.pos.txt.clean.label")
|
||||
neg = os.path.join(ORIGINAL_DATA_BASE, "inst.all.neg.txt.clean.label")
|
||||
json_file = os.path.join(CURRENT_DATA_BASE, "inst.all.")
|
||||
merge_to_json(pos, neg, json_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -31,7 +31,10 @@ We process the files containing negative examples similarly.
|
||||
cat inst.*.neg.txt.clean | sort -n | uniq > inst.all.neg.txt.clean
|
||||
```
|
||||
|
||||
Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`.
|
||||
Based on the `inst.all.pos.txt.clean`, we remove the lines from `inst.all.neg.txt.clean` if they also occur in `inst.all.pos.txt.clean`. This can be completed by `python clean.py`, or
|
||||
<!-- ```shell
|
||||
grep -v -f inst.all.pos.txt.clean inst.all.neg.txt.clean > inst.all.neg.txt.clean
|
||||
``` -->
|
||||
|
||||
|
||||
### 5. convert to json format
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
|
||||
# ORIGINAL_DATA_BASE = "/home/ming/malware/data/elfasm_inst_pairs"
|
||||
ORIGINAL_DATA_BASE = "/home/ming/malware/data/malasm_inst_pairs"
|
||||
CURRENT_DATA_BASE = "/home/ming/malware/inst2vec_bert/data/asm_bert"
|
||||
|
||||
|
||||
|
@ -99,8 +99,8 @@ def main(tokenizer_file=""):
|
||||
# dataset = load_dataset("json", data_files=json_files, field="data")
|
||||
|
||||
text_files = [
|
||||
os.path.join(ORIGINAL_DATA_BASE, "inst.1.{}.txt".format(group))
|
||||
for group in ["pos", "neg"]
|
||||
os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group))
|
||||
for group in ["pos", "neg"] for i in range(10)
|
||||
]
|
||||
|
||||
dataset = []
|
||||
@ -121,4 +121,4 @@ def main(tokenizer_file=""):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.1.json"))
|
||||
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json"))
|
||||
|
Loading…
Reference in New Issue
Block a user