0.5.0 senseppi.ckpt added to the package as a default model

parent 6f5330e5
__version__ = "0.4.1" __version__ = "0.5.0"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -65,11 +65,12 @@ def add_args(parser): ...@@ -65,11 +65,12 @@ def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Predict args") predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path, parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.", help="FASTA file on which to extract the ESM2 representations and then test.",
) )
predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
predict_args.add_argument("--pairs_file", type=str, default=None, predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, " help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.") "all-to-all pairs will be generated.")
......
...@@ -184,11 +184,12 @@ def add_args(parser): ...@@ -184,11 +184,12 @@ def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
string_pred_args = parser.add_argument_group(title="General options") string_pred_args = parser.add_argument_group(title="General options")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("genes", type=str, nargs="+", parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be " help="Name of gene to fetch from STRING database. Several names can be "
"typed (separated by whitespaces).") "typed (separated by whitespaces).")
string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
string_pred_args.add_argument("-s", "--species", type=int, default=9606, string_pred_args.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)") help="Species from STRING database. Default: 9606 (H. Sapiens)")
string_pred_args.add_argument("-n", "--nodes", type=int, default=10, string_pred_args.add_argument("-n", "--nodes", type=int, default=10,
......
...@@ -37,8 +37,6 @@ def add_args(parser): ...@@ -37,8 +37,6 @@ def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
test_args = parser.add_argument_group(title="Predict args") test_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None, parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test.") help="A path to a .tsv file with pairs of proteins to test.")
parser._action_groups[0].add_argument("fasta_file", parser._action_groups[0].add_argument("fasta_file",
...@@ -46,6 +44,9 @@ def add_args(parser): ...@@ -46,6 +44,9 @@ def add_args(parser):
help="FASTA file on which to extract the ESM2 " help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.", "representations and then evaluate.",
) )
test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
test_args.add_argument("-o", "--output", type=str, default="test_metrics", test_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. " help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment