0.5.1 model bug fix

parent aa239099
__version__ = "0.5.0" __version__ = "0.5.1"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -68,7 +68,7 @@ def add_args(parser): ...@@ -68,7 +68,7 @@ def add_args(parser):
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path, parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.", help="FASTA file on which to extract the ESM2 representations and then test.",
) )
predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"), predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If " help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.") "None, the senseppi trained version is used.")
predict_args.add_argument("--pairs_file", type=str, default=None, predict_args.add_argument("--pairs_file", type=str, default=None,
......
...@@ -187,7 +187,7 @@ def add_args(parser): ...@@ -187,7 +187,7 @@ def add_args(parser):
parser._action_groups[0].add_argument("genes", type=str, nargs="+", parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be " help="Name of gene to fetch from STRING database. Several names can be "
"typed (separated by whitespaces).") "typed (separated by whitespaces).")
string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"), string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If " help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.") "None, the senseppi trained version is used.")
string_pred_args.add_argument("-s", "--species", type=int, default=9606, string_pred_args.add_argument("-s", "--species", type=int, default=9606,
......
...@@ -44,7 +44,7 @@ def add_args(parser): ...@@ -44,7 +44,7 @@ def add_args(parser):
help="FASTA file on which to extract the ESM2 " help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.", "representations and then evaluate.",
) )
test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"), test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If " help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.") "None, the senseppi trained version is used.")
test_args.add_argument("-o", "--output", type=str, default="test_metrics", test_args.add_argument("-o", "--output", type=str, default="test_metrics",
......
...@@ -13,7 +13,7 @@ setup( ...@@ -13,7 +13,7 @@ setup(
url="", url="",
license="MIT", license="MIT",
packages=find_packages(), packages=find_packages(),
package_data={'senseppi': ['default_model/*']},
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
include_package_data=True, include_package_data=True,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment