0.1.2 early stop added to training

parent 8cb190f6
...@@ -32,6 +32,10 @@ def main(params): ...@@ -32,6 +32,10 @@ def main(params):
monitor='val_loss', mode='min', save_top_k=1) monitor='val_loss', mode='min', save_top_k=1)
] ]
if params.early_stop is not None:
callbacks.append(pl.callbacks.EarlyStopping(monitor="val_loss", patience=params.early_stop,
verbose=False, mode="min"))
trainer = pl.Trainer(accelerator=params.device, devices=params.num_devices, num_nodes=params.num_nodes, trainer = pl.Trainer(accelerator=params.device, devices=params.num_devices, num_nodes=params.num_nodes,
max_epochs=params.num_epochs, logger=logger, callbacks=callbacks) max_epochs=params.num_epochs, logger=logger, callbacks=callbacks)
...@@ -58,6 +62,9 @@ def add_args(parser): ...@@ -58,6 +62,9 @@ def add_args(parser):
help="Number of devices to use for multi GPU training.") help="Number of devices to use for multi GPU training.")
train_args.add_argument("--num_nodes", type=int, default=1, train_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for training on a cluster.") help="Number of nodes to use for training on a cluster.")
train_args.add_argument("--early_stop", type=int, default=None,
help="Number of epochs to wait before stopping the training "
"(tracking is done with validation loss). By default, the is no early stopping.")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment