0.5.1 model bug fix

parent aa239099
__version__ = "0.5.0"
__version__ = "0.5.1"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -68,7 +68,7 @@ def add_args(parser):
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"),
predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
predict_args.add_argument("--pairs_file", type=str, default=None,
......
......@@ -187,7 +187,7 @@ def add_args(parser):
parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be "
"typed (separated by whitespaces).")
string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"),
string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
string_pred_args.add_argument("-s", "--species", type=int, default=9606,
......
......@@ -44,7 +44,7 @@ def add_args(parser):
help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.",
)
test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "senseppi.ckpt"),
test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
test_args.add_argument("-o", "--output", type=str, default="test_metrics",
......
......@@ -13,7 +13,7 @@ setup(
url="",
license="MIT",
packages=find_packages(),
package_data={'senseppi': ['default_model/*']},
long_description=long_description,
long_description_content_type="text/markdown",
include_package_data=True,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment