40 lines
1.9 KiB
Python
40 lines
1.9 KiB
Python
import torch
|
||
import click
|
||
import asm2vec
|
||
|
||
def train(ipath='./asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints', limit=1000, embedding_size=100, batch_size=1024, epochs=10, calc_acc=True, device='auto', lr=0.02):
|
||
if device == 'auto':
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
||
|
||
|
||
model = None
|
||
#normal是是否归一化
|
||
functions = asm2vec.utils.load_data(ipath, limit=limit,normal=True)
|
||
|
||
def callback(context):
|
||
progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}'
|
||
if context["accuracy"]:
|
||
progress += f', accuracy = {context["accuracy"]:.4f}'
|
||
print(progress)
|
||
asm2vec.utils.save_model(opath, context["model"])
|
||
|
||
asm2vec.utils.train(
|
||
functions,
|
||
model=model,
|
||
embedding_size=embedding_size,
|
||
batch_size=batch_size,
|
||
epochs=epochs,
|
||
calc_acc=calc_acc,
|
||
device=device,
|
||
callback=callback,
|
||
learning_rate=lr
|
||
)
|
||
|
||
if __name__ == '__main__':
|
||
#limit=输入数据条数,embedding_size=输入向量大小,batch_size=批数,epochs=轮数
|
||
# train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_16_100.pt', limit=10000, embedding_size=16, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|
||
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_32_100.pt', limit=10000, embedding_size=32, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|
||
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_64_100.pt', limit=10000, embedding_size=64, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|
||
train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/model_256_100.pt', limit=10000, embedding_size=256, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)
|