diff --git a/src/trainer/pretrain.py b/src/trainer/pretrain.py index 63149fe..1f3d31f 100644 --- a/src/trainer/pretrain.py +++ b/src/trainer/pretrain.py @@ -3,8 +3,8 @@ import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader -from model.bert import BERT -from model.language_model import BERTLM +from ..model.bert import BERT +from ..model.language_model import BERTLM from .optim_schedule import ScheduledOptim import tqdm