40 lines
1.9 KiB
Python
40 lines
1.9 KiB
Python
|
import torch
|
|||
|
import click
|
|||
|
import asm2vec
|
|||
|
|
|||
|
def train(ipath='./asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints', limit=1000, embedding_size=100, batch_size=1024, epochs=10, calc_acc=True, device='auto', lr=0.02):
|
|||
|
if device == 'auto':
|
|||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||
|
|
|||
|
|
|||
|
|
|||
|
model = None
|
|||
|
#normal是是否归一化
|
|||
|
functions = asm2vec.utils.load_data(ipath, limit=limit,normal=True)
|
|||
|
|
|||
|
def callback(context):
|
|||
|
progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}'
|
|||
|
if context["accuracy"]:
|
|||
|
progress += f', accuracy = {context["accuracy"]:.4f}'
|
|||
|
print(progress)
|
|||
|
asm2vec.utils.save_model(opath, context["model"])
|
|||
|
|
|||
|
asm2vec.utils.train(
|
|||
|
functions,
|
|||
|
model=model,
|
|||
|
embedding_size=embedding_size,
|
|||
|
batch_size=batch_size,
|
|||
|
epochs=epochs,
|
|||
|
calc_acc=calc_acc,
|
|||
|
device=device,
|
|||
|
callback=callback,
|
|||
|
learning_rate=lr
|
|||
|
)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
#limit=输入数据条数,embedding_size=输入向量大小,batch_size=批数,epochs=轮数
|
|||
|
# train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_16_100.pt', limit=10000, embedding_size=16, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|
|||
|
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_32_100.pt', limit=10000, embedding_size=32, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|
|||
|
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_64_100.pt', limit=10000, embedding_size=64, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|
|||
|
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_256_100.pt', limit=10000, embedding_size=256, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|