first commit
This commit is contained in:
commit
da15051db9
149
.gitignore
vendored
Normal file
149
.gitignore
vendored
Normal file
@ -0,0 +1,149 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
|
||||
.DS_Store
|
||||
.idea
|
||||
*.log
|
||||
*.temp
|
||||
*.pyc
|
||||
*.py.un~
|
||||
.idea*
|
||||
.idea/*
|
||||
\.idea*
|
||||
.vscode*
|
||||
.vscode/*
|
||||
\.vscode*
|
||||
__pycache__/
|
||||
ReservedDataCode/*
|
||||
ReservedDataCode*
|
||||
src/outputs/
|
||||
*.zip
|
||||
*.sh
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 ryder
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
BIN
MalGraph.png
Normal file
BIN
MalGraph.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 359 KiB |
33
README.md
Normal file
33
README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# MalGraph: Hierarchical Graph Neural Networks for Robust Windows Malware Detection
|
||||
|
||||
|
||||
## 1. Description
|
||||
|
||||
In this paper, we propose [**MalGraph**](https://www.github.com), a hierarchical graph neural network to build an effective and robust Windows PE malware detection. In particular, MalGraph makes better use of the hierarchical graph representation which incorporates the inter-function call graph with intra-function control flow graphs for representing the executable program.
|
||||
If you find this paper or this repo is useful for you, we would like to have you cite our paper as follows.
|
||||
|
||||
![system](./MalGraph.png)
|
||||
|
||||
Xiang Ling, Lingfei Wu, Wei Deng, Sheng Zhang, Zhenqing Qu, Jiangyu Zhang, Tengfei Ma, Bin Wang, Chunming Wu and Shouling Ji, MalGraph: Hierarchical Graph Neural Networks for Robust Windows Malware Detection, IEEE International Conference on Computer Communications (**INFOCOM**) 2022.
|
||||
|
||||
```
|
||||
@inproceedings{ling2022malgraph,
|
||||
title={{MalGraph}: Hierarchical Graph Neural Networks for Robust Windows Malware Detection},
|
||||
author={Ling, Xiang and Wu, Lingfei and Deng, Wei and Qu, Zhenqing and Zhang, Jiangyu and Zhang, Sheng and Ma, Tengfei and Wang, Bin and Wu, Chunming and Ji, Shouling},
|
||||
booktitle={IEEE Conference on Computer Communications (INFOCOM)},
|
||||
pages={},
|
||||
year={2022},
|
||||
address={Virtual Event},
|
||||
publisher={IEEE}
|
||||
}
|
||||
```
|
||||
|
||||
## 2. Usage
|
||||
|
||||
It is necessary to use IDA Pro and the function of `parse_json_list_2_pyg_object()` in `./samples/Preprocess.py` to generate the hierarchical graph representation with `torch_geometric.data.Data` for all file in the training/validation/testing set.
|
||||
One illustrating example can be found in the folder of `samples/README.md`
|
||||
In addition, we also provide the ``src/utils/Vocabulary.py`` which is the core implementation for us to build and save the vocabulary of external function names for all files in the training set.
|
||||
|
||||
After preparing both the preprocess_dataset and train_vocab_file, we can train our model of MalGraph.
|
||||
It is noted that all hyper-parameters are configured in ``configs/default.yaml`` which should be first configured manually.
|
||||
And then we can directly train our model via ``python DistTrainModel.py``
|
27
configs/default.yaml
Normal file
27
configs/default.yaml
Normal file
@ -0,0 +1,27 @@
|
||||
Data:
|
||||
preprocess_root: "../data/processed_dataset/DatasetJSON/"
|
||||
train_vocab_file: "../data/processed_dataset/train_external_function_name_vocab.jsonl"
|
||||
max_vocab_size: 10000 # modify according to the result of 1BuildExternalVocab.py
|
||||
Training:
|
||||
cuda: True # enable GPU training if cuda is available
|
||||
dist_backend: "nccl" # if using torch.distribution, the backend to be used
|
||||
dist_port: "1234"
|
||||
max_epoches: 10
|
||||
train_batch_size: 16
|
||||
test_batch_size: 32
|
||||
seed: 19920208
|
||||
only_test_path: 'None'
|
||||
Model:
|
||||
ablation_models: "Full" # "Full"
|
||||
gnn_type: "GraphSAGE" # "GraphSAGE" / "GCN"
|
||||
pool_type: "global_max_pool" # "global_max_pool" / "global_mean_pool"
|
||||
acfg_node_init_dims: 11
|
||||
cfg_filters: "200-200"
|
||||
fcg_filters: "200-200"
|
||||
number_classes: 1
|
||||
drapout_rate: 0.2
|
||||
Optimizer:
|
||||
name: "AdamW" # Adam / AdamW
|
||||
learning_rate: 1e-3 # initial learning rate
|
||||
weight_decay: 1e-5 # initial weight decay
|
||||
learning_anneal: 1.1 # Annealing applied to learning rate after each epoch
|
83
requirement_conda.txt
Normal file
83
requirement_conda.txt
Normal file
@ -0,0 +1,83 @@
|
||||
# This file may be used to create an environment using:
|
||||
# $ conda create --name <env> --file <this file>
|
||||
# platform: linux-64
|
||||
_libgcc_mutex=0.1=main
|
||||
antlr4-python3-runtime=4.8=pypi_0
|
||||
ase=3.21.1=pypi_0
|
||||
ca-certificates=2021.1.19=h06a4308_1
|
||||
cached-property=1.5.2=pypi_0
|
||||
certifi=2020.12.5=py37h06a4308_0
|
||||
cffi=1.14.5=pypi_0
|
||||
chardet=4.0.0=pypi_0
|
||||
cmake=3.18.4.post1=pypi_0
|
||||
cycler=0.10.0=pypi_0
|
||||
dataclasses=0.6=pypi_0
|
||||
decorator=4.4.2=pypi_0
|
||||
future=0.18.2=pypi_0
|
||||
googledrivedownloader=0.4=pypi_0
|
||||
h5py=3.2.1=pypi_0
|
||||
hydra-core=1.0.6=pypi_0
|
||||
idna=2.10=pypi_0
|
||||
importlib-resources=5.1.2=pypi_0
|
||||
intel-openmp=2021.1.2=pypi_0
|
||||
isodate=0.6.0=pypi_0
|
||||
jinja2=2.11.3=pypi_0
|
||||
joblib=1.0.1=pypi_0
|
||||
kiwisolver=1.3.1=pypi_0
|
||||
ld_impl_linux-64=2.33.1=h53a641e_7
|
||||
libedit=3.1.20191231=h14c3975_1
|
||||
libffi=3.3=he6710b0_2
|
||||
libgcc-ng=9.1.0=hdf63c60_0
|
||||
libstdcxx-ng=9.1.0=hdf63c60_0
|
||||
llvmlite=0.35.0=pypi_0
|
||||
magma-cuda112=2.5.2=1
|
||||
markupsafe=1.1.1=pypi_0
|
||||
matplotlib=3.3.4=pypi_0
|
||||
mkl=2021.1.1=pypi_0
|
||||
mkl-include=2021.1.1=pypi_0
|
||||
ncurses=6.2=he6710b0_1
|
||||
networkx=2.5=pypi_0
|
||||
ninja=1.10.0.post2=pypi_0
|
||||
numba=0.52.0=pypi_0
|
||||
numpy=1.20.1=pypi_0
|
||||
omegaconf=2.0.6=pypi_0
|
||||
openssl=1.1.1j=h27cfd23_0
|
||||
pandas=1.2.3=pypi_0
|
||||
pillow=8.1.2=pypi_0
|
||||
pip=21.0.1=py37h06a4308_0
|
||||
prefetch-generator=1.0.1=pypi_0
|
||||
pycparser=2.20=pypi_0
|
||||
pyparsing=2.4.7=pypi_0
|
||||
python=3.7.9=h7579374_0
|
||||
python-dateutil=2.8.1=pypi_0
|
||||
python-louvain=0.15=pypi_0
|
||||
pytz=2021.1=pypi_0
|
||||
pyyaml=5.4.1=pypi_0
|
||||
rdflib=5.0.0=pypi_0
|
||||
readline=8.1=h27cfd23_0
|
||||
requests=2.25.1=pypi_0
|
||||
scikit-learn=0.24.1=pypi_0
|
||||
scipy=1.6.1=pypi_0
|
||||
seaborn=0.11.1=pypi_0
|
||||
setuptools=52.0.0=py37h06a4308_0
|
||||
six=1.15.0=pypi_0
|
||||
sqlite=3.33.0=h62c20be_0
|
||||
tbb=2021.1.1=pypi_0
|
||||
texttable=1.6.3=pypi_0
|
||||
threadpoolctl=2.1.0=pypi_0
|
||||
tk=8.6.10=hbc83047_0
|
||||
torch=1.8.0+cu111=pypi_0
|
||||
torch-cluster=1.5.9=pypi_0
|
||||
torch-geometric=1.6.3=pypi_0
|
||||
torch-scatter=2.0.6=pypi_0
|
||||
torch-sparse=0.6.9=pypi_0
|
||||
torch-spline-conv=1.2.1=pypi_0
|
||||
torchaudio=0.8.0=pypi_0
|
||||
torchvision=0.9.0+cu111=pypi_0
|
||||
tqdm=4.59.0=pypi_0
|
||||
typing-extensions=3.7.4.3=pypi_0
|
||||
urllib3=1.26.3=pypi_0
|
||||
wheel=0.36.2=pyhd3eb1b0_0
|
||||
xz=5.2.5=h7b6447c_0
|
||||
zipp=3.4.1=pypi_0
|
||||
zlib=1.2.11=h7b6447c_3
|
BIN
samples/FunctionCallGraph.png
Normal file
BIN
samples/FunctionCallGraph.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 51 KiB |
41
samples/PreProcess.py
Normal file
41
samples/PreProcess.py
Normal file
@ -0,0 +1,41 @@
|
||||
import json
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.Vocabulary import Vocab
|
||||
|
||||
|
||||
def parse_json_list_2_pyg_object(jsonl_file: str, label: int, vocab: Vocab):
|
||||
index = 0
|
||||
with open(jsonl_file, "r", encoding="utf-8") as file:
|
||||
for item in tqdm(file):
|
||||
item = json.loads(item)
|
||||
item_hash = item['hash']
|
||||
|
||||
acfg_list = []
|
||||
for one_acfg in item['acfg_list']: # list of dict of acfg
|
||||
block_features = one_acfg['block_features']
|
||||
block_edges = one_acfg['block_edges']
|
||||
one_acfg_data = Data(x=torch.tensor(block_features, dtype=torch.float), edge_index=torch.tensor(block_edges, dtype=torch.long))
|
||||
acfg_list.append(one_acfg_data)
|
||||
|
||||
item_function_names = item['function_names']
|
||||
item_function_edges = item['function_edges']
|
||||
|
||||
local_function_name_list = item_function_names[:len(acfg_list)]
|
||||
assert len(acfg_list) == len(local_function_name_list), "The length of ACFG_List should be equal to the length of Local_Function_List"
|
||||
external_function_name_list = item_function_names[len(acfg_list):]
|
||||
|
||||
external_function_index_list = [vocab[f_name] for f_name in external_function_name_list]
|
||||
index += 1
|
||||
torch.save(Data(hash=item_hash, local_acfgs=acfg_list, external_list=external_function_index_list, function_edges=item_function_edges, targets=label), "./{}.pt".format(index))
|
||||
print(index)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
json_path = "./sample.jsonl"
|
||||
train_vocab_file = "../ReservedDataCode/processed_dataset/train_external_function_name_vocab.jsonl"
|
||||
max_vocab_size = 10000
|
||||
vocabulary = Vocab(freq_file=train_vocab_file, max_vocab_size=max_vocab_size)
|
||||
parse_json_list_2_pyg_object(jsonl_file=json_path, label=1, vocab=vocabulary)
|
21
samples/README.md
Normal file
21
samples/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Data Preprocessing
|
||||
|
||||
### STEP 1: PE Disassemble
|
||||
|
||||
We first use IDA Pro 6.4 to disassemble one given portable executable (PE) file, obtaining one function call graph (i.e., FCG, including both external functions and local functions) and corresponding control flow graphs (CFGs) of local functions.
|
||||
In fact, FCG can be exported as Graph Description Language GDL file format, and CFGs can be processed as ACFGs, which are mainly built on the GitHub repo of https://github.com/qian-feng/Gencoding.
|
||||
We therefore refer interested readers to this repo for more details.
|
||||
|
||||
Taking one PE file as an example, we can use IDA Pro to get the following FCG (25 external functions and 2 local functions)
|
||||
![system](./FunctionCallGraph.png)
|
||||
and two CFGs of local functions, i.e., sub_401000 and 40103C as follows.
|
||||
![system](./sub_401000.png)
|
||||
![system](./sub_40103C.png)
|
||||
After that, we can save the above hierarchical graph representation into sample.jsonl as follows.
|
||||
|
||||
```
|
||||
{"function_edges": [[1, 1, ..., 1], [0, 2, ..., 26]], "acfg_list": [{"block_number": 3, "block_edges": [[0, 0, 1, 1], [0, 2, 0, 2]], "block_features": [[0, 2, ...], [0, 2, ...], [1, 0, ...]]}, {"block_number": 29, "block_edges": [[0, 1, ..., 28], [16, 0, ..., 8]], "block_features": [[8, 2, ...], [0, 7, ...], [0, 7, ...], [0, 7, ...], [0, 7, ...], [0, 7,...], [1, 18, ...], [1, 21, ...], [0, 21,...], [0, 24, ...], [1, 26, ...], [1, 2, ...], [5, 4, ...], [4, 11, ...], [2, 14, ...], [3, 17, ...], [1, 1, ...], [0, 14, ...], [3, 17, ...], [0, 17, ...], [2, 28, ...], [0, 11, ...], [0, 0, ...], [1, 1, ...], [12, 27, ...], [0, 0, ...], [2, 9, ...], [2, 14,...], [1, 21, ...]]}], "function_names": ["sub_401000", "start", "GetTempPathW", "GetFileSize", ... , "InternetOpenW"], "hash": "3***5", "function_number": 27}
|
||||
```
|
||||
|
||||
### STEP 2: Convert the resulting json file to PyG data object
|
||||
However, the above resulting json object can not be directly inputted into our model, we therefore convert it into a PyTorch_Geometric `data` object and provide one example python script of `PreProcess.py` for interested readers.
|
1
samples/sample.jsonl
Normal file
1
samples/sample.jsonl
Normal file
@ -0,0 +1 @@
|
||||
{"function_edges": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]], "acfg_list": [{"block_number": 3, "block_edges": [[0, 0, 1, 1], [0, 2, 0, 2]], "block_features": [[0, 2, 1, 0, 7, 0, 1, 1, 4, 0, 0], [0, 2, 0, 0, 3, 1, 0, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]]}, {"block_number": 29, "block_edges": [[0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 12, 12, 13, 14, 14, 15, 16, 17, 18, 19, 19, 20, 20, 21, 21, 23, 24, 24, 26, 26, 27, 28], [16, 0, 2, 0, 4, 1, 3, 3, 3, 25, 15, 8, 6, 6, 7, 28, 12, 9, 23, 16, 25, 11, 21, 17, 13, 19, 22, 14, 19, 18, 27, 24, 23, 26, 21, 22, 25, 10, 25, 5, 14, 8]], "block_features": [[8, 2, 1, 5, 36, 0, 6, 0, 2, 0, 0], [0, 7, 0, 0, 3, 0, 1, 1, 1, 0, 0], [0, 7, 0, 0, 2, 0, 1, 1, 0, 0, 0], [0, 7, 0, 1, 8, 1, 2, 0, 0, 0, 0], [0, 7, 1, 0, 2, 0, 1, 0, 0, 0, 0], [0, 7, 0, 0, 1, 0, 0, 0, 1, 0, 0], [1, 18, 0, 1, 9, 0, 2, 1, 1, 0, 0], [1, 21, 1, 0, 3, 0, 1, 1, 0, 0, 0], [0, 21, 0, 1, 4, 1, 2, 0, 0, 0, 0], [0, 24, 0, 2, 12, 1, 3, 0, 0, 0, 0], [1, 26, 0, 3, 16, 0, 4, 1, 4, 0, 0], [1, 2, 0, 5, 22, 0, 5, 0, 1, 0, 0], [5, 4, 1, 3, 21, 0, 4, 1, 3, 0, 0], [4, 11, 0, 2, 17, 1, 2, 0, 1, 0, 0], [2, 14, 0, 1, 12, 0, 2, 1, 1, 0, 0], [3, 17, 0, 0, 10, 0, 1, 0, 1, 0, 0], [1, 1, 0, 1, 5, 0, 2, 0, 0, 0, 0], [0, 14, 0, 0, 1, 0, 0, 0, 0, 0, 0], [3, 17, 0, 0, 7, 0, 0, 0, 0, 0, 0], [0, 17, 0, 1, 5, 0, 2, 1, 1, 0, 0], [2, 28, 1, 1, 11, 1, 2, 1, 1, 0, 0], [0, 11, 0, 1, 8, 1, 2, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0], [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0], [12, 27, 1, 7, 41, 0, 8, 1, 6, 0, 0], [0, 0, 1, 0, 7, 1, 0, 0, 0, 1, 0], [2, 9, 0, 2, 17, 0, 3, 1, 3, 0, 0], [2, 14, 0, 0, 5, 0, 1, 0, 4, 0, 0], [1, 21, 4, 1, 13, 0, 2, 0, 5, 0, 0]]}], "function_names": ["sub_401000", "start", "GetTempPathW", "GetFileSize", "GetCurrentDirectoryW", "DeleteFileW", "CloseHandle", "WriteFile", "lstrcmpW", "ReadFile", "GetModuleHandleW", "ExitProcess", "HeapCreate", "HeapAlloc", "GetModuleFileNameW", "CreateFileW", "lstrlenW", "ShellExecuteW", "wsprintfW", "HttpSendRequestW", "InternetSetOptionW", "InternetQueryOptionW", "HttpOpenRequestW", "HttpQueryInfoW", "InternetReadFile", "InternetConnectW", "InternetOpenW"], "hash": "316ebb797d5196020eee013cfe771671fff4da8859adc9f385f52a74e82f4e55", "function_number": 27}
|
BIN
samples/sub_401000.png
Normal file
BIN
samples/sub_401000.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 196 KiB |
BIN
samples/sub_40103C.png
Normal file
BIN
samples/sub_40103C.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 61 KiB |
346
src/DistTrainModel.py
Normal file
346
src/DistTrainModel.py
Normal file
@ -0,0 +1,346 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as torch_mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
from hydra.utils import to_absolute_path
|
||||
from omegaconf import DictConfig
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
from torch import nn
|
||||
from torch_geometric.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.HierarchicalGraphModel import HierarchicalGraphNeuralNetwork
|
||||
from utils.FunctionHelpers import write_into, params_print_log, find_threshold_with_fixed_fpr
|
||||
from utils.ParameterClasses import ModelParams, TrainParams, OptimizerParams, OneEpochResult
|
||||
from utils.PreProcessedDataset import MalwareDetectionDataset
|
||||
from utils.RealBatch import create_real_batch_data
|
||||
from utils.Vocabulary import Vocab
|
||||
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __iter__(self):
|
||||
return BackgroundGenerator(super().__iter__())
|
||||
|
||||
|
||||
def reduce_sum(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
|
||||
return rt
|
||||
|
||||
|
||||
def reduce_mean(tensor, nprocs):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM) # noqa
|
||||
rt /= nprocs
|
||||
return rt
|
||||
|
||||
|
||||
def all_gather_concat(tensor):
|
||||
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
dist.all_gather(tensors_gather, tensor, async_op=False)
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def train_one_epoch(local_rank, train_loader, valid_loader, model, criterion, optimizer, nprocs, idx_epoch, best_auc, best_model_file, original_valid_length, result_file):
|
||||
model.train()
|
||||
local_device = torch.device("cuda", local_rank)
|
||||
write_into(file_name_path=result_file, log_str="The local device = {} among {} nprocs in the {}-th epoch.".format(local_device, nprocs, idx_epoch))
|
||||
|
||||
until_sum_reduced_loss = 0.0
|
||||
smooth_avg_reduced_loss_list = []
|
||||
|
||||
for _idx_bt, _batch in enumerate(tqdm(train_loader, desc="reading _batch from local_rank={}".format(local_rank))):
|
||||
model.train()
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=_batch)
|
||||
if _real_batch is None:
|
||||
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of training ... ".format("*-" * 100))
|
||||
continue
|
||||
|
||||
_real_batch = _real_batch.to(local_device)
|
||||
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
|
||||
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
|
||||
|
||||
train_batch_pred = model(real_local_batch=_real_batch, real_bt_positions=_position, bt_external_names=_external_list, bt_all_function_edges=_function_edges, local_device=local_device)
|
||||
train_batch_pred = train_batch_pred.squeeze()
|
||||
|
||||
loss = criterion(train_batch_pred, _true_classes)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
reduced_loss = reduce_mean(loss, nprocs)
|
||||
until_sum_reduced_loss += reduced_loss.item()
|
||||
smooth_avg_reduced_loss_list.append(until_sum_reduced_loss / (_idx_bt + 1))
|
||||
|
||||
if _idx_bt != 0 and (_idx_bt % math.ceil(len(train_loader) / 3) == 0 or _idx_bt == int(len(train_loader) - 1)):
|
||||
|
||||
val_start_time = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(result_file, "\nIn {}-th epoch, {}-th batch, we start to validate ... ".format(idx_epoch, _idx_bt))
|
||||
|
||||
_eval_flag = "Valid_In_Train_Epoch_{}_Batch_{}".format(idx_epoch, _idx_bt)
|
||||
valid_result = validate(local_rank=local_rank, valid_loader=valid_loader, model=model, criterion=criterion, evaluate_flag=_eval_flag, distributed=True, nprocs=nprocs,
|
||||
original_valid_length=original_valid_length, result_file=result_file, details=False)
|
||||
|
||||
if best_auc < valid_result.ROC_AUC_Score:
|
||||
_info = "[AUC Increased!] In evaluation of epoch-{} / batch-{}: AUC increased from {:.5f} < {:.5f}! Saving the model into {}".format(idx_epoch,
|
||||
_idx_bt,
|
||||
best_auc,
|
||||
valid_result.ROC_AUC_Score,
|
||||
best_model_file)
|
||||
best_auc = valid_result.ROC_AUC_Score
|
||||
torch.save(model.module.state_dict(), best_model_file)
|
||||
else:
|
||||
_info = "[AUC NOT Increased!] AUC decreased from {:.5f} to {:.5f}!".format(best_auc, valid_result.ROC_AUC_Score)
|
||||
|
||||
if local_rank == 0:
|
||||
write_into(result_file, valid_result.__str__())
|
||||
write_into(result_file, _info)
|
||||
write_into(result_file, "[#One Validation Time#] Consume about {} time period for one validation.".format(datetime.now() - val_start_time))
|
||||
|
||||
return smooth_avg_reduced_loss_list, best_auc
|
||||
|
||||
|
||||
def validate(local_rank, valid_loader, model, criterion, evaluate_flag, distributed, nprocs, original_valid_length, result_file, details):
|
||||
model.eval()
|
||||
if distributed:
|
||||
local_device = torch.device("cuda", local_rank)
|
||||
else:
|
||||
local_device = torch.device("cuda")
|
||||
|
||||
sum_loss = torch.tensor(0.0, dtype=torch.float, device=local_device)
|
||||
n_samples = torch.tensor(0, dtype=torch.int, device=local_device)
|
||||
|
||||
all_true_classes = []
|
||||
all_positive_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for idx_batch, data in enumerate(tqdm(valid_loader)):
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
|
||||
if _real_batch is None:
|
||||
write_into(result_file, "{}\n_real_batch is None in creating the real batch data of validation ... ".format("*-" * 100))
|
||||
continue
|
||||
_real_batch = _real_batch.to(local_device)
|
||||
_position = torch.tensor(_position, dtype=torch.long).cuda(local_rank, non_blocking=True)
|
||||
_true_classes = _true_classes.float().cuda(local_rank, non_blocking=True)
|
||||
|
||||
batch_pred = model(real_local_batch=_real_batch, real_bt_positions=_position, bt_external_names=_external_list, bt_all_function_edges=_function_edges, local_device=local_device)
|
||||
batch_pred = batch_pred.squeeze(-1)
|
||||
loss = criterion(batch_pred, _true_classes)
|
||||
sum_loss += loss.item()
|
||||
|
||||
n_samples += len(batch_pred)
|
||||
|
||||
all_true_classes.append(_true_classes)
|
||||
all_positive_probs.append(batch_pred)
|
||||
|
||||
avg_loss = sum_loss / (idx_batch + 1)
|
||||
all_true_classes = torch.cat(all_true_classes, dim=0)
|
||||
all_positive_probs = torch.cat(all_positive_probs, dim=0)
|
||||
|
||||
if distributed:
|
||||
torch.distributed.barrier()
|
||||
reduced_n_samples = reduce_sum(n_samples)
|
||||
reduced_avg_loss = reduce_mean(avg_loss, nprocs)
|
||||
gather_true_classes = all_gather_concat(all_true_classes).detach().cpu().numpy()
|
||||
gather_positive_prods = all_gather_concat(all_positive_probs).detach().cpu().numpy()
|
||||
|
||||
gather_true_classes = gather_true_classes[:original_valid_length]
|
||||
gather_positive_prods = gather_positive_prods[:original_valid_length]
|
||||
|
||||
else:
|
||||
reduced_n_samples = n_samples
|
||||
reduced_avg_loss = avg_loss
|
||||
gather_true_classes = all_true_classes.detach().cpu().numpy()
|
||||
gather_positive_prods = all_positive_probs.detach().cpu().numpy()
|
||||
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
|
||||
_roc_auc_score = roc_auc_score(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
_fpr, _tpr, _thresholds = roc_curve(y_true=gather_true_classes, y_score=gather_positive_prods)
|
||||
if details is True:
|
||||
_100_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.01)
|
||||
_1000_info = find_threshold_with_fixed_fpr(y_true=gather_true_classes, y_pred=gather_positive_prods, fpr_target=0.001)
|
||||
else:
|
||||
_100_info, _1000_info = "None", "None"
|
||||
|
||||
_eval_result = OneEpochResult(Epoch_Flag=evaluate_flag, Number_Samples=reduced_n_samples, Avg_Loss=reduced_avg_loss, Info_100=_100_info, Info_1000=_1000_info,
|
||||
ROC_AUC_Score=_roc_auc_score, Thresholds=_thresholds, TPRs=_tpr, FPRs=_fpr)
|
||||
return _eval_result
|
||||
|
||||
|
||||
def main_train_worker(local_rank: int, nprocs: int, train_params: TrainParams, model_params: ModelParams, optimizer_params: OptimizerParams, global_log: logging.Logger,
|
||||
log_result_file: str):
|
||||
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=local_rank)
|
||||
dist.init_process_group(backend='nccl', init_method='env://', world_size=nprocs, rank=local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# model configure
|
||||
vocab = Vocab(freq_file=train_params.external_func_vocab_file, max_vocab_size=train_params.max_vocab_size)
|
||||
|
||||
if model_params.ablation_models.lower() == "full":
|
||||
model = HierarchicalGraphNeuralNetwork(model_params=model_params, external_vocab=vocab, global_log=global_log)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
model.cuda(local_rank)
|
||||
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
if local_rank == 0:
|
||||
write_into(file_name_path=log_result_file, log_str=model.__str__())
|
||||
|
||||
# loss function
|
||||
criterion = nn.BCELoss().cuda(local_rank)
|
||||
|
||||
lr = optimizer_params.lr
|
||||
if optimizer_params.optimizer_name.lower() == 'adam':
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
elif optimizer_params.optimizer_name.lower() == 'adamw':
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=optimizer_params.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
max_epochs = train_params.max_epochs
|
||||
|
||||
dataset_root_path = train_params.processed_files_path
|
||||
train_batch_size = train_params.train_bs
|
||||
test_batch_size = train_params.test_bs
|
||||
|
||||
# training dataset & dataloader
|
||||
train_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="train")
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
|
||||
# validation dataset & dataloader
|
||||
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
|
||||
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
|
||||
valid_loader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, pin_memory=True, sampler=valid_sampler)
|
||||
|
||||
if local_rank == 0:
|
||||
write_into(file_name_path=log_result_file, log_str="Training dataset={}, sampler={}, loader={}".format(len(train_dataset), len(train_sampler), len(train_loader)))
|
||||
write_into(file_name_path=log_result_file, log_str="Validation dataset={}, sampler={}, loader={}".format(len(valid_dataset), len(valid_sampler), len(valid_loader)))
|
||||
|
||||
best_auc = 0.0
|
||||
ori_valid_length = len(valid_dataset)
|
||||
best_model_path = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(local_rank))
|
||||
|
||||
all_batch_avg_smooth_loss_list = []
|
||||
for epoch in range(max_epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
valid_sampler.set_epoch(epoch)
|
||||
|
||||
# train for one epoch
|
||||
time_start = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} start of {}-epoch, init best_auc={}, start time={} {}".format("-" * 50, epoch, best_auc, time_start.strftime("%Y-%m-%d@%H:%M:%S"), "-" * 50))
|
||||
|
||||
smooth_avg_reduced_loss_list, best_auc = train_one_epoch(local_rank=local_rank, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion,
|
||||
optimizer=optimizer, nprocs=nprocs, idx_epoch=epoch, best_auc=best_auc, best_model_file=best_model_path,
|
||||
original_valid_length=ori_valid_length, result_file=log_result_file)
|
||||
all_batch_avg_smooth_loss_list.extend(smooth_avg_reduced_loss_list)
|
||||
|
||||
# adjust learning rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = param_group['lr'] / optimizer_params.learning_anneal
|
||||
|
||||
time_end = datetime.now()
|
||||
if local_rank == 0:
|
||||
write_into(log_result_file, "\n{} end of {}-epoch, with best_auc={}, len of loss={}, end time={}, and time period={} {}".format("*" * 50, epoch, best_auc,
|
||||
len(smooth_avg_reduced_loss_list),
|
||||
time_end.strftime("%Y-%m-%d@%H:%M:%S"),
|
||||
time_end - time_start, "*" * 50))
|
||||
|
||||
|
||||
# https://hydra.cc/docs/tutorials/basic/your_first_app/defaults#overriding-a-config-group-default
|
||||
@hydra.main(config_path="../configs/", config_name="default.yaml")
|
||||
def main_app(config: DictConfig):
|
||||
# set seed for determinism for reproduction
|
||||
random.seed(config.Training.seed)
|
||||
np.random.seed(config.Training.seed)
|
||||
torch.manual_seed(config.Training.seed)
|
||||
torch.cuda.manual_seed(config.Training.seed)
|
||||
torch.cuda.manual_seed_all(config.Training.seed)
|
||||
|
||||
# setting hyper-parameter for Training / Model / Optimizer
|
||||
_train_params = TrainParams(processed_files_path=to_absolute_path(config.Data.preprocess_root), max_epochs=config.Training.max_epoches, train_bs=config.Training.train_batch_size, test_bs=config.Training.test_batch_size, external_func_vocab_file=to_absolute_path(config.Data.train_vocab_file), max_vocab_size=config.Data.max_vocab_size)
|
||||
_model_params = ModelParams(gnn_type=config.Model.gnn_type, pool_type=config.Model.pool_type, acfg_init_dims=config.Model.acfg_node_init_dims, cfg_filters=config.Model.cfg_filters, fcg_filters=config.Model.fcg_filters, number_classes=config.Model.number_classes, dropout_rate=config.Model.drapout_rate, ablation_models=config.Model.ablation_models)
|
||||
_optim_params = OptimizerParams(optimizer_name=config.Optimizer.name, lr=config.Optimizer.learning_rate, weight_decay=config.Optimizer.weight_decay, learning_anneal=config.Optimizer.learning_anneal)
|
||||
|
||||
# logging
|
||||
log = logging.getLogger("DistTrainModel.py")
|
||||
log.setLevel("DEBUG")
|
||||
log.warning("Hydra's Current Working Directory: {}".format(os.getcwd()))
|
||||
|
||||
# setting for the log directory
|
||||
result_file = '{}_{}_{}_ACFG_{}_FCG_{}_Epoch_{}_TrainBS_{}_LR_{}_Time_{}.txt'.format(_model_params.ablation_models, _model_params.gnn_type, _model_params.pool_type,
|
||||
_model_params.cfg_filters, _model_params.fcg_filters, _train_params.max_epochs,
|
||||
_train_params.train_bs, _optim_params.lr, datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||
log_result_file = os.path.join(os.getcwd(), result_file)
|
||||
|
||||
_other_params = {"Hydra's Current Working Directory": os.getcwd(), "seed": config.Training.seed, "log result file": log_result_file, "only_test_path": config.Training.only_test_path}
|
||||
|
||||
params_print_log(_train_params.__dict__, log_result_file)
|
||||
params_print_log(_model_params.__dict__, log_result_file)
|
||||
params_print_log(_optim_params.__dict__, log_result_file)
|
||||
params_print_log(_other_params, log_result_file)
|
||||
|
||||
if config.Training.only_test_path == 'None':
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(config.Training.dist_port)
|
||||
# num_gpus = 1
|
||||
num_gpus = torch.cuda.device_count()
|
||||
log.info("Total number of GPUs = {}".format(num_gpus))
|
||||
torch_mp.spawn(main_train_worker, nprocs=num_gpus, args=(num_gpus, _train_params, _model_params, _optim_params, log, log_result_file,))
|
||||
|
||||
best_model_file = os.path.join(os.getcwd(), 'LocalRank_{}_best_model.pt'.format(0))
|
||||
|
||||
else:
|
||||
best_model_file = config.Training.only_test_path
|
||||
|
||||
# model re-init and loading
|
||||
log.info("\n\nstarting to load the model & re-validation & testing from the file of \"{}\" \n".format(best_model_file))
|
||||
device = torch.device('cuda')
|
||||
train_vocab_path = _train_params.external_func_vocab_file
|
||||
vocab = Vocab(freq_file=train_vocab_path, max_vocab_size=_train_params.max_vocab_size)
|
||||
|
||||
if _model_params.ablation_models.lower() == "full":
|
||||
model = HierarchicalGraphNeuralNetwork(model_params=_model_params, external_vocab=vocab, global_log=log)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model.to(device)
|
||||
model.load_state_dict(torch.load(best_model_file, map_location=device))
|
||||
criterion = nn.BCELoss().cuda()
|
||||
|
||||
test_batch_size = config.Training.test_batch_size
|
||||
dataset_root_path = _train_params.processed_files_path
|
||||
# validation dataset & dataloader
|
||||
valid_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="valid")
|
||||
valid_dataloader = DataLoaderX(dataset=valid_dataset, batch_size=test_batch_size, shuffle=False)
|
||||
log.info("Total number of all validation samples = {} ".format(len(valid_dataset)))
|
||||
|
||||
# testing dataset & dataloader
|
||||
test_dataset = MalwareDetectionDataset(root=dataset_root_path, train_or_test="test")
|
||||
test_dataloader = DataLoaderX(dataset=test_dataset, batch_size=test_batch_size, shuffle=False)
|
||||
log.info("Total number of all testing samples = {} ".format(len(test_dataset)))
|
||||
|
||||
_valid_result = validate(valid_loader=valid_dataloader, model=model, criterion=criterion, evaluate_flag="DoubleCheckValidation", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(valid_dataset), result_file=log_result_file, details=True)
|
||||
log.warning("\n\n" + _valid_result.__str__())
|
||||
_test_result = validate(valid_loader=test_dataloader, model=model, criterion=criterion, evaluate_flag="FinalTestingResult", distributed=False, local_rank=None, nprocs=None, original_valid_length=len(test_dataset), result_file=log_result_file, details=True)
|
||||
log.warning("\n\n" + _test_result.__str__())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_app()
|
213
src/models/HierarchicalGraphModel.py
Normal file
213
src/models/HierarchicalGraphModel.py
Normal file
@ -0,0 +1,213 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Linear
|
||||
from torch.nn import functional as pt_f
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.nn.conv import GCNConv, SAGEConv
|
||||
from torch_geometric.nn.glob import global_max_pool, global_mean_pool
|
||||
|
||||
sys.path.append("..")
|
||||
from utils.ParameterClasses import ModelParams # noqa
|
||||
from utils.Vocabulary import Vocab # noqa
|
||||
|
||||
|
||||
def div_with_small_value(n, d, eps=1e-8):
|
||||
d = d * (d > eps).float() + eps * (d <= eps).float()
|
||||
return n / d
|
||||
|
||||
|
||||
def padding_tensors(tensor_list):
|
||||
num = len(tensor_list)
|
||||
max_len = max([s.shape[0] for s in tensor_list])
|
||||
out_dims = (num, max_len, *tensor_list[0].shape[1:])
|
||||
out_tensor = tensor_list[0].data.new(*out_dims).fill_(0)
|
||||
mask = tensor_list[0].data.new(*out_dims).fill_(0)
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
length = tensor.size(0)
|
||||
out_tensor[i, :length] = tensor
|
||||
mask[i, :length] = 1
|
||||
return out_tensor, mask
|
||||
|
||||
|
||||
def inverse_padding_tensors(tensors, masks):
|
||||
mask_index = torch.sum(masks, dim=-1) / masks.size(-1)
|
||||
# print("mask_index: ", mask_index.size(), mask_index)
|
||||
|
||||
_out_mask_select = torch.masked_select(tensors, (masks == 1)).view(-1, tensors.size(-1))
|
||||
# print("_out_mask_select: ", _out_mask_select.size(), _out_mask_select)
|
||||
|
||||
batch_index = torch.sum(mask_index, dim=-1)
|
||||
# print("batch_index: ", type(batch_index), batch_index.size(), batch_index)
|
||||
|
||||
batch_idx_list = []
|
||||
for idx, num in enumerate(batch_index):
|
||||
batch_idx_list.extend([idx for _ in range(int(num))])
|
||||
return _out_mask_select, batch_idx_list
|
||||
|
||||
|
||||
class HierarchicalGraphNeuralNetwork(nn.Module):
|
||||
def __init__(self, model_params: ModelParams, external_vocab: Vocab, global_log: logging.Logger): # device=torch.device('cuda')
|
||||
super(HierarchicalGraphNeuralNetwork, self).__init__()
|
||||
|
||||
self.conv = model_params.gnn_type.lower()
|
||||
if self.conv not in ['graphsage', 'gcn']:
|
||||
raise NotImplementedError
|
||||
self.pool = model_params.pool_type.lower()
|
||||
if self.pool not in ["global_max_pool", "global_mean_pool"]:
|
||||
raise NotImplementedError
|
||||
|
||||
# self.device = device
|
||||
self.global_log = global_log
|
||||
|
||||
# Hierarchical 1: Control Flow Graph (CFG) embedding and pooling
|
||||
print(type(model_params.cfg_filters), model_params.cfg_filters)
|
||||
if type(model_params.cfg_filters) == str:
|
||||
cfg_filter_list = [int(number_filter) for number_filter in model_params.cfg_filters.split("-")]
|
||||
else:
|
||||
cfg_filter_list = [int(model_params.cfg_filters)]
|
||||
cfg_filter_list.insert(0, model_params.acfg_init_dims)
|
||||
self.cfg_filter_length = len(cfg_filter_list)
|
||||
|
||||
cfg_graphsage_params = [dict(in_channels=cfg_filter_list[i], out_channels=cfg_filter_list[i + 1], bias=True) for i in range(self.cfg_filter_length - 1)] # GraphSAGE for cfg
|
||||
cfg_gcn_params = [dict(in_channels=cfg_filter_list[i], out_channels=cfg_filter_list[i + 1], cached=False, bias=True) for i in range(self.cfg_filter_length - 1)] # GCN for cfg
|
||||
|
||||
cfg_conv_layer_constructor = {
|
||||
'graphsage': dict(constructor=SAGEConv, kwargs=cfg_graphsage_params),
|
||||
'gcn': dict(constructor=GCNConv, kwargs=cfg_gcn_params)
|
||||
}
|
||||
|
||||
cfg_conv = cfg_conv_layer_constructor[self.conv]
|
||||
cfg_constructor = cfg_conv['constructor']
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
setattr(self, 'CFG_gnn_{}'.format(i + 1), cfg_constructor(**cfg_conv['kwargs'][i]))
|
||||
|
||||
# self.dropout = nn.Dropout(p=model_params.dropout_rate).to(self.device)
|
||||
self.dropout = nn.Dropout(p=model_params.dropout_rate)
|
||||
|
||||
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
|
||||
self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2, embedding_dim=cfg_filter_list[-1], padding_idx=external_vocab.pad_idx)
|
||||
print(type(model_params.fcg_filters), model_params.fcg_filters)
|
||||
if type(model_params.fcg_filters) == str:
|
||||
fcg_filter_list = [int(number_filter) for number_filter in model_params.fcg_filters.split("-")]
|
||||
else:
|
||||
fcg_filter_list = [int(model_params.fcg_filters)]
|
||||
|
||||
fcg_filter_list.insert(0, cfg_filter_list[-1])
|
||||
self.fcg_filter_length = len(fcg_filter_list)
|
||||
|
||||
fcg_graphsage_params = [dict(in_channels=fcg_filter_list[i], out_channels=fcg_filter_list[i + 1], bias=True) for i in range(self.fcg_filter_length - 1)] # GraphSAGE for fcg
|
||||
fcg_gcn_params = [dict(in_channels=fcg_filter_list[i], out_channels=fcg_filter_list[i + 1], cached=False, bias=True) for i in range(self.fcg_filter_length - 1)] # GCN for fcg
|
||||
|
||||
fcg_conv_layer_constructor = {
|
||||
'graphsage': dict(constructor=SAGEConv, kwargs=fcg_graphsage_params),
|
||||
'gcn': dict(constructor=GCNConv, kwargs=fcg_gcn_params)
|
||||
}
|
||||
fcg_conv = fcg_conv_layer_constructor[self.conv]
|
||||
fcg_constructor = fcg_conv['constructor']
|
||||
for i in range(self.fcg_filter_length - 1):
|
||||
setattr(self, 'FCG_gnn_{}'.format(i + 1), fcg_constructor(**fcg_conv['kwargs'][i]))
|
||||
|
||||
# Last Projection Function: gradually project with more linear layers
|
||||
self.pj1 = Linear(in_features=fcg_filter_list[-1], out_features=int(fcg_filter_list[-1] / 2))
|
||||
self.pj2 = Linear(in_features=int(fcg_filter_list[-1] / 2), out_features=int(fcg_filter_list[-1] / 4))
|
||||
self.pj3 = Linear(in_features=int(fcg_filter_list[-1] / 4), out_features=1)
|
||||
|
||||
self.last_activation = nn.Sigmoid()
|
||||
|
||||
# self.last_activation = nn.Softmax(dim=1)
|
||||
# self.last_activation = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward_cfg_gnn(self, local_batch: Batch):
|
||||
in_x, edge_index = local_batch.x, local_batch.edge_index
|
||||
for i in range(self.cfg_filter_length - 1):
|
||||
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
out_x = pt_f.relu(out_x, inplace=True)
|
||||
out_x = self.dropout(out_x)
|
||||
in_x = out_x
|
||||
local_batch.x = in_x
|
||||
return local_batch
|
||||
|
||||
def aggregate_cfg_batch_pooling(self, local_batch: Batch):
|
||||
if self.pool == 'global_max_pool':
|
||||
x_pool = global_max_pool(x=local_batch.x, batch=local_batch.batch)
|
||||
elif self.pool == 'global_mean_pool':
|
||||
x_pool = global_mean_pool(x=local_batch.x, batch=local_batch.batch)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x_pool
|
||||
|
||||
def forward_fcg_gnn(self, function_batch: Batch):
|
||||
in_x, edge_index = function_batch.x, function_batch.edge_index
|
||||
for i in range(self.fcg_filter_length - 1):
|
||||
out_x = getattr(self, 'FCG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||||
out_x = pt_f.relu(out_x, inplace=True)
|
||||
out_x = self.dropout(out_x)
|
||||
in_x = out_x
|
||||
function_batch.x = in_x
|
||||
return function_batch
|
||||
|
||||
def aggregate_fcg_batch_pooling(self, function_batch: Batch):
|
||||
if self.pool == 'global_max_pool':
|
||||
x_pool = global_max_pool(x=function_batch.x, batch=function_batch.batch)
|
||||
elif self.pool == 'global_mean_pool':
|
||||
x_pool = global_mean_pool(x=function_batch.x, batch=function_batch.batch)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x_pool
|
||||
|
||||
def aggregate_final_skip_pooling(self, x, batch):
|
||||
if self.pool == 'global_max_pool':
|
||||
x_pool = global_max_pool(x=x, batch=batch)
|
||||
elif self.pool == 'global_mean_pool':
|
||||
x_pool = global_mean_pool(x=x, batch=batch)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x_pool
|
||||
|
||||
@staticmethod
|
||||
def cosine_attention(mtx1, mtx2):
|
||||
v1_norm = mtx1.norm(p=2, dim=2, keepdim=True)
|
||||
v2_norm = mtx2.norm(p=2, dim=2, keepdim=True).permute(0, 2, 1)
|
||||
a = torch.bmm(mtx1, mtx2.permute(0, 2, 1))
|
||||
d = v1_norm * v2_norm
|
||||
|
||||
return div_with_small_value(a, d)
|
||||
|
||||
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list, bt_all_function_edges: list, local_device: torch.device):
|
||||
|
||||
rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch)
|
||||
x_cfg_pool = self.aggregate_cfg_batch_pooling(local_batch=rtn_local_batch)
|
||||
|
||||
# build the Function Call Graph (FCG) Data/Batch datasets
|
||||
assert len(real_bt_positions) - 1 == len(bt_external_names), "all should be equal to the batch size ... "
|
||||
assert len(real_bt_positions) - 1 == len(bt_all_function_edges), "all should be equal to the batch size ... "
|
||||
|
||||
fcg_list = []
|
||||
fcg_internal_list = []
|
||||
for idx_batch in range(len(real_bt_positions) - 1):
|
||||
start_pos, end_pos = real_bt_positions[idx_batch: idx_batch + 2]
|
||||
|
||||
idx_x_cfg = x_cfg_pool[start_pos: end_pos]
|
||||
fcg_internal_list.append(idx_x_cfg)
|
||||
|
||||
idx_x_external = self.external_embedding_layer(torch.tensor([bt_external_names[idx_batch]], dtype=torch.long).to(local_device))
|
||||
idx_x_external = idx_x_external.squeeze(dim=0)
|
||||
|
||||
idx_x_total = torch.cat([idx_x_cfg, idx_x_external], dim=0)
|
||||
idx_function_edge = torch.tensor(bt_all_function_edges[idx_batch], dtype=torch.long)
|
||||
idx_graph_data = Data(x=idx_x_total, edge_index=idx_function_edge).to(local_device)
|
||||
|
||||
fcg_list.append(idx_graph_data)
|
||||
fcg_batch = Batch.from_data_list(fcg_list)
|
||||
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
|
||||
rtn_fcg_batch = self.forward_fcg_gnn(function_batch=fcg_batch) # [batch_size, max_node_size, dim]
|
||||
x_fcg_pool = self.aggregate_fcg_batch_pooling(function_batch=rtn_fcg_batch) # [batch_size, 1, dim] => [batch_size, dim]
|
||||
batch_final = x_fcg_pool
|
||||
|
||||
# step last project to the number_of_numbers (binary)
|
||||
bt_final_embed = self.pj3(self.pj2(self.pj1(batch_final)))
|
||||
bt_pred = self.last_activation(bt_final_embed)
|
||||
return bt_pred
|
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
108
src/utils/FunctionHelpers.py
Normal file
108
src/utils/FunctionHelpers.py
Normal file
@ -0,0 +1,108 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
from sklearn.metrics import auc, confusion_matrix, balanced_accuracy_score
|
||||
from texttable import Texttable
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def only_get_fpr(y_true, y_pred):
|
||||
n_benign = (y_true == 0).sum()
|
||||
n_false = (y_pred[y_true == 0] == 1).sum()
|
||||
return float(n_false) / float(n_benign)
|
||||
|
||||
|
||||
def get_fpr(y_true, y_pred):
|
||||
tn, fp, fn, tp = confusion_matrix(y_true=y_true, y_pred=y_pred).ravel()
|
||||
return float(fp) / float(fp + tn)
|
||||
|
||||
|
||||
def find_threshold_with_fixed_fpr(y_true, y_pred, fpr_target):
|
||||
start_time = datetime.now()
|
||||
|
||||
threshold = 0.0
|
||||
fpr = only_get_fpr(y_true, y_pred > threshold)
|
||||
while fpr > fpr_target and threshold <= 1.0:
|
||||
threshold += 0.0001
|
||||
fpr = only_get_fpr(y_true, y_pred > threshold)
|
||||
|
||||
tn, fp, fn, tp = confusion_matrix(y_true=y_true, y_pred=y_pred > threshold).ravel()
|
||||
tpr = tp / (tp + fn)
|
||||
fpr = fp / (fp + tn)
|
||||
acc = (tp + tn) / (tn + fp + fn + tp) # equal to accuracy_score(y_true=y_true, y_pred=y_pred > threshold)
|
||||
balanced_acc = balanced_accuracy_score(y_true=y_true, y_pred=y_pred > threshold)
|
||||
|
||||
_info = "Threshold: {:.6f}, TN: {}, FP: {}, FN: {}, TP: {}, TPR: {:.6f}, FPR: {:.6f}, ACC: {:.6f}, Balanced_ACC: {:.6f}. consume about {} time in find threshold".format(
|
||||
threshold, tn, fp, fn, tp, tpr, fpr, acc, balanced_acc, datetime.now() - start_time)
|
||||
return _info
|
||||
|
||||
|
||||
def alphabet_lower_strip(str1):
|
||||
return re.sub("[^A-Za-z]", "", str1).lower()
|
||||
|
||||
|
||||
def filter_counter_with_threshold(counter: Counter, min_threshold):
|
||||
return {x: counter[x] for x in counter if counter[x] >= min_threshold}
|
||||
|
||||
|
||||
def create_dir_if_not_exists(new_dir: str, log: logging.Logger):
|
||||
if not os.path.exists(new_dir):
|
||||
os.makedirs(new_dir)
|
||||
log.info('We are creating the dir of \"{}\" '.format(new_dir))
|
||||
else:
|
||||
log.info('We CANNOT creat the dir of \"{}\" as it is already exists.'.format(new_dir))
|
||||
|
||||
|
||||
def get_jsonl_files_from_path(file_path: str, log: logging.Logger):
|
||||
file_list = []
|
||||
for root, dirs, files in os.walk(file_path):
|
||||
for file in files:
|
||||
if file.endswith(".jsonl"):
|
||||
file_list.append(os.path.join(root, file))
|
||||
file_list.sort()
|
||||
log.info("{}\nFrom the path of {}, we obtain a list of {} files as follows.".format("-" * 50, file_path, len(file_list)))
|
||||
log.info("\n" + '\n'.join(str(f) for f in file_list))
|
||||
return file_list
|
||||
|
||||
|
||||
def write_into(file_name_path: str, log_str: str, print_flag=True):
|
||||
if print_flag:
|
||||
print(log_str)
|
||||
if log_str is None:
|
||||
log_str = 'None'
|
||||
if os.path.isfile(file_name_path):
|
||||
with open(file_name_path, 'a+') as log_file:
|
||||
log_file.write(log_str + '\n')
|
||||
else:
|
||||
with open(file_name_path, 'w+') as log_file:
|
||||
log_file.write(log_str + '\n')
|
||||
|
||||
|
||||
def params_print_log(param_dict: Dict, log_path: str):
|
||||
keys = sorted(param_dict.keys())
|
||||
table = Texttable()
|
||||
table.set_precision(6)
|
||||
table.set_cols_align(["l", "l", "c"])
|
||||
table.add_row(["Index", "Parameters", "Values"])
|
||||
for index, k in enumerate(keys):
|
||||
table.add_row([index, k, str(param_dict[k])])
|
||||
|
||||
# print(table.draw())
|
||||
write_into(file_name_path=log_path, log_str=table.draw())
|
||||
|
||||
|
||||
def dataclasses_to_string(ins: dataclass):
|
||||
name = type(ins).__name__
|
||||
|
||||
var_list = [f"{key} = {value}" for key, value in vars(ins).items()]
|
||||
var_str = '\n=>'.join(var_list)
|
||||
|
||||
return f"{name}:\n=>{var_str}\n"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
60
src/utils/ParameterClasses.py
Normal file
60
src/utils/ParameterClasses.py
Normal file
@ -0,0 +1,60 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainParams:
|
||||
processed_files_path: str
|
||||
# train_test_split_file: str
|
||||
max_epochs: int
|
||||
train_bs: int
|
||||
test_bs: int
|
||||
external_func_vocab_file: str
|
||||
max_vocab_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerParams:
|
||||
optimizer_name: str
|
||||
lr: float
|
||||
weight_decay: float
|
||||
learning_anneal: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParams:
|
||||
gnn_type: str
|
||||
pool_type: str
|
||||
acfg_init_dims: int
|
||||
cfg_filters: str
|
||||
fcg_filters: str
|
||||
number_classes: int
|
||||
dropout_rate: float
|
||||
ablation_models: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OneEpochResult:
|
||||
Epoch_Flag: str
|
||||
Number_Samples: int
|
||||
Avg_Loss: float
|
||||
Info_100: str
|
||||
Info_1000: str
|
||||
ROC_AUC_Score: float
|
||||
Thresholds: list
|
||||
TPRs: list
|
||||
FPRs: list
|
||||
|
||||
def __str__(self):
|
||||
s = "\nResult of \"{}\":\n=Epoch_Flag = {}\n=>Number of samples = {}\n=>Avg_Loss = {}\n=>Info_100 = {}\n=>Info_1000 = {}\n=>ROC_AUC_score = {}\n".format(
|
||||
self.Epoch_Flag,
|
||||
self.Epoch_Flag,
|
||||
self.Number_Samples,
|
||||
self.Avg_Loss,
|
||||
self.Info_100,
|
||||
self.Info_1000,
|
||||
self.ROC_AUC_Score)
|
||||
return s
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
85
src/utils/PreProcessedDataset.py
Normal file
85
src/utils/PreProcessedDataset.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Dataset, DataLoader
|
||||
from utils.RealBatch import create_real_batch_data # noqa
|
||||
|
||||
|
||||
class MalwareDetectionDataset(Dataset):
|
||||
def __init__(self, root, train_or_test, transform=None, pre_transform=None):
|
||||
super(MalwareDetectionDataset, self).__init__(None, transform, pre_transform)
|
||||
self.flag = train_or_test.lower()
|
||||
self.malware_root = os.path.join(root, "{}_malware".format(self.flag))
|
||||
self.benign_root = os.path.join(root, "{}_benign".format(self.flag))
|
||||
self.malware_files = os.listdir(self.malware_root)
|
||||
self.benign_files = os.listdir(self.benign_root)
|
||||
|
||||
@staticmethod
|
||||
def _list_files_for_pt(the_path):
|
||||
files = []
|
||||
for name in os.listdir(the_path):
|
||||
if os.path.splitext(name)[-1] == '.pt':
|
||||
files.append(name)
|
||||
return files
|
||||
|
||||
def __len__(self):
|
||||
# def len(self):
|
||||
# return 201
|
||||
return len(self.malware_files) + len(self.benign_files)
|
||||
|
||||
def get(self, idx):
|
||||
split = len(self.malware_files)
|
||||
# split = 100
|
||||
if idx < split:
|
||||
idx_data = torch.load(osp.join(self.malware_root, 'malware_{}.pt'.format(idx)))
|
||||
else:
|
||||
over_fit_idx = idx - split
|
||||
idx_data = torch.load(osp.join(self.benign_root, "benign_{}.pt".format(over_fit_idx)))
|
||||
return idx_data
|
||||
|
||||
|
||||
def _simulating(_dataset, _batch_size: int):
|
||||
print("\nBatch size = {}".format(_batch_size))
|
||||
time_start = datetime.now()
|
||||
print("start time: " + time_start.strftime("%Y-%m-%d@%H:%M:%S"))
|
||||
|
||||
# https://github.com/pytorch/fairseq/issues/1560
|
||||
# https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189
|
||||
# loaders_1 = DataLoader(dataset=benign_exe_dataset, batch_size=10, shuffle=True, num_workers=0)
|
||||
# increasing the shared memory: ulimit -SHn 51200
|
||||
loader = DataLoader(dataset=_dataset, batch_size=_batch_size, shuffle=True) # default of prefetch_factor = 2 # num_workers=4
|
||||
|
||||
for index, data in enumerate(loader):
|
||||
if index >= 3:
|
||||
break
|
||||
_real_batch, _position, _hash, _external_list, _function_edges, _true_classes = create_real_batch_data(one_batch=data)
|
||||
print(data)
|
||||
print("Hash: ", _hash)
|
||||
print("Position: ", _position)
|
||||
print("\n")
|
||||
|
||||
time_end = datetime.now()
|
||||
print("end time: " + time_end.strftime("%Y-%m-%d@%H:%M:%S"))
|
||||
print("All time = {}\n\n".format(time_end - time_start))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_path: str = '/home/xiang/MalGraph/data/processed_dataset/DatasetJSON/'
|
||||
i_batch_size = 2
|
||||
|
||||
train_dataset = MalwareDetectionDataset(root=root_path, train_or_test='train')
|
||||
print(train_dataset.malware_root, train_dataset.benign_root)
|
||||
print(len(train_dataset.malware_files), len(train_dataset.benign_files), len(train_dataset))
|
||||
_simulating(_dataset=train_dataset, _batch_size=i_batch_size)
|
||||
|
||||
valid_dataset = MalwareDetectionDataset(root=root_path, train_or_test='valid')
|
||||
print(valid_dataset.malware_root, valid_dataset.benign_root)
|
||||
print(len(valid_dataset.malware_files), len(valid_dataset.benign_files), len(valid_dataset))
|
||||
_simulating(_dataset=valid_dataset, _batch_size=i_batch_size)
|
||||
|
||||
test_dataset = MalwareDetectionDataset(root=root_path, train_or_test='test')
|
||||
print(test_dataset.malware_root, test_dataset.benign_root)
|
||||
print(len(test_dataset.malware_files), len(test_dataset.benign_files), len(test_dataset))
|
||||
_simulating(_dataset=test_dataset, _batch_size=i_batch_size)
|
24
src/utils/RealBatch.py
Normal file
24
src/utils/RealBatch.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from torch_geometric.data import Batch
|
||||
from torch_geometric.data import DataLoader
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def create_real_batch_data(one_batch: Batch):
|
||||
real = []
|
||||
position = [0]
|
||||
count = 0
|
||||
|
||||
assert len(one_batch.external_list) == len(one_batch.function_edges) == len(one_batch.local_acfgs) == len(one_batch.hash), "size of each component must be equal to each other"
|
||||
|
||||
for item in one_batch.local_acfgs:
|
||||
for acfg in item:
|
||||
real.append(acfg)
|
||||
count += len(item)
|
||||
position.append(count)
|
||||
|
||||
if len(one_batch.local_acfgs) == 1 and len(one_batch.local_acfgs[0]) == 0:
|
||||
return (None for _ in range(6))
|
||||
else:
|
||||
real_batch = Batch.from_data_list(real)
|
||||
return real_batch, position, one_batch.hash, one_batch.external_list, one_batch.function_edges, one_batch.targets
|
91
src/utils/Vocabulary.py
Normal file
91
src/utils/Vocabulary.py
Normal file
@ -0,0 +1,91 @@
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Vocab:
|
||||
def __init__(self, freq_file: str, max_vocab_size: int, min_freq: int = 1, unk_token: str = '<unk>', pad_token: str = '<pad>', special_tokens: list = None):
|
||||
|
||||
self.max_vocab_size = max_vocab_size
|
||||
self.min_freq = min_freq
|
||||
|
||||
self.unk_token = unk_token
|
||||
self.pad_token = pad_token
|
||||
self.special_tokens = special_tokens
|
||||
|
||||
assert os.path.exists(freq_file), "The file of {} is not exist".format(freq_file)
|
||||
freq_counter = self.load_freq_counter_from_file(file_path=freq_file, min_freq=self.min_freq)
|
||||
|
||||
self.token_2_index, self.index_2_token = self.create_vocabulary(freq_counter=freq_counter)
|
||||
|
||||
self.unk_idx = None if self.unk_token is None else self.token_2_index[self.unk_token]
|
||||
self.pad_idx = None if self.pad_token is None else self.token_2_index[self.pad_token]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_2_token)
|
||||
|
||||
def __getitem__(self, item: str):
|
||||
assert isinstance(item, str)
|
||||
if item in self.token_2_index.keys():
|
||||
return self.token_2_index[item]
|
||||
else:
|
||||
if self.unk_token is not None:
|
||||
return self.token_2_index[self.unk_token]
|
||||
else:
|
||||
raise KeyError("{} is not in the vocabulary, and self.unk_token is None".format(item))
|
||||
|
||||
def create_vocabulary(self, freq_counter: Counter):
|
||||
|
||||
token_2_index = {} # dict
|
||||
index_2_token = [] # list
|
||||
|
||||
if self.unk_token is not None:
|
||||
index_2_token.append(self.unk_token)
|
||||
if self.pad_token is not None:
|
||||
index_2_token.append(self.pad_token)
|
||||
if self.special_tokens is not None:
|
||||
for token in self.special_tokens:
|
||||
index_2_token.append(token)
|
||||
|
||||
for f_name, count in tqdm(freq_counter.most_common(self.max_vocab_size), desc="creating vocab ... "):
|
||||
if f_name in index_2_token:
|
||||
print("trying to add {} to the vocabulary, but it already exists !!!".format(f_name))
|
||||
continue
|
||||
else:
|
||||
index_2_token.append(f_name)
|
||||
|
||||
for index, token in enumerate(index_2_token): # reverse
|
||||
token_2_index.update({token: index})
|
||||
|
||||
return token_2_index, index_2_token
|
||||
|
||||
@staticmethod
|
||||
def load_freq_counter_from_file(file_path: str, min_freq: int):
|
||||
freq_dict = {}
|
||||
with open(file_path, 'r') as f:
|
||||
for line in tqdm(f, desc="Load frequency list from the file of {} ... ".format(file_path)):
|
||||
line = json.loads(line)
|
||||
f_name = line["f_name"]
|
||||
count = int(line["count"])
|
||||
|
||||
assert f_name not in freq_dict, "trying to add {} to the vocabulary, but it already exists !!!"
|
||||
if count < min_freq:
|
||||
print(line, "break")
|
||||
break
|
||||
|
||||
freq_dict[f_name] = count
|
||||
return Counter(freq_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
max_vocab_size = 1000
|
||||
vocab = Vocab(freq_file="../../data/processed_dataset/train_external_function_name_vocab.jsonl", max_vocab_size=max_vocab_size)
|
||||
print(len(vocab.token_2_index), vocab.token_2_index)
|
||||
print(len(vocab.index_2_token), vocab.index_2_token)
|
||||
print(vocab.unk_token, vocab.unk_idx)
|
||||
print(vocab.pad_token, vocab.pad_idx)
|
||||
print(vocab['queryperformancecounter'])
|
||||
print(vocab['EmptyClipboard'])
|
||||
print(vocab[str.lower('EmptyClipboard')])
|
||||
print(vocab['X_Y_Z'])
|
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user