small changes in training args

parent f2864b0e
__version__ = "0.6.0" __version__ = "0.6.1"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -36,8 +36,8 @@ def main(params): ...@@ -36,8 +36,8 @@ def main(params):
callbacks.append(pl.callbacks.EarlyStopping(monitor="val_loss", patience=params.early_stop, callbacks.append(pl.callbacks.EarlyStopping(monitor="val_loss", patience=params.early_stop,
verbose=False, mode="min")) verbose=False, mode="min"))
trainer = pl.Trainer(accelerator=params.device, devices=params.num_devices, num_nodes=params.num_nodes, trainer = pl.Trainer(accelerator=params.device, num_nodes=params.num_nodes, max_epochs=params.num_epochs,
max_epochs=params.num_epochs, logger=logger, callbacks=callbacks) logger=logger, callbacks=callbacks)
trainer.fit(model, train_set, val_set) trainer.fit(model, train_set, val_set)
...@@ -58,11 +58,9 @@ def add_args(parser): ...@@ -58,11 +58,9 @@ def add_args(parser):
help="Fraction of the training data to use for validation.") help="Fraction of the training data to use for validation.")
train_args.add_argument("--seed", type=int, default=None, help="Global training seed.") train_args.add_argument("--seed", type=int, default=None, help="Global training seed.")
train_args.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.") train_args.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.")
train_args.add_argument("--num_devices", type=int, default=1,
help="Number of devices to use for multi GPU training.")
train_args.add_argument("--num_nodes", type=int, default=1, train_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for training on a cluster.") help="Number of nodes to use for training on a cluster.")
train_args.add_argument("--early_stop", type=int, default=None, train_args.add_argument("--early_stop", type=int, default=10,
help="Number of epochs to wait before stopping the training " help="Number of epochs to wait before stopping the training "
"(tracking is done with validation loss). By default, the is no early stopping.") "(tracking is done with validation loss). By default, the is no early stopping.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment