detect_rep/ASM2VEC_plus_scripts/train.py

40 lines
1.9 KiB
Python
Raw Permalink Normal View History

2023-04-05 10:04:49 +08:00
import torch
import click
import asm2vec
def train(ipath='./asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints', limit=1000, embedding_size=100, batch_size=1024, epochs=10, calc_acc=True, device='auto', lr=0.02):
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = None
#normal是是否归一化
functions = asm2vec.utils.load_data(ipath, limit=limit,normal=True)
def callback(context):
progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}'
if context["accuracy"]:
progress += f', accuracy = {context["accuracy"]:.4f}'
print(progress)
asm2vec.utils.save_model(opath, context["model"])
asm2vec.utils.train(
functions,
model=model,
embedding_size=embedding_size,
batch_size=batch_size,
epochs=epochs,
calc_acc=calc_acc,
device=device,
callback=callback,
learning_rate=lr
)
if __name__ == '__main__':
#limit=输入数据条数embedding_size=输入向量大小batch_size=批数epochs=轮数
# train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_16_100.pt', limit=10000, embedding_size=16, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_32_100.pt', limit=10000, embedding_size=32, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_64_100.pt', limit=10000, embedding_size=64, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_256_100.pt', limit=10000, embedding_size=256, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)