import torch import click import asm2vec def train(ipath='./asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints', limit=1000, embedding_size=100, batch_size=1024, epochs=10, calc_acc=True, device='auto', lr=0.02): if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' model = None #normal是是否归一化 functions = asm2vec.utils.load_data(ipath, limit=limit,normal=True) def callback(context): progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}' if context["accuracy"]: progress += f', accuracy = {context["accuracy"]:.4f}' print(progress) asm2vec.utils.save_model(opath, context["model"]) asm2vec.utils.train( functions, model=model, embedding_size=embedding_size, batch_size=batch_size, epochs=epochs, calc_acc=calc_acc, device=device, callback=callback, learning_rate=lr ) if __name__ == '__main__': #limit=输入数据条数,embedding_size=输入向量大小,batch_size=批数,epochs=轮数 train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/s372_model_16_100.pt', limit=10000, embedding_size=16, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02) train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/s372_model_32_100.pt', limit=10000, embedding_size=32, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02) train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/s372_model_64_100.pt', limit=10000, embedding_size=64, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02) train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/s372_model_128_100.pt', limit=10000,embedding_size=128, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02) train(ipath='../asm_func/asm_hex/func_bytes.csv', opath='./asm2vec_checkpoints/s372_model_256_100.pt', limit=10000, embedding_size=256, batch_size=1024, epochs=100, calc_acc=True, device='auto', lr=0.02)