Compare commits
No commits in common. "master" and "Ke" have entirely different histories.
@ -1,6 +1,6 @@
|
||||
Data:
|
||||
preprocess_root: "/home/king/python/data/DatasetJSON_remake"
|
||||
train_vocab_file: "/home/king/python/data/fun_name_sort.jsonl"
|
||||
preprocess_root: "/root/autodl-tmp/"
|
||||
train_vocab_file: "/root/autodl-tmp/train_external_function_name_vocab.jsonl"
|
||||
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
|
||||
Training:
|
||||
cuda: True # enable GPU training if cuda is available
|
||||
@ -8,14 +8,14 @@ Training:
|
||||
dist_port: "1234"
|
||||
max_epoches: 10
|
||||
train_batch_size: 16
|
||||
test_batch_size: 8
|
||||
test_batch_size: 32
|
||||
seed: 19920208
|
||||
only_test_path: 'None'
|
||||
Model:
|
||||
ablation_models: "Full" # "Full"
|
||||
gnn_type: "GraphSAGE" # "GraphSAGE" / "GCN"
|
||||
pool_type: "global_max_pool" # "global_max_pool" / "global_mean_pool"
|
||||
acfg_node_init_dims: 32
|
||||
acfg_node_init_dims: 11
|
||||
cfg_filters: "200-200"
|
||||
fcg_filters: "200-200"
|
||||
number_classes: 1
|
||||
|
25
req_pip.txt
25
req_pip.txt
@ -1,25 +0,0 @@
|
||||
antlr4-python3-runtime==4.8
|
||||
ase==3.21.1
|
||||
cmake==3.18.4.post1
|
||||
dataclasses==0.6
|
||||
googledrivedownloader==0.4
|
||||
hydra-core==1.0.6
|
||||
importlib-resources==5.1.2
|
||||
intel-openmp==2021.1.2
|
||||
magma-cuda112==2.5.2
|
||||
mkl==2021.1.1
|
||||
mkl-include==2021.1.1
|
||||
ninja==1.10.0.post2
|
||||
omegaconf==2.0.6
|
||||
prefetch-generator==1.0.1
|
||||
rdflib==5.0.0
|
||||
tbb==2021.1.1
|
||||
texttable==1.6.3
|
||||
torch==1.8.0+cu111
|
||||
torch-cluster==1.5.9
|
||||
torch-geometric==1.6.3
|
||||
torch-scatter==2.0.6
|
||||
torch-sparse==0.6.9
|
||||
torch-spline-conv==1.2.1
|
||||
torchaudio==0.8.0
|
||||
torchvision==0.9.0+cu111
|
@ -1,83 +1,83 @@
|
||||
# This file may be used to create an environment using:
|
||||
# $ conda create --name <env> --file <this file>
|
||||
# platform: linux-64
|
||||
_libgcc_mutex=0.1
|
||||
antlr4-python3-runtime=4.8
|
||||
ase=3.21.1
|
||||
ca-certificates=2021.1.19
|
||||
cached-property=1.5.2
|
||||
certifi=2020.12.5
|
||||
cffi=1.14.5
|
||||
chardet=4.0.0
|
||||
cmake=3.18.4.post1
|
||||
cycler=0.10.0
|
||||
dataclasses=0.6
|
||||
decorator=4.4.2
|
||||
future=0.18.2
|
||||
googledrivedownloader=0.4
|
||||
h5py=3.2.1
|
||||
hydra-core=1.0.6
|
||||
idna=2.10
|
||||
importlib-resources=5.1.2
|
||||
intel-openmp=2021.1.2
|
||||
isodate=0.6.0
|
||||
jinja2=2.11.3
|
||||
joblib=1.0.1
|
||||
kiwisolver=1.3.1
|
||||
ld_impl_linux-64=2.33.1
|
||||
libedit=3.1.20191231
|
||||
libffi=3.3
|
||||
libgcc-ng=9.1.0
|
||||
libstdcxx-ng=9.1.0
|
||||
llvmlite=0.35.0
|
||||
magma-cuda112=2.5.2
|
||||
markupsafe=1.1.1
|
||||
matplotlib=3.3.4
|
||||
mkl=2021.1.1
|
||||
mkl-include=2021.1.1
|
||||
ncurses=6.2
|
||||
networkx=2.5
|
||||
ninja=1.10.0.post2
|
||||
numba=0.52.0
|
||||
numpy=1.20.1
|
||||
omegaconf=2.0.6
|
||||
openssl=1.1.1j
|
||||
pandas=1.2.3
|
||||
pillow=8.1.2
|
||||
pip=21.0.1
|
||||
prefetch-generator=1.0.1
|
||||
pycparser=2.20
|
||||
pyparsing=2.4.7
|
||||
python=3.7.9
|
||||
python-dateutil=2.8.1
|
||||
python-louvain=0.15
|
||||
pytz=2021.1
|
||||
pyyaml=5.4.1
|
||||
rdflib=5.0.0
|
||||
readline=8.1
|
||||
requests=2.25.1
|
||||
scikit-learn=0.24.1
|
||||
scipy=1.6.1
|
||||
seaborn=0.11.1
|
||||
setuptools=52.0.0
|
||||
six=1.15.0
|
||||
sqlite=3.33.0
|
||||
tbb=2021.1.1
|
||||
texttable=1.6.3
|
||||
threadpoolctl=2.1.0
|
||||
tk=8.6.10
|
||||
torch=1.8.0+cu111
|
||||
torch-cluster=1.5.9
|
||||
torch-geometric=1.6.3
|
||||
torch-scatter=2.0.6
|
||||
torch-sparse=0.6.9
|
||||
torch-spline-conv=1.2.1
|
||||
torchaudio=0.8.0
|
||||
torchvision=0.9.0+cu111
|
||||
tqdm=4.59.0
|
||||
typing-extensions=3.7.4.3
|
||||
urllib3=1.26.3
|
||||
wheel=0.36.2
|
||||
xz=5.2.5
|
||||
zipp=3.4.1
|
||||
zlib=1.2.11
|
||||
_libgcc_mutex=0.1=main
|
||||
antlr4-python3-runtime=4.8=pypi_0
|
||||
ase=3.21.1=pypi_0
|
||||
ca-certificates=2021.1.19=h06a4308_1
|
||||
cached-property=1.5.2=pypi_0
|
||||
certifi=2020.12.5=py37h06a4308_0
|
||||
cffi=1.14.5=pypi_0
|
||||
chardet=4.0.0=pypi_0
|
||||
cmake=3.18.4.post1=pypi_0
|
||||
cycler=0.10.0=pypi_0
|
||||
dataclasses=0.6=pypi_0
|
||||
decorator=4.4.2=pypi_0
|
||||
future=0.18.2=pypi_0
|
||||
googledrivedownloader=0.4=pypi_0
|
||||
h5py=3.2.1=pypi_0
|
||||
hydra-core=1.0.6=pypi_0
|
||||
idna=2.10=pypi_0
|
||||
importlib-resources=5.1.2=pypi_0
|
||||
intel-openmp=2021.1.2=pypi_0
|
||||
isodate=0.6.0=pypi_0
|
||||
jinja2=2.11.3=pypi_0
|
||||
joblib=1.0.1=pypi_0
|
||||
kiwisolver=1.3.1=pypi_0
|
||||
ld_impl_linux-64=2.33.1=h53a641e_7
|
||||
libedit=3.1.20191231=h14c3975_1
|
||||
libffi=3.3=he6710b0_2
|
||||
libgcc-ng=9.1.0=hdf63c60_0
|
||||
libstdcxx-ng=9.1.0=hdf63c60_0
|
||||
llvmlite=0.35.0=pypi_0
|
||||
magma-cuda112=2.5.2=1
|
||||
markupsafe=1.1.1=pypi_0
|
||||
matplotlib=3.3.4=pypi_0
|
||||
mkl=2021.1.1=pypi_0
|
||||
mkl-include=2021.1.1=pypi_0
|
||||
ncurses=6.2=he6710b0_1
|
||||
networkx=2.5=pypi_0
|
||||
ninja=1.10.0.post2=pypi_0
|
||||
numba=0.52.0=pypi_0
|
||||
numpy=1.20.1=pypi_0
|
||||
omegaconf=2.0.6=pypi_0
|
||||
openssl=1.1.1j=h27cfd23_0
|
||||
pandas=1.2.3=pypi_0
|
||||
pillow=8.1.2=pypi_0
|
||||
pip=21.0.1=py37h06a4308_0
|
||||
prefetch-generator=1.0.1=pypi_0
|
||||
pycparser=2.20=pypi_0
|
||||
pyparsing=2.4.7=pypi_0
|
||||
python=3.7.9=h7579374_0
|
||||
python-dateutil=2.8.1=pypi_0
|
||||
python-louvain=0.15=pypi_0
|
||||
pytz=2021.1=pypi_0
|
||||
pyyaml=5.4.1=pypi_0
|
||||
rdflib=5.0.0=pypi_0
|
||||
readline=8.1=h27cfd23_0
|
||||
requests=2.25.1=pypi_0
|
||||
scikit-learn=0.24.1=pypi_0
|
||||
scipy=1.6.1=pypi_0
|
||||
seaborn=0.11.1=pypi_0
|
||||
setuptools=52.0.0=py37h06a4308_0
|
||||
six=1.15.0=pypi_0
|
||||
sqlite=3.33.0=h62c20be_0
|
||||
tbb=2021.1.1=pypi_0
|
||||
texttable=1.6.3=pypi_0
|
||||
threadpoolctl=2.1.0=pypi_0
|
||||
tk=8.6.10=hbc83047_0
|
||||
torch=1.8.0+cu111=pypi_0
|
||||
torch-cluster=1.5.9=pypi_0
|
||||
torch-geometric=1.6.3=pypi_0
|
||||
torch-scatter=2.0.6=pypi_0
|
||||
torch-sparse=0.6.9=pypi_0
|
||||
torch-spline-conv=1.2.1=pypi_0
|
||||
torchaudio=0.8.0=pypi_0
|
||||
torchvision=0.9.0+cu111=pypi_0
|
||||
tqdm=4.59.0=pypi_0
|
||||
typing-extensions=3.7.4.3=pypi_0
|
||||
urllib3=1.26.3=pypi_0
|
||||
wheel=0.36.2=pyhd3eb1b0_0
|
||||
xz=5.2.5=h7b6447c_0
|
||||
zipp=3.4.1=pypi_0
|
||||
zlib=1.2.11=h7b6447c_3
|
||||
|
@ -1,85 +1,41 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from tqdm import tqdm
|
||||
sys.path.append(os.path.dirname(sys.path[0]))
|
||||
from src.utils.Vocabulary import Vocab
|
||||
|
||||
from utils.Vocabulary import Vocab
|
||||
|
||||
|
||||
def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab, save_path: str, file_type: str):
|
||||
# def parse_json_list_2_pyg_object(jsonl_file: str):
|
||||
train_type = ['train', 'valid', 'test']
|
||||
def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab):
|
||||
index = 0
|
||||
file_index = 0
|
||||
type_index = 0
|
||||
valid_flag = True
|
||||
test_flag = True
|
||||
file_len = len(os.listdir(jsonl_file))
|
||||
|
||||
for file in tqdm(os.listdir(jsonl_file), desc=file_type):
|
||||
if index >= file_len * 0.8 and valid_flag:
|
||||
type_index += 1
|
||||
valid_flag = False
|
||||
file_index = 0
|
||||
print("make valid set")
|
||||
elif index >= file_len * 0.9 and test_flag:
|
||||
type_index += 1
|
||||
test_flag = False
|
||||
file_index = 0
|
||||
print("make test set")
|
||||
j = json_to_pt(file=jsonl_file + file, label=label, vocab=vocab, save_path=save_path, file_type=file_type, train_type=train_type[type_index], index=file_index)
|
||||
index += 1
|
||||
file_index += 1
|
||||
|
||||
|
||||
def json_to_pt(file: str, label: int, vocab: Vocab, save_path: str, file_type: str, train_type: str, index: int):
|
||||
if not os.path.exists(save_path+f"{train_type}_{file_type}/"):
|
||||
os.makedirs(save_path+f"{train_type}_{file_type}/")
|
||||
with open(file, "r", encoding="utf-8") as item:
|
||||
line = item.readline()
|
||||
try:
|
||||
item = json.loads(line)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
print(e)
|
||||
print(file)
|
||||
return False
|
||||
item_hash = item['hash']
|
||||
acfg_list = []
|
||||
for one_acfg in item['acfg_list']: # list of dict of acfg
|
||||
block_features = one_acfg['block_features']
|
||||
block_edges = one_acfg['block_edges']
|
||||
one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float),
|
||||
edge_index=torch.tensor(block_edges, dtype=torch.long))
|
||||
acfg_list.append(one_acfg_data)
|
||||
|
||||
item_function_names = item['function_names']
|
||||
item_function_edges = item['function_edges']
|
||||
|
||||
local_function_name_list = item_function_names[:len(acfg_list)]
|
||||
assert len(acfg_list) == len(
|
||||
local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List"
|
||||
external_function_name_list = item_function_names[len(acfg_list):]
|
||||
|
||||
external_function_index_list = [vocab[f_name] for f_name in external_function_name_list]
|
||||
|
||||
torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list,
|
||||
function_edges=item_function_edges, targets=label),
|
||||
save_path + "{}_{}/{}_{}.pt".format(train_type, file_type, file_type, index))
|
||||
return True
|
||||
|
||||
with open(jsonl_file, "r", encoding="utf-8") as file:
|
||||
for item in tqdm(file):
|
||||
item = json.loads(item)
|
||||
item_hash = item['hash']
|
||||
|
||||
acfg_list = []
|
||||
for one_acfg in item['acfg_list']: # list of dict of acfg
|
||||
block_features = one_acfg['block_features']
|
||||
block_edges = one_acfg['block_edges']
|
||||
one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float), edge_index=torch.tensor(block_edges, dtype=torch.long))
|
||||
acfg_list.append(one_acfg_data)
|
||||
|
||||
item_function_names = item['function_names']
|
||||
item_function_edges = item['function_edges']
|
||||
|
||||
local_function_name_list = item_function_names[:len(acfg_list)]
|
||||
assert len(acfg_list) == len(local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List"
|
||||
external_function_name_list = item_function_names[len(acfg_list):]
|
||||
|
||||
external_function_index_list = [vocab[f_name] for f_name in external_function_name_list]
|
||||
index += 1
|
||||
torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list, function_edges=item_function_edges, targets=label), "./{}.pt".format(index))
|
||||
print(index)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
malware_json_path = "/home/king/python/data/jsonl/malware/"
|
||||
benign_json_path = "/home/king/python/data/jsonl/benign/"
|
||||
train_vocab_file = "/home/king/python/data/fun_name_sort.jsonl"
|
||||
save_vocab_file = "/home/king/python/data/DatasetJSON_remake/"
|
||||
file_type = ["malware", "benign"]
|
||||
json_path = "./sample.jsonl"
|
||||
train_vocab_file = "../ReservedDataCode/processed_dataset/train_external_function_name_vocab.jsonl"
|
||||
max_vocab_size = 10000
|
||||
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
|
||||
# parse_json_list_2_pyg_object(jsonl_file=malware_json_path, label=1, vocab=vocabulary, save_path=save_vocab_file,
|
||||
# file_type=file_type[0])
|
||||
parse_json_list_2_pyg_object(jsonl_file=benign_json_path, label=0, vocab=vocabulary, save_path=save_vocab_file,
|
||||
file_type=file_type[1])
|
||||
parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary)
|
@ -1,38 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from itertools import islice
|
||||
import heapq
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == '__main__':
|
||||
mal_file_name = '/home/king/python/data/jsonl/infected_jsonl/'
|
||||
ben_file_name = '/home/king/python/data/jsonl/refind_jsonl/'
|
||||
fun_name_dict = {}
|
||||
for file in tqdm(os.listdir(mal_file_name)):
|
||||
with open(mal_file_name + file, 'r') as item:
|
||||
item = json.loads(item.readline())
|
||||
item_fun_list = item['function_names']
|
||||
for fun_name in item_fun_list:
|
||||
if fun_name != 'start' and fun_name != 'start_0' and 'sub_' not in fun_name:
|
||||
if fun_name_dict.get(fun_name) is not None:
|
||||
fun_name_dict[fun_name] += 1
|
||||
else:
|
||||
fun_name_dict[fun_name] = 1
|
||||
for file in tqdm(os.listdir(ben_file_name)):
|
||||
with open(ben_file_name + file, 'r') as item:
|
||||
item = json.loads(item.readline())
|
||||
item_fun_list = item['function_names']
|
||||
for fun_name in item_fun_list:
|
||||
if fun_name != 'start' and fun_name != 'start_0' and 'sub_' not in fun_name:
|
||||
if fun_name_dict.get(fun_name) is not None:
|
||||
fun_name_dict[fun_name] += 1
|
||||
else:
|
||||
fun_name_dict[fun_name] = 1
|
||||
with open('/home/king/python/data/processed_dataset/train_external_function_name_vocab.jsonl', 'w') as file:
|
||||
largest_10000_items = heapq.nlargest(10000, fun_name_dict.items(), key=lambda item: item[1])
|
||||
for key, value in largest_10000_items:
|
||||
temp = {"f_name": key, "count": value}
|
||||
file.write(json.dumps(temp) + '\n')
|
||||
|
||||
|
1
samples/sample.jsonl
Normal file
1
samples/sample.jsonl
Normal file
@ -0,0 +1 @@
|
||||
{"function_edges": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]], "acfg_list": [{"block_number": 3, "block_edges": [[0, 0, 1, 1], [0, 2, 0, 2]], "block_features": [[0, 2, 1, 0, 7, 0, 1, 1, 4, 0, 0], [0, 2, 0, 0, 3, 1, 0, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]]}, {"block_number": 29, "block_edges": [[0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 12, 12, 13, 14, 14, 15, 16, 17, 18, 19, 19, 20, 20, 21, 21, 23, 24, 24, 26, 26, 27, 28], [16, 0, 2, 0, 4, 1, 3, 3, 3, 25, 15, 8, 6, 6, 7, 28, 12, 9, 23, 16, 25, 11, 21, 17, 13, 19, 22, 14, 19, 18, 27, 24, 23, 26, 21, 22, 25, 10, 25, 5, 14, 8]], "block_features": [[8, 2, 1, 5, 36, 0, 6, 0, 2, 0, 0], [0, 7, 0, 0, 3, 0, 1, 1, 1, 0, 0], [0, 7, 0, 0, 2, 0, 1, 1, 0, 0, 0], [0, 7, 0, 1, 8, 1, 2, 0, 0, 0, 0], [0, 7, 1, 0, 2, 0, 1, 0, 0, 0, 0], [0, 7, 0, 0, 1, 0, 0, 0, 1, 0, 0], [1, 18, 0, 1, 9, 0, 2, 1, 1, 0, 0], [1, 21, 1, 0, 3, 0, 1, 1, 0, 0, 0], [0, 21, 0, 1, 4, 1, 2, 0, 0, 0, 0], [0, 24, 0, 2, 12, 1, 3, 0, 0, 0, 0], [1, 26, 0, 3, 16, 0, 4, 1, 4, 0, 0], [1, 2, 0, 5, 22, 0, 5, 0, 1, 0, 0], [5, 4, 1, 3, 21, 0, 4, 1, 3, 0, 0], [4, 11, 0, 2, 17, 1, 2, 0, 1, 0, 0], [2, 14, 0, 1, 12, 0, 2, 1, 1, 0, 0], [3, 17, 0, 0, 10, 0, 1, 0, 1, 0, 0], [1, 1, 0, 1, 5, 0, 2, 0, 0, 0, 0], [0, 14, 0, 0, 1, 0, 0, 0, 0, 0, 0], [3, 17, 0, 0, 7, 0, 0, 0, 0, 0, 0], [0, 17, 0, 1, 5, 0, 2, 1, 1, 0, 0], [2, 28, 1, 1, 11, 1, 2, 1, 1, 0, 0], [0, 11, 0, 1, 8, 1, 2, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0], [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0], [12, 27, 1, 7, 41, 0, 8, 1, 6, 0, 0], [0, 0, 1, 0, 7, 1, 0, 0, 0, 1, 0], [2, 9, 0, 2, 17, 0, 3, 1, 3, 0, 0], [2, 14, 0, 0, 5, 0, 1, 0, 4, 0, 0], [1, 21, 4, 1, 13, 0, 2, 0, 5, 0, 0]]}], "function_names": ["sub_401000", "start", "GetTempPathW", "GetFileSize", "GetCurrentDirectoryW", "DeleteFileW", "CloseHandle", "WriteFile", "lstrcmpW", "ReadFile", "GetModuleHandleW", "ExitProcess", "HeapCreate", "HeapAlloc", "GetModuleFileNameW", "CreateFileW", "lstrlenW", "ShellExecuteW", "wsprintfW", "HttpSendRequestW", "InternetSetOptionW", "InternetQueryOptionW", "HttpOpenRequestW", "HttpQueryInfoW", "InternetReadFile", "InternetConnectW", "InternetOpenW"], "hash": "316ebb797d5196020eee013cfe771671fff4da8859adc9f385f52a74e82f4e55", "function_number": 27}
|
@ -15,9 +15,11 @@ from hydra.utils import to_absolute_path
|
||||
from omegaconf import DictConfig
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
import matplotlib.pyplot as plt
|
||||
from torch import nn
|
||||
from torch_geometric.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
|
||||
@ -26,10 +28,6 @@ from utils.PreProcessedDataset import MalwareDetectionDataset
|
||||
from utils.RealBatch import create_real_batch_data
|
||||
from utils.Vocabulary import Vocab
|
||||
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = "1"
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
||||
|
||||
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __iter__(self):
|
||||
@ -99,7 +97,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op
|
||||
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
|
||||
valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs,
|
||||
original_valid_length=original_valid_length, result_file=result_file, details=False)
|
||||
|
||||
|
||||
if best_auc < valid_result.ROC_AUC_Score:
|
||||
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
|
||||
_idx_bt,
|
||||
@ -122,7 +120,7 @@ def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, op
|
||||
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
|
||||
model.eval()
|
||||
if distributed:
|
||||
local_device = torch.device("cpu", local_rank)
|
||||
local_device = torch.device("cuda", local_rank)
|
||||
else:
|
||||
local_device = torch.device("cuda")
|
||||
|
||||
@ -174,7 +172,6 @@ def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distribu
|
||||
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
|
||||
|
||||
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
if details is True:
|
||||
@ -251,19 +248,10 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
|
||||
time_start = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank,
|
||||
train_loader=train_loader,
|
||||
valid_loader=valid_loader,
|
||||
model=model,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
nprocs=nprocs,
|
||||
idx_epoch=epoch,
|
||||
best_auc=best_auc,
|
||||
best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length,
|
||||
result_file=log_result_file)
|
||||
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
||||
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length, result_file=log_result_file)
|
||||
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
|
||||
|
||||
# adjust learning rate
|
||||
@ -279,7 +267,7 @@ def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, m
|
||||
|
||||
|
||||
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
|
||||
@hydra.main(config_path="../configs/", config_name="default.yaml")
|
||||
@hydra.main(config_path="../configs/", config_name="default.yaml", version_base=None)
|
||||
def main_app(config: DictConfig):
|
||||
# set seed for determinism for reproduction
|
||||
random.seed(config.Training.seed)
|
||||
@ -317,12 +305,8 @@ def main_app(config: DictConfig):
|
||||
# num_gpus = 1
|
||||
num_gpus = torch.cuda.device_count()
|
||||
log.info("Total number of GPUs = {}".format(num_gpus))
|
||||
# try:
|
||||
# torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
|
||||
# main_train_worker(0, num_gpus, _train_params, _model_params, _optim_params, log, log_result_file)
|
||||
|
||||
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
|
||||
|
||||
else:
|
||||
|
@ -1,382 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as torch_mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
from hydra.utils import to_absolute_path
|
||||
from omegaconf import DictConfig
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
from torch import nn
|
||||
from torch_geometric.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
|
||||
from utils.ParameterClasses import ModelParams, TrainParams, OptimizerParams, OneEpochResult
|
||||
from utils.PreProcessedDataset import MalwareDetectionDataset
|
||||
from utils.RealBatch import create_real_batch_data
|
||||
from utils.Vocabulary import Vocab
|
||||
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __iter__(self):
|
||||
return BackgroundGenerator(super().__iter__())
|
||||
|
||||
|
||||
def reduce_sum(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
|
||||
return rt
|
||||
|
||||
|
||||
def reduce_mean(tensor, nprocs):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
|
||||
rt /= nprocs
|
||||
return rt
|
||||
|
||||
|
||||
|
||||
def all_gather_concat(tensor):
|
||||
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
dist.all_gather(tensors_gather, tensor, async_op=False)
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, optimizer, nprocs, idx_epoch, best_auc, best_model_file, original_valid_length, result_file):
|
||||
# print(train_loader.dataset.__dict__)
|
||||
model.train()
|
||||
local_device = torch.device("cuda", local_rank)
|
||||
write_into(file_name_path=result_file, log_str="The local device = {} among {} nprocs in the {}-th epoch.".format(local_device, nprocs, idx_epoch))
|
||||
|
||||
until_sum_reduced_loss = 0.0
|
||||
smooth_avg_reduced_loss_list = []
|
||||
|
||||
for _idx_bt, _batch in enumerate(tqdm(train_loader, desc="reading _batch from local_rank={}".format(local_rank))):
|
||||
model.train()
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=_batch)
|
||||
if _real_batch is None:
|
||||
write_into(result_file,
|
||||
"{}\n_real_batch is None in creating the real batch data of training ... ".format("*-" * 100))
|
||||
continue
|
||||
|
||||
_real_batch = _real_batch.to(local_device)
|
||||
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
|
||||
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
|
||||
|
||||
train_batch_pred = model(real_local_batch=_real_batch,
|
||||
real_bt_positions=_position,
|
||||
bt_external_names=_external_list,
|
||||
bt_all_function_edges=_function_edges,
|
||||
local_device=local_device)
|
||||
train_batch_pred = train_batch_pred.squeeze()
|
||||
|
||||
loss = criterion(train_batch_pred, _true_classes)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
reduced_loss = reduce_mean(loss, nprocs)
|
||||
until_sum_reduced_loss += reduced_loss.item()
|
||||
smooth_avg_reduced_loss_list.append(until_sum_reduced_loss / (_idx_bt + 1))
|
||||
|
||||
if _idx_bt != 0 and (_idx_bt % math.ceil(len(train_loader) / 3) == 0 or _idx_bt == int(len(train_loader) - 1)):
|
||||
|
||||
val_start_time = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(result_file, "\nIn {}-th epoch, {}-th batch, we start to validate ... ".format(idx_epoch, _idx_bt))
|
||||
|
||||
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
|
||||
valid_result = validate(local_rank=local_rank,
|
||||
valid_loader=valid_loader,
|
||||
model=model,
|
||||
criterion=criterion,
|
||||
evaluate_flag=_eval_flag,
|
||||
distributed=True, # 分布式
|
||||
nprocs=nprocs,
|
||||
original_valid_length=original_valid_length,
|
||||
result_file=result_file,
|
||||
details=True # 验证细节
|
||||
)
|
||||
|
||||
if best_auc < valid_result.ROC_AUC_Score:
|
||||
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
|
||||
_idx_bt,
|
||||
best_auc,
|
||||
valid_result.ROC_AUC_Score,
|
||||
best_model_file)
|
||||
best_auc = valid_result.ROC_AUC_Score
|
||||
torch.save(model.module.state_dict(), best_model_file)
|
||||
else:
|
||||
_info = "[AUC NOT Increased!] AUC decreased from {:.5f} to {:.5f}!".format(best_auc, valid_result.ROC_AUC_Score)
|
||||
|
||||
if local_rank == 0:
|
||||
write_into(result_file, valid_result.__str__())
|
||||
write_into(result_file, _info)
|
||||
write_into(result_file, "[#One Validation Time#] Consume about {} time period for one validation.".format(datetime.now() - val_start_time))
|
||||
|
||||
return smooth_avg_reduced_loss_list, best_auc
|
||||
|
||||
|
||||
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
|
||||
model.eval()
|
||||
if distributed:
|
||||
local_device = torch.device("cuda", local_rank)
|
||||
else:
|
||||
local_device = torch.device("cuda")
|
||||
|
||||
sum_loss = torch.tensor(0.0, dtype=torch.float, device=local_device)
|
||||
n_samples = torch.tensor(0, dtype=torch.int, device=local_device)
|
||||
|
||||
all_true_classes = []
|
||||
all_positive_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for idx_batch, data in enumerate(tqdm(valid_loader)):
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
|
||||
if _real_batch is None:
|
||||
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of validation ... ".format("*-" * 100))
|
||||
continue
|
||||
_real_batch = _real_batch.to(local_device)
|
||||
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
|
||||
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
|
||||
|
||||
batch_pred = model(real_local_batch=_real_batch,
|
||||
real_bt_positions=_position,
|
||||
bt_external_names=_external_list,
|
||||
bt_all_function_edges=_function_edges,
|
||||
local_device=local_device)
|
||||
batch_pred = batch_pred.squeeze(-1)
|
||||
loss = criterion(batch_pred, _true_classes)
|
||||
sum_loss += loss.item()
|
||||
|
||||
n_samples += len(batch_pred)
|
||||
|
||||
all_true_classes.append(_true_classes)
|
||||
all_positive_probs.append(batch_pred)
|
||||
|
||||
avg_loss = sum_loss / (idx_batch + 1)
|
||||
all_true_classes = torch.cat(all_true_classes, dim=0)
|
||||
all_positive_probs = torch.cat(all_positive_probs, dim=0)
|
||||
|
||||
if distributed:
|
||||
torch.distributed.barrier()
|
||||
reduced_n_samples = reduce_sum(n_samples)
|
||||
reduced_avg_loss = reduce_mean(avg_loss, nprocs)
|
||||
gather_true_classes = all_gather_concat(all_true_classes).detach().cpu().numpy()
|
||||
gather_positive_prods = all_gather_concat(all_positive_probs).detach().cpu().numpy()
|
||||
|
||||
gather_true_classes = gather_true_classes[:original_valid_length]
|
||||
gather_positive_prods = gather_positive_prods[:original_valid_length]
|
||||
|
||||
else:
|
||||
reduced_n_samples = n_samples
|
||||
reduced_avg_loss = avg_loss
|
||||
gather_true_classes = all_true_classes.detach().cpu().numpy()
|
||||
gather_positive_prods = all_positive_probs.detach().cpu().numpy()
|
||||
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
|
||||
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
if details is True:
|
||||
_100_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.01)
|
||||
_1000_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.001)
|
||||
else:
|
||||
_100_info, _1000_info = "None", "None"
|
||||
|
||||
_eval_result = OneEpochResult(Epoch_Flag=evaluate_flag,
|
||||
Number_Samples=reduced_n_samples,
|
||||
Avg_Loss=reduced_avg_loss,
|
||||
Info_100=_100_info,
|
||||
Info_1000=_1000_info,
|
||||
ROC_AUC_Score=_roc_auc_score,
|
||||
Thresholds=_thresholds,
|
||||
TPRs=_tpr,
|
||||
FPRs=_fpr)
|
||||
return _eval_result
|
||||
|
||||
|
||||
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
|
||||
log_result_file: str):
|
||||
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
|
||||
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# model configure
|
||||
vocab = Vocab(freq_file=train_params.external_func_vocab_file, max_vocab_size=train_params.max_vocab_size)
|
||||
|
||||
if model_params.ablation_models.lower() == "full":
|
||||
model = HierarchicalGraphNeuralNetwork(model_params=model_params, external_vocab=vocab, global_log=global_log)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
model.cuda(local_rank)
|
||||
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
if local_rank == 0:
|
||||
write_into(file_name_path=log_result_file, log_str=model.__str__())
|
||||
|
||||
# loss function
|
||||
criterion = nn.BCELoss().cuda(local_rank)
|
||||
|
||||
lr = optimizer_params.lr
|
||||
if optimizer_params.optimizer_name.lower() == 'adam':
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
elif optimizer_params.optimizer_name.lower() == 'adamw':
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=optimizer_params.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
max_epochs = train_params.max_epochs
|
||||
|
||||
dataset_root_path = train_params.processed_files_path
|
||||
train_batch_size = train_params.train_bs
|
||||
test_batch_size = train_params.test_bs
|
||||
|
||||
# training dataset & dataloader
|
||||
train_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="train")
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
|
||||
# validation dataset & dataloader
|
||||
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
|
||||
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
|
||||
valid_loader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, pin_memory=True, sampler=valid_sampler)
|
||||
|
||||
if local_rank == 0:
|
||||
write_into(file_name_path=log_result_file, log_str="Training dataset={}, sampler={}, loader={}".format(len(train_dataset), len(train_sampler), len(train_loader)))
|
||||
write_into(file_name_path=log_result_file, log_str="Validation dataset={}, sampler={}, loader={}".format(len(valid_dataset), len(valid_sampler), len(valid_loader)))
|
||||
|
||||
best_auc = 0.0
|
||||
ori_valid_length = len(valid_dataset)
|
||||
best_model_path = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(local_rank))
|
||||
|
||||
all_batch_avg_smooth_loss_list = []
|
||||
for epoch in range(max_epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
valid_sampler.set_epoch(epoch)
|
||||
|
||||
# train for one epoch
|
||||
time_start = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank,
|
||||
train_loader=train_loader,
|
||||
valid_loader=valid_loader,
|
||||
model=model,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
nprocs=nprocs,
|
||||
idx_epoch=epoch,
|
||||
best_auc=best_auc,
|
||||
best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length,
|
||||
result_file=log_result_file)
|
||||
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
|
||||
|
||||
# adjust learning rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = param_group['lr'] / optimizer_params.learning_anneal
|
||||
|
||||
time_end = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} end of {}-epoch, with best_auc={}, len of loss={}, end time={}, and time period={} {}".format("*" * 50, epoch, best_auc,
|
||||
len(smooth_avg_reduced_loss_list),
|
||||
time_end.strftime("%Y-%m-%d@%H:%M:%S"),
|
||||
time_end - time_start, "*" * 50))
|
||||
|
||||
|
||||
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
|
||||
@hydra.main(config_path="../configs/", config_name="default.yaml")
|
||||
def main_app(config: DictConfig):
|
||||
# set seed for determinism for reproduction
|
||||
random.seed(config.Training.seed)
|
||||
np.random.seed(config.Training.seed)
|
||||
torch.manual_seed(config.Training.seed)
|
||||
torch.cuda.manual_seed(config.Training.seed)
|
||||
torch.cuda.manual_seed_all(config.Training.seed)
|
||||
|
||||
# setting hyper-parameter for Training / Model / Optimizer
|
||||
_train_params = TrainParams(processed_files_path=to_absolute_path(config.Data.preprocess_root), max_epochs=config.Training.max_epoches, train_bs=config.Training.train_batch_size, test_bs=config.Training.test_batch_size, external_func_vocab_file=to_absolute_path(config.Data.train_vocab_file), max_vocab_size=config.Data.max_vocab_size)
|
||||
_model_params = ModelParams(gnn_type=config.Model.gnn_type, pool_type=config.Model.pool_type, acfg_init_dims=config.Model.acfg_node_init_dims, cfg_filters=config.Model.cfg_filters, fcg_filters=config.Model.fcg_filters, number_classes=config.Model.number_classes, dropout_rate=config.Model.drapout_rate, ablation_models=config.Model.ablation_models)
|
||||
_optim_params = OptimizerParams(optimizer_name=config.Optimizer.name, lr=config.Optimizer.learning_rate, weight_decay=config.Optimizer.weight_decay, learning_anneal=config.Optimizer.learning_anneal)
|
||||
|
||||
# logging
|
||||
log = logging.getLogger("DistTrainModel.py")
|
||||
log.setLevel("DEBUG")
|
||||
log.warning("Hydra's Current Working Directory: {}".format(os.getcwd()))
|
||||
|
||||
# setting for the log directory
|
||||
result_file = '{}_{}_{}_ACFG_{}_FCG_{}_Epoch_{}_TrainBS_{}_LR_{}_Time_{}.txt'.format(_model_params.ablation_models, _model_params.gnn_type, _model_params.pool_type,
|
||||
_model_params.cfg_filters, _model_params.fcg_filters, _train_params.max_epochs,
|
||||
_train_params.train_bs, _optim_params.lr, datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||
log_result_file = os.path.join(os.getcwd(), result_file)
|
||||
|
||||
_other_params = {"Hydra's Current Working Directory": os.getcwd(), "seed": config.Training.seed, "log result file": log_result_file, "only_test_path": config.Training.only_test_path}
|
||||
|
||||
params_print_log(_train_params.__dict__, log_result_file)
|
||||
params_print_log(_model_params.__dict__, log_result_file)
|
||||
params_print_log(_optim_params.__dict__, log_result_file)
|
||||
params_print_log(_other_params, log_result_file)
|
||||
|
||||
if config.Training.only_test_path == 'None':
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(config.Training.dist_port)
|
||||
# num_gpus = 1
|
||||
num_gpus = torch.cuda.device_count()
|
||||
log.info("Total number of GPUs = {}".format(num_gpus))
|
||||
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
|
||||
|
||||
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
|
||||
|
||||
else:
|
||||
best_model_file = config.Training.only_test_path
|
||||
|
||||
# model re-init and loading
|
||||
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
|
||||
device = torch.device('cuda')
|
||||
train_vocab_path = _train_params.external_func_vocab_file
|
||||
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
|
||||
|
||||
if _model_params.ablation_models.lower() == "full":
|
||||
model = HierarchicalGraphNeuralNetwork(model_params=_model_params, external_vocab=vocab, global_log=log)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model.to(device)
|
||||
model.load_state_dict(torch.load(best_model_file, map_location=device))
|
||||
criterion = nn.BCELoss().cuda()
|
||||
|
||||
test_batch_size = config.Training.test_batch_size
|
||||
dataset_root_path = _train_params.processed_files_path
|
||||
# validation dataset & dataloader
|
||||
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
|
||||
valid_dataloader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, shuffle=False)
|
||||
log.info("Total number of all validation samples = {} ".format(len(valid_dataset)))
|
||||
|
||||
# testing dataset & dataloader
|
||||
test_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="test")
|
||||
test_dataloader = DataLoaderX(dataset=test_dataset, batch_size=test_batch_size, shuffle=False)
|
||||
log.info("Total number of all testing samples = {} ".format(len(test_dataset)))
|
||||
|
||||
_valid_result = validate(valid_loader=valid_dataloader, model=model, criterion=criterion, evaluate_flag="DoubleCheckValidation", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(valid_dataset), result_file=log_result_file, details=True)
|
||||
log.warning("\n\n" + _valid_result.__str__())
|
||||
_test_result = validate(valid_loader=test_dataloader, model=model, criterion=criterion, evaluate_flag="FinalTestingResult", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(test_dataset), result_file=log_result_file, details=True)
|
||||
log.warning("\n\n" + _test_result.__str__())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_app()
|
@ -13,6 +13,8 @@ sys.path.append("..")
|
||||
from utils.ParameterClasses import ModelParams # noqa
|
||||
from utils.Vocabulary import Vocab # noqa
|
||||
|
||||
perform_MID = True
|
||||
|
||||
|
||||
def div_with_small_value(n, d, eps=1e-8):
|
||||
d = d * (d > eps).float() + eps * (d <= eps).float()
|
||||
@ -63,7 +65,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
self.global_log = global_log
|
||||
|
||||
# Hierarchical 1: Control Flow Graph (CFG) embedding and pooling
|
||||
# print(type(model_params.cfg_filters), model_params.cfg_filters)
|
||||
print(type(model_params.cfg_filters), model_params.cfg_filters)
|
||||
if type(model_params.cfg_filters) == str:
|
||||
cfg_filter_list = [int(number_filter) for number_filter in model_params.cfg_filters.split("-")]
|
||||
else:
|
||||
@ -89,7 +91,7 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
|
||||
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
|
||||
self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2, embedding_dim=cfg_filter_list[-1], padding_idx=external_vocab.pad_idx)
|
||||
# print(type(model_params.fcg_filters), model_params.fcg_filters)
|
||||
print(type(model_params.fcg_filters), model_params.fcg_filters)
|
||||
if type(model_params.fcg_filters) == str:
|
||||
fcg_filter_list = [int(number_filter) for number_filter in model_params.fcg_filters.split("-")]
|
||||
else:
|
||||
@ -119,8 +121,11 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
|
||||
# self.last_activation = nn.Softmax(dim=1)
|
||||
# self.last_activation = nn.LogSoftmax(dim=1)
|
||||
|
||||
|
||||
def forward_cfg_gnn(self, local_batch: Batch):
|
||||
if perform_MID:
|
||||
return self.forward_MID_cfg_gnn(local_batch)
|
||||
|
||||
in_x, edge_index = local_batch.x, local_batch.edge_index
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
@ -129,7 +134,35 @@ class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
in_x = out_x
|
||||
local_batch.x = in_x
|
||||
return local_batch
|
||||
|
||||
|
||||
# 多实例分解的CFG嵌入学习
|
||||
def forward_MID_cfg_gnn(self, local_batch: Batch):
|
||||
device = torch.device('cuda')
|
||||
cfg_embeddings = []
|
||||
cfg_subgraph_loader = local_batch.cfg_subgraph_loader
|
||||
for acfg in cfg_subgraph_loader:
|
||||
subgraph_embeddings = []
|
||||
for subgraph in acfg:
|
||||
in_x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device)
|
||||
batch = torch.zeros(in_x.size(0), dtype=torch.long, device=device)
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
out_x = pt_f.relu(out_x, inplace=True)
|
||||
out_x = self.dropout(out_x)
|
||||
in_x = out_x
|
||||
subgraph_embedding = global_mean_pool(in_x, batch)
|
||||
subgraph_embeddings.append(subgraph_embedding.squeeze(0))
|
||||
cfg_embedding = torch.stack(subgraph_embeddings).mean(dim=0)
|
||||
cfg_embeddings.append(cfg_embedding)
|
||||
|
||||
cfg_embeddings = torch.stack(cfg_embeddings)
|
||||
# 创建一个新的 batch 向量
|
||||
batch_size = cfg_embeddings.size(0)
|
||||
new_batch = torch.arange(batch_size)
|
||||
local_batch.x = cfg_embeddings.to(device)
|
||||
local_batch.batch = new_batch.to(device)
|
||||
return local_batch
|
||||
|
||||
def aggregate_cfg_batch_pooling(self, local_batch: Batch):
|
||||
if self.pool == 'global_max_pool':
|
||||
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)
|
||||
|
@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Dataset, DataLoader
|
||||
from utils.RealBatch import create_real_batch_data # noqa
|
||||
from src.utils.RealBatch import create_real_batch_data # noqa
|
||||
|
||||
|
||||
class MalwareDetectionDataset(Dataset):
|
||||
@ -23,12 +23,15 @@ class MalwareDetectionDataset(Dataset):
|
||||
if os.path.splitext(name)[-1] == '.pt':
|
||||
files.append(name)
|
||||
return files
|
||||
|
||||
def len(self):
|
||||
|
||||
def __len__(self):
|
||||
# def len(self):
|
||||
# return 201
|
||||
return len(self.malware_files) + len(self.benign_files)
|
||||
|
||||
|
||||
def len(self):
|
||||
return len(self.malware_files) + len(self.benign_files)
|
||||
|
||||
def get(self, idx):
|
||||
split = len(self.malware_files)
|
||||
# split = 100
|
||||
@ -55,6 +58,7 @@ def _simulating(_dataset, _batch_size: int):
|
||||
if index >= 3:
|
||||
break
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
|
||||
print(data)
|
||||
print("Hash: ", _hash)
|
||||
print("Position: ", _position)
|
||||
print("\n")
|
||||
@ -65,7 +69,8 @@ def _simulating(_dataset, _batch_size: int):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_path: str = '/home/king/python/data/processed_dataset/DatasetJSON_test'
|
||||
root_path: str = '/root/autodl-tmp/'
|
||||
# root_path: str = 'D:\\hkn\\infected\\datasets\\proprecessed_pt'
|
||||
i_batch_size = 2
|
||||
|
||||
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
|
||||
|
@ -1,13 +1,26 @@
|
||||
import torch
|
||||
from torch_geometric.data import Batch
|
||||
from torch_geometric.data import DataLoader
|
||||
from pprint import pprint
|
||||
import pymetis as metis
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.utils import to_networkx, from_networkx
|
||||
from typing import List
|
||||
|
||||
from src.utils.RemoveCycleEdgesTrueskill import perform_breaking_edges
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
perform_MID = True
|
||||
|
||||
|
||||
def create_real_batch_data(one_batch: Batch):
|
||||
if perform_MID:
|
||||
return create_MID_real_batch_data(one_batch)
|
||||
|
||||
real = []
|
||||
position = [0]
|
||||
count = 0
|
||||
|
||||
assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other"
|
||||
|
||||
for item in one_batch.local_acfgs:
|
||||
@ -20,4 +33,177 @@ def create_real_batch_data(one_batch: Batch):
|
||||
return (None for _ in range(6))
|
||||
else:
|
||||
real_batch = Batch.from_data_list(real)
|
||||
return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets
|
||||
return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets
|
||||
|
||||
|
||||
# cfg的多实例分解batch
|
||||
def create_MID_real_batch_data(one_batch: Batch):
|
||||
# 原始cfg列表
|
||||
real = []
|
||||
# 分解后的cfg列表,每个元素都是一个cfg分解后的子图列表list[Data],因此它是二维的
|
||||
real_decomposed_cfgs = []
|
||||
position = [0]
|
||||
count = 0
|
||||
|
||||
assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(
|
||||
one_batch.hash), "size of each component must be equal to each other"
|
||||
|
||||
for pe in one_batch.local_acfgs:
|
||||
# 遍历pe中的acfg
|
||||
for acfg in pe:
|
||||
# 多实例分解acfg,返回一个list[Data]
|
||||
sub_graphs = multi_instance_decompose(acfg)
|
||||
real_decomposed_cfgs.append(sub_graphs)
|
||||
real.append(acfg)
|
||||
# 一个exe中的所有acfg数量
|
||||
count += len(pe)
|
||||
# 记录每个exe中acfg的数量
|
||||
position.append(count)
|
||||
|
||||
if len(one_batch.local_acfgs) == 1 and len(one_batch.local_acfgs[0]) == 0:
|
||||
return (None for _ in range(6))
|
||||
else:
|
||||
real_batch = Batch.from_data_list(real)
|
||||
real_batch.cfg_subgraph_loader = real_decomposed_cfgs
|
||||
return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets
|
||||
|
||||
|
||||
# CFG的多实例分解
|
||||
# return list[Data]
|
||||
def multi_instance_decompose(acfg: Data):
|
||||
# edge_index : torch.tensor([[0, 1, 2], [1, 2, 3]])
|
||||
# acfg.x是每个块的11维属性张量
|
||||
# 只有一个节点的图,所以没有边信息,edge_index长度为0,不需要处理
|
||||
# if len(acfg.x) == 1:
|
||||
# return [acfg]
|
||||
#
|
||||
# g = nx.DiGraph()
|
||||
# g.add_edges_from(edge_index2edges(acfg.edge_index))
|
||||
|
||||
return metis_MID(acfg)
|
||||
|
||||
|
||||
def metis_MID(acfg):
|
||||
nparts = 3
|
||||
node_num = len(acfg.x)
|
||||
if node_num < 10:
|
||||
return [acfg]
|
||||
G = to_networkx(acfg).to_undirected()
|
||||
adjacency_list = [list(G.neighbors(node)) for node in sorted(G.nodes)]
|
||||
_, parts = metis.part_graph(nparts=nparts, adjacency=adjacency_list, recursive=False) # 分解为3个子图
|
||||
sub_graphs: List[Data] = []
|
||||
subgraph_nodes: List[List[int]] = []
|
||||
for i, p in enumerate(parts):
|
||||
while p >= len(subgraph_nodes):
|
||||
subgraph_nodes.append([])
|
||||
subgraph_nodes[p].append(i)
|
||||
|
||||
for sub_graph in subgraph_nodes:
|
||||
if len(sub_graph) == 0:
|
||||
continue
|
||||
indices = torch.unique(torch.tensor(sub_graph)).long()
|
||||
sub_G = G.subgraph(sub_graph)
|
||||
sub_data = from_networkx(sub_G)
|
||||
sub_data.x = acfg.x[indices]
|
||||
sub_graphs.append(sub_data)
|
||||
|
||||
return sub_graphs
|
||||
|
||||
|
||||
# 将循环结构和剩余的层次结构分别保存为Data,返回list[Data]
|
||||
def structure_MID(acfg, g):
|
||||
result = []
|
||||
|
||||
# 提取图中的自环结构
|
||||
# self_loop = nx.selfloop_edges(g)
|
||||
# result += [create_data(acfg.x, torch.tensor([[loop[0]], [loop[0]]])) for loop in self_loop]
|
||||
|
||||
# 这里不能用self_loop,因为这个变量在被读取之后会被清空
|
||||
# g.remove_edges_from(nx.selfloop_edges(g))
|
||||
|
||||
# 提取图中的循环结构
|
||||
# cycles = list(nx.simple_cycles(g))
|
||||
# if len(cycles) > 0:
|
||||
# max_cycle = max(len(cycle) for cycle in cycles)
|
||||
# max_cycle = max(cycles, key=len)
|
||||
# print(max_cycle)
|
||||
# result += [create_data(acfg.x, torch.tensor([path[:-1], path[1:]])) for path in cycles]
|
||||
|
||||
# time_start = datetime.now()
|
||||
# 将图转换为DAG,尽可能保留原图的层次结构
|
||||
perform_breaking_edges(g)
|
||||
graph_index = edges2edge_index(g.edges)
|
||||
result.append(create_data(acfg.x, graph_index))
|
||||
# time_end = datetime.now()
|
||||
# print("process time = {}".format(time_end - time_start))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 将图进行拓扑排序后进行dfs找出图中所有最长子路径,分别保存为Data,返回list[Data]
|
||||
def topological_MID(acfg, g):
|
||||
# 将图转换为DAG,尽可能保留原图的层次结构
|
||||
perform_breaking_edges(g)
|
||||
# 拓扑排序
|
||||
topo_order = list(nx.topological_sort(g))
|
||||
# 初始化距离数组
|
||||
dist = {node: float('-inf') for node in g.nodes()}
|
||||
# 初始化前驱节点数组
|
||||
prev = {node: [] for node in g.nodes()}
|
||||
# 初始化起点节点的距离为0
|
||||
for node in g.nodes():
|
||||
if g.in_degree(node) == 0:
|
||||
dist[node] = 0
|
||||
# 遍历所有节点
|
||||
for node in topo_order:
|
||||
# 遍历所有后继节点
|
||||
for successor in g.successors(node):
|
||||
# 更新距离
|
||||
if dist[successor] < dist[node] + 1:
|
||||
dist[successor] = dist[node] + 1
|
||||
prev[successor] = [node]
|
||||
elif dist[successor] == dist[node] + 1:
|
||||
prev[successor].append(node)
|
||||
|
||||
# 计算最长路径的长度
|
||||
max_length = max(dist.values())
|
||||
|
||||
# 初始化最长路径数组
|
||||
longest_paths = []
|
||||
|
||||
# 遍历所有终点节点
|
||||
for node in g.nodes():
|
||||
if g.out_degree(node) == 0 and dist[node] == max_length:
|
||||
dfs(node, [node], prev, longest_paths)
|
||||
|
||||
# 将acfg中所有最长子图路径转换为Data集合,也就是说一个acfg被转换为一个Data列表
|
||||
return [create_data(acfg.x, torch.tensor([path[:-1], path[1:]])) for path in longest_paths]
|
||||
|
||||
|
||||
def dfs(node, path, prev, longest_paths):
|
||||
if len(prev[node]) == 0:
|
||||
longest_paths.append(path)
|
||||
else:
|
||||
for predecessor in prev[node]:
|
||||
dfs(predecessor, [predecessor] + path, prev, longest_paths)
|
||||
|
||||
|
||||
# 获取edge_index中出现过的所有元素,在x中仅保留这些元素所对应的索引
|
||||
# 用于快速创建子图的x属性,注意x和edge_index都是torch.tensor
|
||||
def create_data(x, edge_index):
|
||||
# 获取edge_index中出现过的元素
|
||||
indices = torch.unique(edge_index).long()
|
||||
return Data(x[indices], edge_index)
|
||||
|
||||
|
||||
# torch.tensor([[1, 2, 3], [2, 3, 4]]) => [(1, 2), (2, 3), (3, 4)]
|
||||
# 将edge_index张量转换为edges数组
|
||||
def edge_index2edges(edge_index):
|
||||
return list(zip(*edge_index.tolist()))
|
||||
|
||||
|
||||
# OutEdgeView([(1, 2), (2, 3), (3, 4)]) => torch.tensor([[1, 2, 3], [2, 3, 4]])
|
||||
# 将edges数组转换为edge_index张量
|
||||
def edges2edge_index(edges):
|
||||
edges = list(edges.items())
|
||||
return torch.tensor([list(edge[0]) for edge in edges]).t().contiguous()
|
||||
|
170
src/utils/RemoveCycleEdgesTrueskill.py
Normal file
170
src/utils/RemoveCycleEdgesTrueskill.py
Normal file
@ -0,0 +1,170 @@
|
||||
import random
|
||||
from trueskill import Rating, rate_1vs1
|
||||
import networkx as nx
|
||||
import os
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def __get_big_sccs(g):
|
||||
num_big_sccs = 0
|
||||
big_sccs = []
|
||||
for sub in (g.subgraph(c).copy() for c in nx.strongly_connected_components(g)):
|
||||
number_of_nodes = sub.number_of_nodes()
|
||||
if number_of_nodes >= 2:
|
||||
# strongly connected components
|
||||
num_big_sccs += 1
|
||||
big_sccs.append(sub)
|
||||
# print(" # big sccs: %d" % (num_big_sccs))
|
||||
return big_sccs
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def __nodes_in_scc(sccs):
|
||||
scc_nodes = []
|
||||
scc_edges = []
|
||||
for scc in sccs:
|
||||
scc_nodes += list(scc.nodes())
|
||||
scc_edges += list(scc.edges())
|
||||
|
||||
# print("# nodes in big sccs: %d" % len(scc_nodes))
|
||||
# print("# edges in big sccs: %d" % len(scc_edges))
|
||||
return scc_nodes
|
||||
|
||||
|
||||
def __scores_of_nodes_in_scc(sccs, players):
|
||||
scc_nodes = __nodes_in_scc(sccs)
|
||||
scc_nodes_score_dict = {}
|
||||
for node in scc_nodes:
|
||||
scc_nodes_score_dict[node] = players[node]
|
||||
# print("# scores of nodes in scc: %d" % (len(scc_nodes_score_dict)))
|
||||
return scc_nodes_score_dict
|
||||
|
||||
|
||||
def __filter_big_scc(g, edges_to_be_removed):
|
||||
# Given a graph g and edges to be removed
|
||||
# Return a list of big scc subgraphs (# of nodes >= 2)
|
||||
g.remove_edges_from(edges_to_be_removed)
|
||||
sub_graphs = filter(lambda scc: scc.number_of_nodes() >= 2,
|
||||
[g.subgraph(c).copy() for c in nx.strongly_connected_components(g)])
|
||||
return sub_graphs
|
||||
|
||||
|
||||
def __remove_cycle_edges_by_agony_iterately(sccs, players, edges_to_be_removed):
|
||||
while True:
|
||||
graph = sccs.pop()
|
||||
pair_max_agony = None
|
||||
max_agony = -1
|
||||
for pair in graph.edges():
|
||||
u, v = pair
|
||||
agony = max(players[u] - players[v], 0)
|
||||
if agony >= max_agony:
|
||||
pair_max_agony = (u, v)
|
||||
max_agony = agony
|
||||
edges_to_be_removed.append(pair_max_agony)
|
||||
# print("graph: (%d,%d), edge to be removed: %s, agony: %0.4f" % (graph.number_of_nodes(),graph.number_of_edges(),pair_max_agony,max_agony))
|
||||
graph.remove_edges_from([pair_max_agony])
|
||||
# print("graph: (%d,%d), edge to be removed: %s" % (graph.number_of_nodes(),graph.number_of_edges(),pair_max_agony))
|
||||
sub_graphs = __filter_big_scc(graph, [pair_max_agony])
|
||||
if sub_graphs:
|
||||
for index, sub in enumerate(sub_graphs):
|
||||
sccs.append(sub)
|
||||
if not sccs:
|
||||
return
|
||||
|
||||
|
||||
def __compute_trueskill(pairs, players):
|
||||
if not players:
|
||||
for u, v in pairs:
|
||||
if u not in players:
|
||||
players[u] = Rating()
|
||||
if v not in players:
|
||||
players[v] = Rating()
|
||||
|
||||
random.shuffle(pairs)
|
||||
for u, v in pairs:
|
||||
players[v], players[u] = rate_1vs1(players[v], players[u])
|
||||
|
||||
return players
|
||||
|
||||
|
||||
def __get_players_score(players, n_sigma):
|
||||
relative_score = {}
|
||||
for k, v in players.items():
|
||||
relative_score[k] = players[k].mu - n_sigma * players[k].sigma
|
||||
return relative_score
|
||||
|
||||
|
||||
def __measure_pairs_agreement(pairs, nodes_score):
|
||||
# whether nodes in pairs agree with their ranking scores
|
||||
num_correct_pairs = 0
|
||||
num_wrong_pairs = 0
|
||||
total_pairs = 0
|
||||
for u, v in pairs:
|
||||
if u in nodes_score and v in nodes_score:
|
||||
if nodes_score[u] <= nodes_score[v]:
|
||||
num_correct_pairs += 1
|
||||
else:
|
||||
num_wrong_pairs += 1
|
||||
total_pairs += 1
|
||||
if total_pairs != 0:
|
||||
acc = num_correct_pairs * 1.0 / total_pairs
|
||||
# print("correct pairs: %d, wrong pairs: %d, total pairs: %d, accuracy: %0.4f" % (num_correct_pairs,num_wrong_pairs,total_pairs,num_correct_pairs*1.0/total_pairs))
|
||||
else:
|
||||
acc = 1
|
||||
# print("total pairs: 0, accuracy: 1")
|
||||
return acc
|
||||
|
||||
|
||||
def __trueskill_ratings(pairs, iter_times=15, n_sigma=3, threshold=0.85):
|
||||
players = {}
|
||||
for i in range(iter_times):
|
||||
players = __compute_trueskill(pairs, players)
|
||||
relative_scores = __get_players_score(players, n_sigma=n_sigma)
|
||||
accu = __measure_pairs_agreement(pairs, relative_scores)
|
||||
if accu >= threshold:
|
||||
return relative_scores
|
||||
# end = datetime.now()
|
||||
# time_used = end - start
|
||||
# print("time used in computing true skill: %0.4f s, iteration time is: %i" % (time_used.seconds, (i + 1)))
|
||||
return relative_scores
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
# def breaking_cycles_by_TS(graph_path):
|
||||
# g = nx.read_edgelist(graph_path, create_using=nx.DiGraph(), nodetype=int)
|
||||
# players_score_dict = __trueskill_ratings(list(g.edges()), iter_times=15, n_sigma=3, threshold=0.95)
|
||||
# g.remove_edges_from(list(nx.selfloop_edges(g)))
|
||||
# big_sccs = __get_big_sccs(g)
|
||||
# scc_nodes_score_dict = __scores_of_nodes_in_scc(big_sccs, players_score_dict)
|
||||
# edges_to_be_removed = []
|
||||
# if len(big_sccs) == 0:
|
||||
# print("After removal of self loop edgs: %s" % nx.is_directed_acyclic_graph(g))
|
||||
# return
|
||||
#
|
||||
# __remove_cycle_edges_by_agony_iterately(big_sccs, scc_nodes_score_dict, edges_to_be_removed)
|
||||
# g.remove_edges_from(edges_to_be_removed)
|
||||
# nx.write_edgelist(g, out_path)
|
||||
|
||||
|
||||
# edgelist形式为[(x0, y0), (x1, y1), (x2, y2), (x3, y3)]
|
||||
def perform_breaking_edges(g):
|
||||
players_score_dict = __trueskill_ratings(list(g.edges()), iter_times=15, n_sigma=3, threshold=0.95)
|
||||
g.remove_edges_from(list(nx.selfloop_edges(g)))
|
||||
big_sccs = __get_big_sccs(g)
|
||||
scc_nodes_score_dict = __scores_of_nodes_in_scc(big_sccs, players_score_dict)
|
||||
edges_to_be_removed = []
|
||||
|
||||
# 移除自环已经是DAG
|
||||
if len(big_sccs) == 0:
|
||||
return
|
||||
|
||||
__remove_cycle_edges_by_agony_iterately(big_sccs, scc_nodes_score_dict, edges_to_be_removed)
|
||||
g.remove_edges_from(edges_to_be_removed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# for test only
|
||||
graph_path = 'D:\\hkn\\infected\\datasets\\text_only_nx\\text.edges'
|
||||
out_path = 'D:\\hkn\\infected\\datasets\\text_only_nx\\result.edges'
|
||||
|
||||
# breaking_cycles_by_TS(graph_path)
|
@ -64,7 +64,6 @@ class Vocab:
|
||||
def load_freq_counter_from_file(file_path: str, min_freq: int):
|
||||
freq_dict = {}
|
||||
with open(file_path, 'r') as f:
|
||||
|
||||
for line in tqdm(f, desc="Load frequency list from the file of {} ... ".format(file_path)):
|
||||
line = json.loads(line)
|
||||
f_name = line["f_name"]
|
||||
|
97
src/utils/util.py
Normal file
97
src/utils/util.py
Normal file
@ -0,0 +1,97 @@
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
|
||||
|
||||
def transfer_remote():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
all_benign = '/root/autodl-tmp/all_benign'
|
||||
one_family_malware = '/root/autodl-tmp/one_family_malware'
|
||||
|
||||
sample = ['malware', 'benign']
|
||||
tags = ['test', 'train', 'valid']
|
||||
for s in sample:
|
||||
index = 0
|
||||
for t in tags:
|
||||
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
|
||||
for file in os.listdir(file_dir):
|
||||
dest_dir = all_benign if s == 'benign' else one_family_malware
|
||||
shutil.copy(os.path.join(file_dir, file), os.path.join(dest_dir, str(index)))
|
||||
index += 1
|
||||
|
||||
delete_remote()
|
||||
|
||||
|
||||
def delete_remote():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
sample = ['malware', 'benign']
|
||||
tags = ['test', 'train', 'valid']
|
||||
for s in sample:
|
||||
for t in tags:
|
||||
file_dir = os.path.join(samples_dir, '{}_{}'.format(t, s))
|
||||
for f in os.listdir(file_dir):
|
||||
os.remove(os.path.join(file_dir, f))
|
||||
|
||||
|
||||
def delete_remote_backup():
|
||||
samples_dir = '/root/autodl-tmp'
|
||||
dir_name = ['all', 'all_benign', 'one_family_malware', 'test_malware_backup', 'valid_malware_backup', 'train_malware_backup']
|
||||
for name in dir_name:
|
||||
file_dir = os.path.join(samples_dir, name)
|
||||
if os.path.exists(file_dir):
|
||||
for f in os.listdir(file_dir):
|
||||
os.remove(os.path.join(file_dir, f))
|
||||
|
||||
|
||||
# 重命名pt文件使之与代码相符
|
||||
def rename(file_dir, mal_or_be, postfix):
|
||||
tag_set = ['train', 'test', 'valid']
|
||||
for tag in tag_set:
|
||||
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
|
||||
for index, f in enumerate(os.listdir(data_dir)):
|
||||
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, 'm' + f))
|
||||
for tag in tag_set:
|
||||
data_dir = os.path.join(file_dir, '{}_{}{}/'.format(tag, mal_or_be, postfix))
|
||||
for index, f in enumerate(os.listdir(data_dir)):
|
||||
os.rename(os.path.join(data_dir, f), os.path.join(data_dir, '{}_{}.pt'.format(mal_or_be, index)))
|
||||
|
||||
|
||||
def split_samples(flag):
|
||||
postfix = ''
|
||||
file_dir = '/root/autodl-tmp'
|
||||
if flag == 'one_family':
|
||||
path = os.path.join(file_dir, 'one_family_malware')
|
||||
tag = 'malware'
|
||||
elif flag == 'standard':
|
||||
path = os.path.join(file_dir, 'all')
|
||||
postfix = '_backup'
|
||||
tag = 'malware'
|
||||
elif flag == 'benign':
|
||||
path = os.path.join(file_dir, 'all_benign')
|
||||
tag = 'benign'
|
||||
else:
|
||||
print('flag not implemented')
|
||||
return
|
||||
|
||||
os_list = os.listdir(path)
|
||||
random.shuffle(os_list)
|
||||
# 6/2/2 分数据
|
||||
train_len = int(len(os_list) * 0.6)
|
||||
test_len = int(train_len / 3)
|
||||
for index, f in enumerate(os_list):
|
||||
if index < train_len:
|
||||
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'train_{}'.format(tag) + postfix))
|
||||
elif train_len <= index < train_len + test_len:
|
||||
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'test_{}'.format(tag) + postfix))
|
||||
else:
|
||||
shutil.copy(os.path.join(path, f), os.path.join(file_dir, 'valid_{}'.format(tag) + postfix))
|
||||
rename(file_dir, tag, postfix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# delete_remote_backup()
|
||||
# transfer_remote()
|
||||
# delete_remote()
|
||||
split_samples('standard')
|
||||
split_samples('one_family')
|
||||
split_samples('benign')
|
@ -1,9 +0,0 @@
|
||||
import torch_geometric
|
||||
import torch
|
||||
if __name__ == '__main__':
|
||||
print(torch.__version__)
|
||||
print(torch.cuda.device_count())
|
||||
print(torch.cuda.get_device_name())
|
||||
print(torch.cuda.is_available())
|
||||
# print(torch.cuda.nccl.is_available())
|
||||
print(torch.cuda.nccl.version())
|
Loading…
Reference in New Issue
Block a user