import torch import sys import pandas as pd from capstone import * import binascii from sklearn.feature_extraction.text import CountVectorizer sys.path.append(r"./features_method/asm2vec_base/") import detect_script.features_method.asm2vec_plus.asm2vec_plus_util as asm2vec_plus_obj # import load_asm2vec_plus_model,str_hex_to_bytes,get_asm_input_vector import detect_script.features_method.asm2vec_base.asm2vec_base_util as asm2vec_base_obj import detect_script.features_method.s_asm2vec_base.asm2vec_base_util as s_asm2vec_base_obj import detect_script.features_method.s368_asm2vec_base.asm2vec_base_util as s368_asm2vec_base_obj import os import json device = 'cuda' if torch.cuda.is_available() else 'cpu' asm2vec_plus_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_16_100.pt", device=device) asm2vec_plus_32_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_32_100.pt", device=device) asm2vec_plus_64_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_64_100.pt", device=device) asm2vec_plus_128_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_128_100.pt", device=device) asm2vec_plus_256_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_256_100.pt", device=device) asm2vec_base_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\model_16_100.pt", device=device) asm2vec_base_32_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_32_100.pt", device=device) asm2vec_base_64_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_64_100.pt", device=device) asm2vec_base_128_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_128_100.pt", device=device) asm2vec_base_256_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_256_100.pt", device=device) asm2vec_base_16_s_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\s_model_16_100.pt", device=device) asm2vec_base_32_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_32_100.pt", device=device) asm2vec_base_64_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_64_100.pt", device=device) asm2vec_base_128_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_128_100.pt", device=device) asm2vec_base_256_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_256_100.pt", device=device) asm2vec_base_16_s368_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_16_100.pt", device=device) asm2vec_base_32_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_32_100.pt", device=device) asm2vec_base_64_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_64_100.pt", device=device) asm2vec_base_128_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_128_100.pt", device=device) asm2vec_base_256_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_256_100.pt", device=device) def asm2vec_plus(model,hex_asm): hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list) hex2vec_list = hex2vec_list fun2vec_origin = [0.0] * len(hex2vec_list[0]) # 开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] opcode_seq_len = len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len fun2vec_origin = torch.tensor(fun2vec_origin).to(device) embedding_func_vec = model.to(device).linear_f(fun2vec_origin) return embedding_func_vec def asm2vec_plus_16(model=asm2vec_plus_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_plus(model,hex_asm) def asm2vec_plus_32(model=asm2vec_plus_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_plus(model, hex_asm) def asm2vec_plus_64(model=asm2vec_plus_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_plus(model, hex_asm) def asm2vec_plus_128(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_plus(model, hex_asm) def asm2vec_plus_256(model=asm2vec_plus_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_plus(model, hex_asm) def asm2vec_base(model,hex_asm): hex2vec_list = asm2vec_base_obj.str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = asm2vec_base_obj.get_asm_input_vector(hex2vec_list) hex2vec_list = hex2vec_list fun2vec_origin = [0.0] * len(hex2vec_list[0]) # 开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] opcode_seq_len = len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len fun2vec_origin = torch.tensor(fun2vec_origin).to(device) embedding_func_vec = model.to(device).linear_f(fun2vec_origin) return embedding_func_vec def asm2vec_base_16(model=asm2vec_base_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_base(model,hex_asm) def asm2vec_base_32(model=asm2vec_base_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_base(model,hex_asm) def asm2vec_base_64(model=asm2vec_base_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_base(model,hex_asm) def asm2vec_base_128(model=asm2vec_base_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_base(model,hex_asm) def asm2vec_base_256(model=asm2vec_base_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_base(model, hex_asm) def asm2vec_s_base(model,hex_asm): hex2vec_list = s_asm2vec_base_obj.str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = s_asm2vec_base_obj.get_asm_input_vector(hex2vec_list) hex2vec_list = hex2vec_list fun2vec_origin = [0.0] * len(hex2vec_list[0]) # 开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] opcode_seq_len = len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len fun2vec_origin = torch.tensor(fun2vec_origin).to(device) embedding_func_vec = model.to(device).linear_f(fun2vec_origin) return embedding_func_vec def asm2vec_s_base_16(model=asm2vec_base_16_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s_base(model,hex_asm) def asm2vec_s_base_32(model=asm2vec_base_32_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s_base(model,hex_asm) def asm2vec_s_base_64(model=asm2vec_base_64_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s_base(model,hex_asm) def asm2vec_s_base_128(model=asm2vec_base_128_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s_base(model,hex_asm) def asm2vec_s_base_256(model=asm2vec_base_256_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s_base(model, hex_asm) def asm2vec_s368_base(model,hex_asm): hex2vec_list = s368_asm2vec_base_obj.str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = s368_asm2vec_base_obj.get_asm_input_vector(hex2vec_list) hex2vec_list = hex2vec_list fun2vec_origin = [0.0] * len(hex2vec_list[0]) # 开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] opcode_seq_len = len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len fun2vec_origin = torch.tensor(fun2vec_origin).to(device) embedding_func_vec = model.to(device).linear_f(fun2vec_origin) return embedding_func_vec def asm2vec_s368_base_16(model=asm2vec_base_16_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s368_base(model,hex_asm) def asm2vec_s368_base_32(model=asm2vec_base_32_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s368_base(model,hex_asm) def asm2vec_s368_base_64(model=asm2vec_base_64_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s368_base(model,hex_asm) def asm2vec_s368_base_128(model=asm2vec_base_128_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s368_base(model,hex_asm) def asm2vec_s368_base_256(model=asm2vec_base_256_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): return asm2vec_s368_base(model, hex_asm) def malconv(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): vec=[0]*256 hex_asm_pretreat="" for i in range(len(hex_asm)): hex_asm_pretreat=hex_asm_pretreat+hex_asm[i] if (i+1) % 2==0: hex_asm_pretreat=hex_asm_pretreat+" " # print(hex_asm_pretreat) hex_asm_pretreat=hex_asm_pretreat.split() # print(hex_asm_pretreat) for i in range(len(hex_asm_pretreat)): hex_asm_pretreat[i]=int(hex_asm_pretreat[i],16) vec[hex_asm_pretreat[i]]+=1 # print(hex_asm_pretreat) return vec with open("./features_method/n_gram/" + 'vocab.json', 'r', encoding='utf-8') as fp: n_gram_vocab = json.load(fp) #2-gram ,长度为1024 def n_gram(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): vec_len=512 vec = [0] * vec_len md = Cs(CS_ARCH_X86, CS_MODE_32) HexCode = binascii.unhexlify(hex_asm) Hex_list=[] for item in md.disasm(HexCode, 0): Hex_list.append(item.mnemonic) for i in range(len(Hex_list)-1): if Hex_list[i]+" "+Hex_list[i+1] in n_gram_vocab: index=int(n_gram_vocab[Hex_list[i]+" "+Hex_list[i+1]]) if index < vec_len: vec[index]+=1 return vec with open("./features_method/word_frequency/" + 'vocab.json', 'r', encoding='utf-8') as fp: word_frequency_vocab = json.load(fp) def word_frequency(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): vec = [0] * 256 md = Cs(CS_ARCH_X86, CS_MODE_32) HexCode = binascii.unhexlify(hex_asm) for item in md.disasm(HexCode, 0): if item.mnemonic in word_frequency_vocab: vec[word_frequency_vocab[item.mnemonic]] += 1 return vec def gcn_base(): pass def asm2vec_plus_test(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"): hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm) hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list) hex2vec_list = hex2vec_list fun2vec_origin = [0.0] * len(hex2vec_list[0]) # 开始对每一行的代码求平均值,得到函数的vec for i in hex2vec_list: for j in range(len(i)): fun2vec_origin[j] += i[j] opcode_seq_len = len(hex2vec_list) for i in range(len(fun2vec_origin)): fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len func2vec_origin_list=[] for i in range(10024): func2vec_origin_list.append(fun2vec_origin) func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device) embedding_func_vec = model.to(device).linear_f(func2vec_origin_list) return embedding_func_vec import time if __name__ == '__main__': # res=asm2vec_plus(asm2vec_plus_model) # print(res) # print(len(res)) # res=asm2vec_base() # print(res) # print(len(res)) # n_gram() # res=n_gram() T1 = time.time() # for i in range(10024): res=asm2vec_plus_test() # print(res) # print(len(res)) T2 = time.time() print('程序运行时间:%s秒' % ((T2 - T1) )) # res=asm2vec_plus_16() # print(res) # print(len(res))