0.7.2 Checkpoint loading in train

parent 8ec564c9
__version__ = "0.7.1"
__version__ = "0.7.2"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -19,6 +19,11 @@ def main(params):
model = SensePPIModel(params)
if params.device == 'gpu' and params.ckpt_path is not None:
checkpoint = torch.load(params.ckpt_path)
model.load_state_dict(checkpoint['state_dict'])
print('checkpoint loaded')
model.load_data(dataset=dataset, valid_size=params.valid_size)
train_set = model.train_dataloader()
val_set = model.val_dataloader()
......@@ -63,6 +68,9 @@ def add_args(parser):
train_args.add_argument("--early_stop", type=int, default=10,
help="Number of epochs to wait before stopping the training "
"(tracking is done with validation loss). By default, the is no early stopping.")
train_args.add_argument("--ckpt_path", type=str, default=None,
help="Path to the checkpoint to load the model from. Can be used for transfer learning. "
"Loads the model only for GPU training.")
parser = SensePPIModel.add_model_specific_args(parser)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment