0.5.6 documentation and parser updates

parent c5196603
__version__ = "0.5.5" __version__ = "0.5.6"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
import argparse import argparse
import logging import logging
import os
import torch import torch
from .commands import * from .commands import *
from senseppi import __version__ from senseppi import __version__
...@@ -44,8 +45,12 @@ def main(): ...@@ -44,8 +45,12 @@ def main():
logging.info('Device used: {}'.format(params.device)) logging.info('Device used: {}'.format(params.device))
if hasattr(params, 'model_path'):
if params.model_path is None:
params.model_path = os.path.join(os.path.dirname(__file__), "default_model", "senseppi.ckpt")
params.func(params) params.func(params)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -258,8 +258,9 @@ def add_args(parser): ...@@ -258,8 +258,9 @@ def add_args(parser):
help="The sequences file downloaded from the same page of STRING. " help="The sequences file downloaded from the same page of STRING. "
"For both files see https://string-db.org/cgi/download") "For both files see https://string-db.org/cgi/download")
parser.add_argument("--not_remove_long_short_proteins", action='store_true', parser.add_argument("--not_remove_long_short_proteins", action='store_true',
help="Whether to remove proteins that are too short or too long. " help="If specified, does not remove proteins "
"Normally, the long and short proteins are removed.") "shorter than --min_length and longer than --max_length. "
"By default, long and short proteins are removed.")
parser.add_argument("--min_length", type=int, default=50, parser.add_argument("--min_length", type=int, default=50,
help="The minimum length of a protein to be included in the dataset.") help="The minimum length of a protein to be included in the dataset.")
parser.add_argument("--max_length", type=int, default=800, parser.add_argument("--max_length", type=int, default=800,
...@@ -290,27 +291,30 @@ def main(params): ...@@ -290,27 +291,30 @@ def main(params):
logging.info('STRING version: {}'.format(version)) logging.info('STRING version: {}'.format(version))
try: try:
url = "{0}protein.physical.links.full.v{1}/{2}.protein.physical.links.full.v{1}.txt.gz".format(DOWNLOAD_LINK_STRING, version, params.species) url = "{0}protein.physical.links.full.v{1}/{2}.protein.physical.links.full.v{1}.txt.gz".format(
DOWNLOAD_LINK_STRING, version, params.species)
string_file_name_links = "{1}.protein.physical.links.full.v{0}.txt".format(version, params.species) string_file_name_links = "{1}.protein.physical.links.full.v{0}.txt".format(version, params.species)
wget.download(url, out=string_file_name_links+'.gz') wget.download(url, out=string_file_name_links + '.gz')
with gzip.open(string_file_name_links+'.gz', 'rb') as f_in: with gzip.open(string_file_name_links + '.gz', 'rb') as f_in:
with open(string_file_name_links, 'wb') as f_out: with open(string_file_name_links, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
url = "{0}protein.sequences.v{1}/{2}.protein.sequences.v{1}.fa.gz".format(DOWNLOAD_LINK_STRING, version, params.species) url = "{0}protein.sequences.v{1}/{2}.protein.sequences.v{1}.fa.gz".format(DOWNLOAD_LINK_STRING, version,
params.species)
string_file_name_seqs = "{1}.protein.sequences.v{0}.fa".format(version, params.species) string_file_name_seqs = "{1}.protein.sequences.v{0}.fa".format(version, params.species)
wget.download(url, out=string_file_name_seqs+'.gz') wget.download(url, out=string_file_name_seqs + '.gz')
with gzip.open(string_file_name_seqs+'.gz', 'rb') as f_in: with gzip.open(string_file_name_seqs + '.gz', 'rb') as f_in:
with open(string_file_name_seqs, 'wb') as f_out: with open(string_file_name_seqs, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
except HTTPError: except HTTPError:
raise Exception('The files are not available for the specified species. ' raise Exception('The files are not available for the specified species. '
'There might be two reasons for that: \n ' 'There might be two reasons for that: \n '
'1) the species is not available in STRING. Please check the STRING species list to verify. \n' '1) the species is not available in STRING. Please check the STRING species list to '
'2) the download link has changed. Please raise an issue in the repository. ') 'verify. \n '
'2) the download link has changed. Please raise an issue in the repository. ')
os.remove(string_file_name_seqs+'.gz') os.remove(string_file_name_seqs + '.gz')
os.remove(string_file_name_links+'.gz') os.remove(string_file_name_links + '.gz')
params.interactions = string_file_name_links params.interactions = string_file_name_links
params.sequences = string_file_name_seqs params.sequences = string_file_name_seqs
...@@ -330,4 +334,3 @@ if __name__ == '__main__': ...@@ -330,4 +334,3 @@ if __name__ == '__main__':
parser = add_args(parser) parser = add_args(parser)
params = parser.parse_args() params = parser.parse_args()
main(params) main(params)
...@@ -68,9 +68,10 @@ def add_args(parser): ...@@ -68,9 +68,10 @@ def add_args(parser):
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path, parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.", help="FASTA file on which to extract the ESM2 representations and then test.",
) )
predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"), predict_args.add_argument("--model_path", type=str, default=None,
help="A path to .ckpt file that contains weights to a pretrained model. If " help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.") "None, the preinstalled senseppi.ckpt trained version is used. "
"(Trained on human PPIs)")
predict_args.add_argument("--pairs_file", type=str, default=None, predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, " help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.") "all-to-all pairs will be generated.")
...@@ -79,7 +80,7 @@ def add_args(parser): ...@@ -79,7 +80,7 @@ def add_args(parser):
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
predict_args.add_argument("--with_self", action='store_true', predict_args.add_argument("--with_self", action='store_true',
help="Include self-interactions in the predictions." help="Include self-interactions in the predictions."
"By default they are not included since they were not part of training but" "By default they are not included since they were not part of training but "
"they can be included by setting this flag to True.") "they can be included by setting this flag to True.")
predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5, predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5,
help="Prediction threshold to determine interacting pairs that " help="Prediction threshold to determine interacting pairs that "
......
...@@ -109,7 +109,7 @@ def main(params): ...@@ -109,7 +109,7 @@ def main(params):
sns.heatmap(labels_heatmap, cmap=cmap, vmin=0, vmax=1, sns.heatmap(labels_heatmap, cmap=cmap, vmin=0, vmax=1,
ax=ax1, mask=labels_heatmap == -1, ax=ax1, mask=labels_heatmap == -1,
cbar=False, square=True) # , linewidths=0.5, linecolor='white') cbar=False, square=True)
cbar = ax1.figure.colorbar(ax1.collections[0], ax=ax1, location='right', pad=0.15) cbar = ax1.figure.colorbar(ax1.collections[0], ax=ax1, location='right', pad=0.15)
cbar.ax.yaxis.set_ticks_position('right') cbar.ax.yaxis.set_ticks_position('right')
...@@ -172,26 +172,28 @@ def add_args(parser): ...@@ -172,26 +172,28 @@ def add_args(parser):
parser._action_groups[0].add_argument("genes", type=str, nargs="+", parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be " help="Name of gene to fetch from STRING database. Several names can be "
"typed (separated by whitespaces).") "typed (separated by whitespaces).")
string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"), string_pred_args.add_argument("--model_path", type=str, default=None,
help="A path to .ckpt file that contains weights to a pretrained model. If " help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.") "None, the preinstalled senseppi.ckpt trained version is used. "
"(Trained on human PPIs)")
string_pred_args.add_argument("-s", "--species", type=int, default=9606, string_pred_args.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)") help="Species from STRING database. Default: H. Sapiens")
string_pred_args.add_argument("-n", "--nodes", type=int, default=10, string_pred_args.add_argument("-n", "--nodes", type=int, default=10,
help="Number of nodes to fetch from STRING database. Default: 10") help="Number of nodes to fetch from STRING database. ")
string_pred_args.add_argument("-r", "--score", type=int, default=0, string_pred_args.add_argument("-r", "--score", type=int, default=0,
help="Score threshold for STRING connections. Range: (0, 1000). Default: 0") help="Score threshold for STRING connections. Range: (0, 1000). ")
string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500, string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500,
help="Prediction threshold. Range: (0, 1000). Default: 500") help="Prediction threshold. Range: (0, 1000). ")
string_pred_args.add_argument("--graphs", action='store_true', string_pred_args.add_argument("--graphs", action='store_true',
help="Enables plotting the heatmap and a network graph.") help="Enables plotting the heatmap and a network graph.")
string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string", string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string",
help="A path to a file where the predictions will be saved. " help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
string_pred_args.add_argument("--network_type", type=str, default="physical", string_pred_args.add_argument("--network_type", type=str, default="physical", choices=['physical', 'functional'],
help="Network type: \"physical\" or \"functional\". Default: \"physical\"") help="Network type to fetch from STRING database. ")
string_pred_args.add_argument("--delete_proteins", type=str, nargs="+", default=None, string_pred_args.add_argument("--delete_proteins", type=str, nargs="+", default=None,
help="List of proteins to delete from the graph. Default: None") help="List of proteins to delete from the graph. "
"Several names can be specified separated by whitespaces. ")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr") remove_argument(parser, "--lr")
......
...@@ -37,16 +37,17 @@ def add_args(parser): ...@@ -37,16 +37,17 @@ def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
test_args = parser.add_argument_group(title="Predict args") test_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None, parser._action_groups[0].add_argument("pairs_file", type=str,
help="A path to a .tsv file with pairs of proteins to test.") help="A path to a .tsv file with pairs of proteins to test.")
parser._action_groups[0].add_argument("fasta_file", parser._action_groups[0].add_argument("fasta_file",
type=pathlib.Path, type=pathlib.Path,
help="FASTA file on which to extract the ESM2 " help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.", "representations and then evaluate.",
) )
test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"), test_args.add_argument("--model_path", type=str, default=None,
help="A path to .ckpt file that contains weights to a pretrained model. If " help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.") "None, the preinstalled senseppi.ckpt trained version is used. "
"(Trained on human PPIs)")
test_args.add_argument("-o", "--output", type=str, default="test_metrics", test_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. " help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
......
...@@ -16,8 +16,10 @@ def add_esm_args(parent_parser): ...@@ -16,8 +16,10 @@ def add_esm_args(parent_parser):
parser = parent_parser.add_argument_group(title="ESM2 model args", parser = parent_parser.add_argument_group(title="ESM2 model args",
description="ESM2: Extract per-token representations and model " description="ESM2: Extract per-token representations and model "
"outputs for sequences in a FASTA file. " "outputs for sequences in a FASTA file. "
"If you would like to use the basic version of SENSE-PPI " "The representations are saved in --output_dir_esm folder so "
"do no edit the default values of the arguments below. ") "they can be reused in multiple runs. In order to reuse the "
"embeddings, make sure that --output_dir_esm is set to the "
"correct folder.")
parser.add_argument( parser.add_argument(
"--model_location_esm", "--model_location_esm",
type=str, default="esm2_t36_3B_UR50D", type=str, default="esm2_t36_3B_UR50D",
......
...@@ -107,7 +107,9 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require ...@@ -107,7 +107,9 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require
string_interactions = pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t')) string_interactions = pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t'))
if 'Error' in string_interactions.columns: if 'Error' in string_interactions.columns:
raise Exception(string_interactions['ErrorMessage'].values[0]) err_msg = string_interactions['ErrorMessage'].values[0]
err_msg = err_msg.replace('<br>', '\n').replace('<br/>', '\n').replace('<p>', '\n').replace('</p>', '\n')
raise Exception(err_msg)
if len(string_interactions) == 0: if len(string_interactions) == 0:
raise Exception('No interactions found. Please revise your input parameters.') raise Exception('No interactions found. Please revise your input parameters.')
......
...@@ -7,15 +7,28 @@ import argparse ...@@ -7,15 +7,28 @@ import argparse
class ArgumentParserWithDefaults(argparse.ArgumentParser): class ArgumentParserWithDefaults(argparse.ArgumentParser):
def add_argument(self, *args, help=None, default=None, **kwargs): def add_argument(self, *args, **kwargs):
if help is not None: if 'help' in kwargs and kwargs['help'] is not argparse.SUPPRESS and 'default' in kwargs and args[0] != '-h':
kwargs['help'] = help kwargs['help'] += ' (Default: {})'.format(kwargs['default'])
if default is not None and args[0] != '-h':
kwargs['default'] = default
if help is not None:
kwargs['help'] += ' Default: {}'.format(default)
super().add_argument(*args, **kwargs) super().add_argument(*args, **kwargs)
def add_argument_group(self, *args, **kwargs):
group = ArgumentGroupWithDefaults(self, *args, **kwargs)
self._action_groups.append(group)
return group
class ArgumentGroupWithDefaults(argparse._ArgumentGroup):
def add_argument(self, *args, **kwargs):
if 'help' in kwargs and kwargs['help'] is not argparse.SUPPRESS and 'default' in kwargs and args[0] != '-h':
kwargs['help'] += ' (Default: {})'.format(kwargs['default'])
super().add_argument(*args, **kwargs)
def add_argument_group(self, *args, **kwargs):
group = self._ArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def add_general_args(parser): def add_general_args(parser):
parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__)) parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment