detect_rep/data_extract/node_feature.py

373 lines
18 KiB
Python
Raw Permalink Normal View History

2023-04-05 10:04:49 +08:00
import torch
import sys
import pandas as pd
from capstone import *
import binascii
from sklearn.feature_extraction.text import CountVectorizer
sys.path.append(r"./features_method/asm2vec_base/")
import detect_script.features_method.asm2vec_plus.asm2vec_plus_util as asm2vec_plus_obj
# import load_asm2vec_plus_model,str_hex_to_bytes,get_asm_input_vector
import detect_script.features_method.asm2vec_base.asm2vec_base_util as asm2vec_base_obj
import detect_script.features_method.s_asm2vec_base.asm2vec_base_util as s_asm2vec_base_obj
import detect_script.features_method.s368_asm2vec_base.asm2vec_base_util as s368_asm2vec_base_obj
import os
import json
device = 'cuda' if torch.cuda.is_available() else 'cpu'
asm2vec_plus_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_16_100.pt", device=device)
asm2vec_plus_32_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_32_100.pt", device=device)
asm2vec_plus_64_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_64_100.pt", device=device)
asm2vec_plus_128_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_128_100.pt", device=device)
asm2vec_plus_256_model= asm2vec_plus_obj.load_asm2vec_plus_model(path=r".\features_method\asm2vec_plus\checkpoints\model_256_100.pt", device=device)
asm2vec_base_16_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\model_16_100.pt", device=device)
asm2vec_base_32_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_32_100.pt", device=device)
asm2vec_base_64_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_64_100.pt", device=device)
asm2vec_base_128_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_128_100.pt", device=device)
asm2vec_base_256_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\model_256_100.pt", device=device)
asm2vec_base_16_s_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\asm2vec_base\checkpoints\s_model_16_100.pt", device=device)
asm2vec_base_32_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_32_100.pt", device=device)
asm2vec_base_64_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_64_100.pt", device=device)
asm2vec_base_128_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_128_100.pt", device=device)
asm2vec_base_256_s_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\asm2vec_base\checkpoints\s_model_256_100.pt", device=device)
asm2vec_base_16_s368_model= asm2vec_plus_obj.load_asm2vec_plus_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_16_100.pt", device=device)
asm2vec_base_32_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_32_100.pt", device=device)
asm2vec_base_64_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_64_100.pt", device=device)
asm2vec_base_128_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_128_100.pt", device=device)
asm2vec_base_256_s368_model = asm2vec_base_obj.load_asm2vec_base_model(r".\features_method\s368_asm2vec_base\checkpoints\s368_model_256_100.pt", device=device)
def asm2vec_plus(model,hex_asm):
# T1 = time.time()
func2vec_origin_list = []
for hex_asm_item in hex_asm[:10000]:
hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm_item)
hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
# T2 = time.time()
# print('程序运行时间:%s秒' % ((T2 - T1)))
return embedding_func_vec
# def asm2vec_plus(model,hex_asm):
# hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm)
# hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
# hex2vec_list = hex2vec_list
# fun2vec_origin = [0.0] * len(hex2vec_list[0])
# # 开始对每一行的代码求平均值得到函数的vec
# for i in hex2vec_list:
# for j in range(len(i)):
# fun2vec_origin[j] += i[j]
# opcode_seq_len = len(hex2vec_list)
# for i in range(len(fun2vec_origin)):
# fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
# fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
# embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
# return embedding_func_vec
def asm2vec_plus_16(model=asm2vec_plus_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model,hex_asm)
def asm2vec_plus_32(model=asm2vec_plus_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_64(model=asm2vec_plus_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_128(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_plus_256(model=asm2vec_plus_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_plus(model, hex_asm)
def asm2vec_base(model,hex_asm):
func2vec_origin_list = []
for hex_asm_item in hex_asm:
hex2vec_list = asm2vec_base_obj.str_hex_to_bytes(hex_asm_item)
hex2vec_list, opcode_oprand_seq = asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
return embedding_func_vec
# def asm2vec_base(model,hex_asm):
# hex2vec_list = asm2vec_base_obj.str_hex_to_bytes(hex_asm)
# hex2vec_list, opcode_oprand_seq = asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
# hex2vec_list = hex2vec_list
# fun2vec_origin = [0.0] * len(hex2vec_list[0])
# # 开始对每一行的代码求平均值得到函数的vec
# for i in hex2vec_list:
# for j in range(len(i)):
# fun2vec_origin[j] += i[j]
# opcode_seq_len = len(hex2vec_list)
# for i in range(len(fun2vec_origin)):
# fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
# fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
# embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
# return embedding_func_vec
def asm2vec_base_16(model=asm2vec_base_16_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_32(model=asm2vec_base_32_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_64(model=asm2vec_base_64_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_128(model=asm2vec_base_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model,hex_asm)
def asm2vec_base_256(model=asm2vec_base_256_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_base(model, hex_asm)
def asm2vec_s_base(model,hex_asm):
func2vec_origin_list = []
for hex_asm_item in hex_asm:
hex2vec_list = s_asm2vec_base_obj.str_hex_to_bytes(hex_asm_item)
hex2vec_list, opcode_oprand_seq = s_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
return embedding_func_vec
# def asm2vec_s_base(model,hex_asm):
# hex2vec_list = s_asm2vec_base_obj.str_hex_to_bytes(hex_asm)
# hex2vec_list, opcode_oprand_seq = s_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
# hex2vec_list = hex2vec_list
# fun2vec_origin = [0.0] * len(hex2vec_list[0])
# # 开始对每一行的代码求平均值得到函数的vec
# for i in hex2vec_list:
# for j in range(len(i)):
# fun2vec_origin[j] += i[j]
# opcode_seq_len = len(hex2vec_list)
# for i in range(len(fun2vec_origin)):
# fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
# fun2vec_origin = torch.tensor(fun2vec_origin).to(device)
# embedding_func_vec = model.to(device).linear_f(fun2vec_origin)
# return embedding_func_vec
def asm2vec_s_base_16(model=asm2vec_base_16_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_32(model=asm2vec_base_32_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_64(model=asm2vec_base_64_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_128(model=asm2vec_base_128_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model,hex_asm)
def asm2vec_s_base_256(model=asm2vec_base_256_s_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s_base(model, hex_asm)
def asm2vec_s368_base(model,hex_asm):
func2vec_origin_list = []
for hex_asm_item in hex_asm:
hex2vec_list = s368_asm2vec_base_obj.str_hex_to_bytes(hex_asm_item)
hex2vec_list, opcode_oprand_seq = s368_asm2vec_base_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list = torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
return embedding_func_vec
def asm2vec_s368_base_16(model=asm2vec_base_16_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_32(model=asm2vec_base_32_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_64(model=asm2vec_base_64_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_128(model=asm2vec_base_128_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model,hex_asm)
def asm2vec_s368_base_256(model=asm2vec_base_256_s368_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
return asm2vec_s368_base(model, hex_asm)
def malconv(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec_list=[]
for hex_asm_item in hex_asm:
vec=[0]*256
hex_asm_pretreat=""
for i in range(len(hex_asm_item)):
hex_asm_pretreat=hex_asm_pretreat+hex_asm_item[i]
if (i+1) % 2==0:
hex_asm_pretreat=hex_asm_pretreat+" "
# print(hex_asm_pretreat)
hex_asm_pretreat=hex_asm_pretreat.split()
# print(hex_asm_pretreat)
for i in range(len(hex_asm_pretreat)):
hex_asm_pretreat[i]=int(hex_asm_pretreat[i],16)
vec[hex_asm_pretreat[i]]+=1
vec_list.append(vec)
# print(hex_asm_pretreat)
return vec_list
# def malconv(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
# vec = [0] * 256
# hex_asm_pretreat = ""
# for i in range(len(hex_asm)):
# hex_asm_pretreat = hex_asm_pretreat + hex_asm[i]
# if (i + 1) % 2 == 0:
# hex_asm_pretreat = hex_asm_pretreat + " "
# # print(hex_asm_pretreat)
# hex_asm_pretreat = hex_asm_pretreat.split()
# # print(hex_asm_pretreat)
# for i in range(len(hex_asm_pretreat)):
# hex_asm_pretreat[i] = int(hex_asm_pretreat[i], 16)
# vec[hex_asm_pretreat[i]] += 1
# # print(hex_asm_pretreat)
# return vec
with open("./features_method/n_gram/" + 'vocab.json', 'r', encoding='utf-8') as fp:
n_gram_vocab = json.load(fp)
#2-gram ,长度为1024
def n_gram(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec_len=512
vec_list = []
for hex_asm_item in hex_asm:
vec = [0] * vec_len
md = Cs(CS_ARCH_X86, CS_MODE_32)
HexCode = binascii.unhexlify(hex_asm_item)
Hex_list=[]
for item in md.disasm(HexCode, 0):
Hex_list.append(item.mnemonic)
for i in range(len(Hex_list)-1):
if Hex_list[i]+" "+Hex_list[i+1] in n_gram_vocab:
index=int(n_gram_vocab[Hex_list[i]+" "+Hex_list[i+1]])
if index < vec_len:
vec[index]+=1
vec_list.append(vec)
return vec_list
# def n_gram(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
# vec_len = 512
# vec_list = []
#
# vec = [0] * vec_len
# md = Cs(CS_ARCH_X86, CS_MODE_32)
# HexCode = binascii.unhexlify(hex_asm)
#
# Hex_list = []
# for item in md.disasm(HexCode, 0):
# Hex_list.append(item.mnemonic)
#
# for i in range(len(Hex_list) - 1):
# if Hex_list[i] + " " + Hex_list[i + 1] in n_gram_vocab:
# index = int(n_gram_vocab[Hex_list[i] + " " + Hex_list[i + 1]])
# if index < vec_len:
# vec[index] += 1
#
# return vec
with open("./features_method/word_frequency/" + 'vocab.json', 'r', encoding='utf-8') as fp:
word_frequency_vocab = json.load(fp)
def word_frequency(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
vec_list = []
for hex_asm_item in hex_asm:
vec = [0] * 256
md = Cs(CS_ARCH_X86, CS_MODE_32)
HexCode = binascii.unhexlify(hex_asm_item)
for item in md.disasm(HexCode, 0):
if item.mnemonic in word_frequency_vocab:
vec[word_frequency_vocab[item.mnemonic]] += 1
vec_list.append(vec)
return vec_list
# def word_frequency(hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
# vec = [0] * 256
# md = Cs(CS_ARCH_X86, CS_MODE_32)
# HexCode = binascii.unhexlify(hex_asm)
# for item in md.disasm(HexCode, 0):
# if item.mnemonic in word_frequency_vocab:
# vec[word_frequency_vocab[item.mnemonic]] += 1
# return vec
def gcn_base():
pass
def asm2vec_plus_test(model=asm2vec_plus_128_model,hex_asm="56a194382b56508b35cc912b56ffd6ff3584382b56ffd6a174382b5650ffd65ec3"):
hex2vec_list = asm2vec_plus_obj.str_hex_to_bytes(hex_asm)
hex2vec_list, opcode_oprand_seq = asm2vec_plus_obj.get_asm_input_vector(hex2vec_list)
hex2vec_list = hex2vec_list
fun2vec_origin = [0.0] * len(hex2vec_list[0])
# 开始对每一行的代码求平均值得到函数的vec
for i in hex2vec_list:
for j in range(len(i)):
fun2vec_origin[j] += i[j]
opcode_seq_len = len(hex2vec_list)
for i in range(len(fun2vec_origin)):
fun2vec_origin[i] = fun2vec_origin[i] / opcode_seq_len
func2vec_origin_list=[]
for i in range(10024):
func2vec_origin_list.append(fun2vec_origin)
func2vec_origin_list= torch.tensor(func2vec_origin_list).to(device)
embedding_func_vec = model.to(device).linear_f(func2vec_origin_list)
return embedding_func_vec
import time
if __name__ == '__main__':
# res=asm2vec_plus(asm2vec_plus_model)
# print(res)
# print(len(res))
# res=asm2vec_base()
# print(res)
# print(len(res))
# n_gram()
# res=n_gram()
T1 = time.time()
# for i in range(10024):
res=asm2vec_plus_test()
# print(res)
# print(len(res))
T2 = time.time()
print('程序运行时间:%s' % ((T2 - T1) ))
# res=asm2vec_plus_16()
# print(res)
# print(len(res))