detect_rep/detect_script/extract.py
2023-04-05 10:04:49 +08:00

247 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import angr
import csv
from angrutils import plot_cfg, hook0, set_plot_style
import bingraphvis
import networkx as nx
import os
from tqdm import tqdm
import sys
# sys.path.append(r'../ASM2VEC_plus_scripts/')
# from func2vec import func2vec,load_model
# from node_feature import asm2vec_plus
import lief
from detect_pe_packer import detect_pack_res
# exit()
from node_feature import *
def my_round(vec_list):
for i in range(len(vec_list)):
vec_list[i] = round(vec_list[i], 5)
return vec_list
def get_node_feature(hex_asm="558bec83ec085756bf0bb80000ff15d4804000",node_feature_method="n_gram"):
if node_feature_method =="asm2vec_plus_16":
return asm2vec_plus_16(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_plus_32":
return asm2vec_plus_32(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_plus_64":
return asm2vec_plus_64(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_plus_128":
return asm2vec_plus_128(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_plus_256":
return asm2vec_plus_256(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_base_16":
return asm2vec_base_16(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_base_32":
return asm2vec_base_32(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_base_64":
return asm2vec_base_64(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_base_128":
return asm2vec_base_128(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_base_256":
return asm2vec_base_256(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s_base_16":
return asm2vec_s_base_16(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s_base_32":
return asm2vec_s_base_32(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s_base_64":
return asm2vec_s_base_64(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s_base_128":
return asm2vec_s_base_128(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s_base_256":
return asm2vec_s_base_256(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s368_base_16":
return asm2vec_s368_base_16(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s368_base_32":
return asm2vec_s368_base_32(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s368_base_64":
return asm2vec_s368_base_64(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s368_base_128":
return asm2vec_s368_base_128(hex_asm=hex_asm).tolist()
elif node_feature_method == "asm2vec_s368_base_256":
return asm2vec_s368_base_256(hex_asm=hex_asm).tolist()
elif node_feature_method =="malconv":
return malconv(hex_asm=hex_asm)
elif node_feature_method =="n_gram":
return n_gram(hex_asm=hex_asm)
elif node_feature_method =="word_frequency":
return word_frequency(hex_asm=hex_asm)
elif node_feature_method =="asm2vec_plus_small":
return asm2vec_plus_small(hex_asm=hex_asm).tolist()
elif node_feature_method =="asm2vec_base_small":
return asm2vec_base_small(hex_asm=hex_asm).tolist()
asm2vec_model_path="../ASM2VEC_plus_scripts/asm2vec_checkpoints/model_100.pt"
# print(bengin_list)
# jmp_family=["jmp","call","ret","retf",
# "ja","jnbe","jae","jnb","jb","jane","jbe","jna",
# "jg","jnle","jge","jnl","jl","jnge","jle","jng",
# "je","jz","jne","jnz","jc","jnc","jno","jnp","jpo",
# "jns","jo","jp","jpe","js"
# "loop","loope","loopz","loopne","loopnz","jcxz","jecxz"
# ]
def get_new_section_addr(bin_parse):
entry_point = bin_parse.optional_header.addressof_entrypoint
# 找到入口点所在的section名字
entryname = bin_parse.section_from_rva(entry_point).name
for section in bin_parse.sections:
if section.name == entryname:
text_characteristics = section.characteristics
virtual_address = section.virtual_address
virtual_size = section.size
virtual_offset = section.offset
# 创建新段,并设置偏移位置
new_section = lief.PE.Section("test")
if virtual_size % 0x1000 == 0:
mod_num = int(virtual_size / 0x1000)
else:
mod_num = int(virtual_size / 0x1000) + 1
new_section.virtual_address = virtual_address + mod_num * 0x1000
# print(hex(new_section.virtual_address))
# exit()
# print(hex(new_section.virtual_address + bin_parse.imagebase))
new_section.offset = virtual_offset + virtual_size
return new_section.offset
def cfg_extract(file_list,data_dir="../data/malware",CFG_dir="../CFG_data/malware",csv_save_path="../CFG_data/malware_msg.csv",header = ['malware_name','nodes_num','edgs_num']):
csv_data=[]
#载入asm2vec的模型
# asm2vec_model=load_model(asm2vec_model_path)
for i in tqdm(range(len(file_list))):
# for i in tqdm(range(0,50)):
file_item=file_list[i]
#剔除加壳程序
try:
if detect_pack_res(os.path.join(data_dir, file_item)) == True:
continue
bin_parse = lief.PE.parse(os.path.join(data_dir,file_item))
p = angr.Project(os.path.join(data_dir,file_item), load_options={'auto_load_libs': False})
cfg = p.analyses.CFGFast(show_progressbar=True, normalize=False, resolve_indirect_jumps=False,
force_smart_scan=False, symbols=False, data_references=False)
# new_section_addr = get_new_section_addr(bin_parse)
except:
continue
G = cfg.graph
#如果反汇编有出错,则当前这个文件舍弃
flag=1
#为每个节点设置相关信息
insert_point_count=0
for e,dict in G.nodes.items():
try:
asm_hex=e.block.bytes.hex().replace("0x","")
#不能正确生成vector的函数不要
# dict["feature"] = str(vec)
#判断是否有插入点
# dict["has_insertPoint"] = False
#该basic_block的bytes指令
dict["asm2vec_plus_16"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_plus_16")))
except Exception as ea:
flag = 0
break
dict["asm2vec_plus_32"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_plus_32")))
dict["asm2vec_plus_64"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_plus_64")))
dict["asm2vec_plus_128"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_plus_128")))
dict["asm2vec_plus_256"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_plus_256")))
dict["asm2vec_base_16"] = str(my_round(get_node_feature(hex_asm=asm_hex,node_feature_method="asm2vec_base_16")))
dict["asm2vec_base_32"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_base_32")))
dict["asm2vec_base_64"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_base_64")))
dict["asm2vec_base_128"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_base_128")))
dict["asm2vec_base_256"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_base_256")))
dict["asm2vec_s_base_16"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s_base_16")))
dict["asm2vec_s_base_32"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s_base_32")))
dict["asm2vec_s_base_64"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s_base_64")))
dict["asm2vec_s_base_128"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s_base_128")))
dict["asm2vec_s_base_256"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s_base_256")))
dict["asm2vec_s368_base_16"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s368_base_16")))
dict["asm2vec_s368_base_32"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s368_base_32")))
dict["asm2vec_s368_base_64"] = str( my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s368_base_64")))
dict["asm2vec_s368_base_128"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s368_base_128")))
dict["asm2vec_s368_base_256"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="asm2vec_s368_base_256")))
dict["n_gram"] = str(my_round(get_node_feature(hex_asm=asm_hex, node_feature_method="n_gram")))
dict["malconv"]= str(my_round(get_node_feature(hex_asm=asm_hex,node_feature_method="malconv")))
dict["word_frequency"]= str(my_round(get_node_feature(hex_asm=asm_hex,node_feature_method="word_frequency")))
#用于记录当前basicblock插入func后该func的地址值初始值是原始程序的新建section的地址
if flag==1:
print("成功写入",os.path.join(CFG_dir,file_item)+".gexf")
nx.write_gexf(G, os.path.join(CFG_dir,file_item)+".gexf")
#写入节点数量与边数量信息
nodes_num=len(cfg.graph.nodes())
edgs_num=len(cfg.graph.edges())
csv_data.append([file_item+".gexf",str(nodes_num),str(edgs_num)])
with open(csv_save_path,'w',encoding='utf-8',newline='') as fp:
writer =csv.writer(fp)
writer.writerow(header)
writer.writerows(csv_data)
#用于直接读取gexf信息保存的代码
def write_csv(CFG_dir = "../CFG_data/malware",csv_save_path="../CFG_data/malware_msg.csv",header = ['malware_name','nodes_num','edgs_num']):
file_list=os.listdir(CFG_dir)
# benign_list=os.listdir(bengin_CFG_dir)
csv_data = []
for item_name in tqdm(file_list):
file=os.path.join(CFG_dir,item_name)
G=nx.read_gexf(file)
nodes_num = len(G.nodes())
edgs_num = len(G.edges())
csv_data.append([item_name, str(nodes_num), str(edgs_num)])
with open(csv_save_path,'w',encoding='utf-8',newline='') as fp:
writer =csv.writer(fp)
writer.writerow(header)
writer.writerows(csv_data)
print("成功结束")
# pass
if __name__ == '__main__':
malware_data_dir = "../data/malware"
bengin_data_dir = "../data/benign"
# pack_data_dir= "../data/pack"
malware_CFG_dir = "../cfg_data_with_feature/malware"
bengin_CFG_dir = "../cfg_data_with_feature/benign"
# pack_CFG_dir = "../cfg_data/pack"
malware_list = os.listdir(malware_data_dir)
bengin_list = os.listdir(bengin_data_dir)
# pack_list = os.listdir(pack_data_dir)
# print(pack_list)
# exit()
#
# cfg_extract(pack_list, data_dir=pack_data_dir, CFG_dir=pack_CFG_dir,
# csv_save_path="../cfg_data/pack_msg.csv", header=['bengin_name', 'nodes_num', 'edgs_num'])
# cfg_extract(malware_list,data_dir=malware_data_dir,CFG_dir=malware_CFG_dir,csv_save_path="../cfg_data_with_feature/malware_msg1.csv",header = ['malware_name','nodes_num','edgs_num'])
cfg_extract(bengin_list, data_dir=bengin_data_dir ,CFG_dir=bengin_CFG_dir ,csv_save_path="../cfg_data_with_feature/benign_msg.csv", header = ['bengin_name', 'nodes_num', 'edgs_num'])
# write_csv(CFG_dir = "../CFG_data/malware_asm2vec_base",csv_save_path="../CFG_data/malware_msg_asm2vec_base.csv",header = ['malware_name','nodes_num','edgs_num','insert_point_count'])
# write_csv(CFG_dir="../CFG_data/benign_asm2vec_base", csv_save_path="../CFG_data/benign_msg_asm2vec_base.csv",header=['benign_name', 'nodes_num', 'edgs_num','insert_point_count'])