detect_rep/detect_script/node_feature.py
2023-04-05 10:04:49 +08:00

260 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import sys
import pandas as pd
from capstone import *
import binascii
from sklearn.feature_extraction.text import CountVectorizer
sys.path.append(r"./features_method/asm2vec_base/")
import detect_script.features_method.asm2vec_plus.asm2vec_plus_util as asm2vec_plus_obj
# import load_asm2vec_plus_model,str_hex_to_bytes,get_asm_input_vector
import detect_script.features_method.asm2vec_base.asm2vec_base_util as asm2vec_base_obj
import detect_script.features_method.s_asm2vec_base.asm2vec_base_util as s_asm2vec_base_obj
import detect_script.features_method.s368_asm2vec_base.asm2vec_base_util as s368_asm2vec_base_obj
import os
import json
device = 'cuda' if torch.cuda.is_available() else 'cpu'
asm2vec_plus_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_16_100.pt", device=device)
asm2vec_plus_32_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_32_100.pt", device=device)
asm2vec_plus_64_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_64_100.pt", device=device)
asm2vec_plus_128_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_128_100.pt", device=device)
asm2vec_plus_256_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_256_100.pt", device=device)
asm2vec_base_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\model_16_100.pt", device=device)
asm2vec_base_32_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_32_100.pt", device=device)
asm2vec_base_64_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_64_100.pt", device=device)
asm2vec_base_128_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_128_100.pt", device=device)
asm2vec_base_256_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_256_100.pt", device=device)
asm2vec_base_16_s_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\s_model_16_100.pt", device=device)
asm2vec_base_32_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_32_100.pt", device=device)
asm2vec_base_64_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_64_100.pt", device=device)
asm2vec_base_128_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_128_100.pt", device=device)
asm2vec_base_256_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_256_100.pt", device=device)
asm2vec_base_16_s368_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_16_100.pt", device=device)
asm2vec_base_32_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_32_100.pt", device=device)
asm2vec_base_64_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_64_100.pt", device=device)
asm2vec_base_128_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_128_100.pt", device=device)
asm2vec_base_256_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_256_100.pt", device=device)
def asm2vec_plus(model,hex_asm):
hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_plus_16(model=asm2vec_plus_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model,hex_asm)
def asm2vec_plus_32(model=asm2vec_plus_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_64(model=asm2vec_plus_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_128(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_256(model=asm2vec_plus_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_base(model,hex_asm):
hex2vec_list = asm2vec_base_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_base_16(model=asm2vec_base_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_32(model=asm2vec_base_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_64(model=asm2vec_base_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_128(model=asm2vec_base_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_256(model=asm2vec_base_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model, hex_asm)
def asm2vec_s_base(model,hex_asm):
hex2vec_list = s_asm2vec_base_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = s_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_s_base_16(model=asm2vec_base_16_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_32(model=asm2vec_base_32_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_64(model=asm2vec_base_64_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_128(model=asm2vec_base_128_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_256(model=asm2vec_base_256_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model, hex_asm)
def asm2vec_s368_base(model,hex_asm):
hex2vec_list = s368_asm2vec_base_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = s368_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
return embedding_func_vec
def asm2vec_s368_base_16(model=asm2vec_base_16_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_32(model=asm2vec_base_32_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_64(model=asm2vec_base_64_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_128(model=asm2vec_base_128_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_256(model=asm2vec_base_256_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model, hex_asm)
def malconv(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec=[0]*256
hex_asm_pretreat=""
for i in range(len(hex_asm)):
hex_asm_pretreat=hex_asm_pretreat+hex_asm[i]
if (i+1) % 2==0:
hex_asm_pretreat=hex_asm_pretreat+" "
# print(hex_asm_pretreat)
hex_asm_pretreat=hex_asm_pretreat.split()
# print(hex_asm_pretreat)
for i in range(len(hex_asm_pretreat)):
hex_asm_pretreat[i]=int(hex_asm_pretreat[i],16)
vec[hex_asm_pretreat[i]]+=1
# print(hex_asm_pretreat)
return vec
with open("./features_method/n_gram/" + 'vocab.json', 'r', encoding='utf-8') as fp:
n_gram_vocab = json.load(fp)
#2-gram ,长度为1024
def n_gram(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec_len=512
vec = [0] * vec_len
md = Cs(CS_ARCH_X86, CS_MODE_32)
HexCode = binascii.unhexlify(hex_asm)
Hex_list=[]
for item in md.disasm(HexCode, 0):
Hex_list.append(item.mnemonic)
for i in range(len(Hex_list)-1):
if Hex_list[i]+" "+Hex_list[i+1] in n_gram_vocab:
index=int(n_gram_vocab[Hex_list[i]+" "+Hex_list[i+1]])
if index < vec_len:
vec[index]+=1
return vec
with open("./features_method/word_frequency/" + 'vocab.json', 'r', encoding='utf-8') as fp:
word_frequency_vocab = json.load(fp)
def word_frequency(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec = [0] * 256
md = Cs(CS_ARCH_X86, CS_MODE_32)
HexCode = binascii.unhexlify(hex_asm)
for item in md.disasm(HexCode, 0):
if item.mnemonic in word_frequency_vocab:
vec[word_frequency_vocab[item.mnemonic]] += 1
return vec
def gcn_base():
pass
def asm2vec_plus_test(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list=[]
for i in range(10024):
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
return embedding_func_vec
import time
if __name__ == '__main__':
# res=asm2vec_plus(asm2vec_plus_model)
# print(res)
# print(len(res))
# res=asm2vec_base()
# print(res)
# print(len(res))
# n_gram()
# res=n_gram()
T1 = time.time()
# for i in range(10024):
res=asm2vec_plus_test()
# print(res)
# print(len(res))
T2 = time.time()
print('程序运行时间:%s' % ((T2 - T1) ))
# res=asm2vec_plus_16()
# print(res)
# print(len(res))