detect_rep/detect_script/node_feature.py

260 lines
13 KiB
Python
Raw Permalink Normal View History

2023-04-05 10:04:49 +08:00
import torch
import sys
import pandas as pd
from capstone import *
import binascii
from sklearn.feature_extraction.text import CountVectorizer
sys.path.append(r"./features_method/asm2vec_base/")
import detect_script.features_method.asm2vec_plus.asm2vec_plus_util as asm2vec_plus_obj
# import load_asm2vec_plus_model,str_hex_to_bytes,get_asm_input_vector
import detect_script.features_method.asm2vec_base.asm2vec_base_util as asm2vec_base_obj
import detect_script.features_method.s_asm2vec_base.asm2vec_base_util as s_asm2vec_base_obj
import detect_script.features_method.s368_asm2vec_base.asm2vec_base_util as s368_asm2vec_base_obj
import os
import json
device = 'cuda' if torch.cuda.is_available() else 'cpu'
asm2vec_plus_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_16_100.pt", device=device)
asm2vec_plus_32_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_32_100.pt", device=device)
asm2vec_plus_64_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_64_100.pt", device=device)
asm2vec_plus_128_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_128_100.pt", device=device)
asm2vec_plus_256_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_256_100.pt", device=device)
asm2vec_base_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\model_16_100.pt", device=device)
asm2vec_base_32_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_32_100.pt", device=device)
asm2vec_base_64_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_64_100.pt", device=device)
asm2vec_base_128_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_128_100.pt", device=device)
asm2vec_base_256_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_256_100.pt", device=device)
asm2vec_base_16_s_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\s_model_16_100.pt", device=device)
asm2vec_base_32_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_32_100.pt", device=device)
asm2vec_base_64_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_64_100.pt", device=device)
asm2vec_base_128_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_128_100.pt", device=device)
asm2vec_base_256_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_256_100.pt", device=device)
asm2vec_base_16_s368_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_16_100.pt", device=device)
asm2vec_base_32_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_32_100.pt", device=device)
asm2vec_base_64_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_64_100.pt", device=device)
asm2vec_base_128_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_128_100.pt", device=device)
asm2vec_base_256_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_256_100.pt", device=device)
def asm2vec_plus(model,hex_asm):
hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_plus_16(model=asm2vec_plus_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model,hex_asm)
def asm2vec_plus_32(model=asm2vec_plus_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_64(model=asm2vec_plus_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_128(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_256(model=asm2vec_plus_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_base(model,hex_asm):
hex2vec_list = asm2vec_base_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_base_16(model=asm2vec_base_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_32(model=asm2vec_base_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_64(model=asm2vec_base_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_128(model=asm2vec_base_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_256(model=asm2vec_base_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model, hex_asm)
def asm2vec_s_base(model,hex_asm):
hex2vec_list = s_asm2vec_base_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = s_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_s_base_16(model=asm2vec_base_16_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_32(model=asm2vec_base_32_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_64(model=asm2vec_base_64_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_128(model=asm2vec_base_128_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_256(model=asm2vec_base_256_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model, hex_asm)
def asm2vec_s368_base(model,hex_asm):
hex2vec_list = s368_asm2vec_base_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = s368_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_s368_base_16(model=asm2vec_base_16_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_32(model=asm2vec_base_32_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_64(model=asm2vec_base_64_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_128(model=asm2vec_base_128_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_256(model=asm2vec_base_256_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model, hex_asm)
def malconv(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec=[0]*256
hex_asm_pretreat=""
for i in range(len(hex_asm)):
hex_asm_pretreat=hex_asm_pretreat+hex_asm[i]
if (i+1) % 2==0:
hex_asm_pretreat=hex_asm_pretreat+" "
# print(hex_asm_pretreat)
hex_asm_pretreat=hex_asm_pretreat.split()
# print(hex_asm_pretreat)
for i in range(len(hex_asm_pretreat)):
hex_asm_pretreat[i]=int(hex_asm_pretreat[i],16)
vec[hex_asm_pretreat[i]]+=1
# print(hex_asm_pretreat)
return vec
with open("./features_method/n_gram/" + 'vocab.json', 'r', encoding='utf-8') as fp:
n_gram_vocab = json.load(fp)
#2-gram ,长度为1024
def n_gram(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec_len=512
vec = [0] * vec_len
md = Cs(CS_ARCH_X86, CS_MODE_32)
HexCode = binascii.unhexlify(hex_asm)
Hex_list=[]
for item in md.disasm(HexCode, 0):
Hex_list.append(item.mnemonic)
for i in range(len(Hex_list)-1):
if Hex_list[i]+" "+Hex_list[i+1] in n_gram_vocab:
index=int(n_gram_vocab[Hex_list[i]+" "+Hex_list[i+1]])
if index < vec_len:
vec[index]+=1
return vec
with open("./features_method/word_frequency/" + 'vocab.json', 'r', encoding='utf-8') as fp:
word_frequency_vocab = json.load(fp)
def word_frequency(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec = [0] * 256
md = Cs(CS_ARCH_X86, CS_MODE_32)
HexCode = binascii.unhexlify(hex_asm)
for item in md.disasm(HexCode, 0):
if item.mnemonic in word_frequency_vocab:
vec[word_frequency_vocab[item.mnemonic]] += 1
return vec
def gcn_base():
pass
def asm2vec_plus_test(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list=[]
for i in range(10024):
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
return embedding_func_vec
import time
if __name__ == '__main__':
# res=asm2vec_plus(asm2vec_plus_model)
# print(res)
# print(len(res))
# res=asm2vec_base()
# print(res)
# print(len(res))
# n_gram()
# res=n_gram()
T1 = time.time()
# for i in range(10024):
res=asm2vec_plus_test()
# print(res)
# print(len(res))
T2 = time.time()
print('程序运行时间:%s' % ((T2 - T1) ))
# res=asm2vec_plus_16()
# print(res)
# print(len(res))