0.5.6 documentation and parser updates

parent c5196603
__version__ = "0.5.5"
__version__ = "0.5.6"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
import argparse
import logging
import os
import torch
from .commands import *
from senseppi import __version__
......@@ -44,8 +45,12 @@ def main():
logging.info('Device used: {}'.format(params.device))
if hasattr(params, 'model_path'):
if params.model_path is None:
params.model_path = os.path.join(os.path.dirname(__file__), "default_model", "senseppi.ckpt")
params.func(params)
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -258,8 +258,9 @@ def add_args(parser):
help="The sequences file downloaded from the same page of STRING. "
"For both files see https://string-db.org/cgi/download")
parser.add_argument("--not_remove_long_short_proteins", action='store_true',
help="Whether to remove proteins that are too short or too long. "
"Normally, the long and short proteins are removed.")
help="If specified, does not remove proteins "
"shorter than --min_length and longer than --max_length. "
"By default, long and short proteins are removed.")
parser.add_argument("--min_length", type=int, default=50,
help="The minimum length of a protein to be included in the dataset.")
parser.add_argument("--max_length", type=int, default=800,
......@@ -290,27 +291,30 @@ def main(params):
logging.info('STRING version: {}'.format(version))
try:
url = "{0}protein.physical.links.full.v{1}/{2}.protein.physical.links.full.v{1}.txt.gz".format(DOWNLOAD_LINK_STRING, version, params.species)
url = "{0}protein.physical.links.full.v{1}/{2}.protein.physical.links.full.v{1}.txt.gz".format(
DOWNLOAD_LINK_STRING, version, params.species)
string_file_name_links = "{1}.protein.physical.links.full.v{0}.txt".format(version, params.species)
wget.download(url, out=string_file_name_links+'.gz')
with gzip.open(string_file_name_links+'.gz', 'rb') as f_in:
wget.download(url, out=string_file_name_links + '.gz')
with gzip.open(string_file_name_links + '.gz', 'rb') as f_in:
with open(string_file_name_links, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
url = "{0}protein.sequences.v{1}/{2}.protein.sequences.v{1}.fa.gz".format(DOWNLOAD_LINK_STRING, version, params.species)
url = "{0}protein.sequences.v{1}/{2}.protein.sequences.v{1}.fa.gz".format(DOWNLOAD_LINK_STRING, version,
params.species)
string_file_name_seqs = "{1}.protein.sequences.v{0}.fa".format(version, params.species)
wget.download(url, out=string_file_name_seqs+'.gz')
with gzip.open(string_file_name_seqs+'.gz', 'rb') as f_in:
wget.download(url, out=string_file_name_seqs + '.gz')
with gzip.open(string_file_name_seqs + '.gz', 'rb') as f_in:
with open(string_file_name_seqs, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
except HTTPError:
raise Exception('The files are not available for the specified species. '
'There might be two reasons for that: \n '
'1) the species is not available in STRING. Please check the STRING species list to verify. \n'
'2) the download link has changed. Please raise an issue in the repository. ')
'There might be two reasons for that: \n '
'1) the species is not available in STRING. Please check the STRING species list to '
'verify. \n '
'2) the download link has changed. Please raise an issue in the repository. ')
os.remove(string_file_name_seqs+'.gz')
os.remove(string_file_name_links+'.gz')
os.remove(string_file_name_seqs + '.gz')
os.remove(string_file_name_links + '.gz')
params.interactions = string_file_name_links
params.sequences = string_file_name_seqs
......@@ -330,4 +334,3 @@ if __name__ == '__main__':
parser = add_args(parser)
params = parser.parse_args()
main(params)
......@@ -68,9 +68,10 @@ def add_args(parser):
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
predict_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
predict_args.add_argument("--model_path", type=str, default=None,
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the preinstalled senseppi.ckpt trained version is used. "
"(Trained on human PPIs)")
predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.")
......@@ -79,7 +80,7 @@ def add_args(parser):
"(.tsv format will be added automatically)")
predict_args.add_argument("--with_self", action='store_true',
help="Include self-interactions in the predictions."
"By default they are not included since they were not part of training but"
"By default they are not included since they were not part of training but "
"they can be included by setting this flag to True.")
predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5,
help="Prediction threshold to determine interacting pairs that "
......
......@@ -109,7 +109,7 @@ def main(params):
sns.heatmap(labels_heatmap, cmap=cmap, vmin=0, vmax=1,
ax=ax1, mask=labels_heatmap == -1,
cbar=False, square=True) # , linewidths=0.5, linecolor='white')
cbar=False, square=True)
cbar = ax1.figure.colorbar(ax1.collections[0], ax=ax1, location='right', pad=0.15)
cbar.ax.yaxis.set_ticks_position('right')
......@@ -172,26 +172,28 @@ def add_args(parser):
parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be "
"typed (separated by whitespaces).")
string_pred_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
string_pred_args.add_argument("--model_path", type=str, default=None,
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
"None, the preinstalled senseppi.ckpt trained version is used. "
"(Trained on human PPIs)")
string_pred_args.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)")
help="Species from STRING database. Default: H. Sapiens")
string_pred_args.add_argument("-n", "--nodes", type=int, default=10,
help="Number of nodes to fetch from STRING database. Default: 10")
help="Number of nodes to fetch from STRING database. ")
string_pred_args.add_argument("-r", "--score", type=int, default=0,
help="Score threshold for STRING connections. Range: (0, 1000). Default: 0")
help="Score threshold for STRING connections. Range: (0, 1000). ")
string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500,
help="Prediction threshold. Range: (0, 1000). Default: 500")
help="Prediction threshold. Range: (0, 1000). ")
string_pred_args.add_argument("--graphs", action='store_true',
help="Enables plotting the heatmap and a network graph.")
string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string",
help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)")
string_pred_args.add_argument("--network_type", type=str, default="physical",
help="Network type: \"physical\" or \"functional\". Default: \"physical\"")
string_pred_args.add_argument("--network_type", type=str, default="physical", choices=['physical', 'functional'],
help="Network type to fetch from STRING database. ")
string_pred_args.add_argument("--delete_proteins", type=str, nargs="+", default=None,
help="List of proteins to delete from the graph. Default: None")
help="List of proteins to delete from the graph. "
"Several names can be specified separated by whitespaces. ")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
......
......@@ -37,16 +37,17 @@ def add_args(parser):
parser = add_general_args(parser)
test_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
parser._action_groups[0].add_argument("pairs_file", type=str,
help="A path to a .tsv file with pairs of proteins to test.")
parser._action_groups[0].add_argument("fasta_file",
type=pathlib.Path,
help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.",
)
test_args.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "..", "default_model", "senseppi.ckpt"),
test_args.add_argument("--model_path", type=str, default=None,
help="A path to .ckpt file that contains weights to a pretrained model. If "
"None, the senseppi trained version is used.")
"None, the preinstalled senseppi.ckpt trained version is used. "
"(Trained on human PPIs)")
test_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)")
......
......@@ -16,8 +16,10 @@ def add_esm_args(parent_parser):
parser = parent_parser.add_argument_group(title="ESM2 model args",
description="ESM2: Extract per-token representations and model "
"outputs for sequences in a FASTA file. "
"If you would like to use the basic version of SENSE-PPI "
"do no edit the default values of the arguments below. ")
"The representations are saved in --output_dir_esm folder so "
"they can be reused in multiple runs. In order to reuse the "
"embeddings, make sure that --output_dir_esm is set to the "
"correct folder.")
parser.add_argument(
"--model_location_esm",
type=str, default="esm2_t36_3B_UR50D",
......
......@@ -107,7 +107,9 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require
string_interactions = pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t'))
if 'Error' in string_interactions.columns:
raise Exception(string_interactions['ErrorMessage'].values[0])
err_msg = string_interactions['ErrorMessage'].values[0]
err_msg = err_msg.replace('<br>', '\n').replace('<br/>', '\n').replace('<p>', '\n').replace('</p>', '\n')
raise Exception(err_msg)
if len(string_interactions) == 0:
raise Exception('No interactions found. Please revise your input parameters.')
......
......@@ -7,15 +7,28 @@ import argparse
class ArgumentParserWithDefaults(argparse.ArgumentParser):
def add_argument(self, *args, help=None, default=None, **kwargs):
if help is not None:
kwargs['help'] = help
if default is not None and args[0] != '-h':
kwargs['default'] = default
if help is not None:
kwargs['help'] += ' Default: {}'.format(default)
def add_argument(self, *args, **kwargs):
if 'help' in kwargs and kwargs['help'] is not argparse.SUPPRESS and 'default' in kwargs and args[0] != '-h':
kwargs['help'] += ' (Default: {})'.format(kwargs['default'])
super().add_argument(*args, **kwargs)
def add_argument_group(self, *args, **kwargs):
group = ArgumentGroupWithDefaults(self, *args, **kwargs)
self._action_groups.append(group)
return group
class ArgumentGroupWithDefaults(argparse._ArgumentGroup):
def add_argument(self, *args, **kwargs):
if 'help' in kwargs and kwargs['help'] is not argparse.SUPPRESS and 'default' in kwargs and args[0] != '-h':
kwargs['help'] += ' (Default: {})'.format(kwargs['default'])
super().add_argument(*args, **kwargs)
def add_argument_group(self, *args, **kwargs):
group = self._ArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def add_general_args(parser):
parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment