All needed commands for version 0.1.0 are included

Working commands: predict, test, string_dataset_create
train and predict_string are still in progress
parent 04240968
__version__ = "0.1.0"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils
from . import model, commands, esm2_model, dataset, utils, network_utils
__all__ = [
"model",
"commands",
"esm2_model",
"dataset",
"utils"
"utils",
"network_utils"
]
\ No newline at end of file
import argparse
import logging
from .commands import *
from senseppi import __version__
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(
description="SENSE_PPI: Sequence-based EvolutIoNary ScalE Protein-Protein Interaction prediction",
usage="senseppi <command> [<args>]",
......@@ -15,7 +18,12 @@ def main():
subparsers = parser.add_subparsers(title="The list of SEINE-PPI commands:", required=True, dest="cmd")
modules = {'train': train, 'predict': predict}
modules = {'train': train,
'predict': predict,
'string_dataset_create': string_dataset_create,
'test': test,
'predict_string': predict_string
}
for name, module in modules.items():
sp = subparsers.add_parser(name)
......@@ -25,7 +33,7 @@ def main():
params = parser.parse_args()
#WARNING: due to some internal issues of torch, the mps backend is temporarily disabled
if params.device == 'mps':
if hasattr(params, 'device') and params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
......
__all__ = ['predict', 'train']
\ No newline at end of file
__all__ = ['predict', 'train', 'string_dataset_create', 'test', 'predict_string']
\ No newline at end of file
......@@ -4,6 +4,8 @@ from itertools import permutations, product
import numpy as np
import pandas as pd
import logging
import pathlib
import argparse
from ..dataset import PairSequenceData
from ..model import SensePPIModel
from ..utils import *
......@@ -66,6 +68,9 @@ def add_args(parser):
predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.")
......@@ -88,8 +93,6 @@ def add_args(parser):
def main(params):
logging.info("Device used: ", params.device)
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs_file is None:
generate_pairs(params.fasta_file, 'protein.pairs.tsv', with_self=params.with_self)
......
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import pandas as pd
import logging
import pathlib
import argparse
from ..dataset import PairSequenceData
from ..model import SensePPIModel
from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
def test(params):
eval_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
pretrained_model = SensePPIModel(params)
if params.device == 'gpu':
checkpoint = torch.load(params.model_path)
elif params.device == 'mps':
checkpoint = torch.load(params.model_path, map_location=torch.device('mps'))
else:
checkpoint = torch.load(params.model_path, map_location=torch.device('cpu'))
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False)
eval_loader = DataLoader(dataset=eval_data,
batch_size=params.batch_size,
num_workers=4)
return trainer.test(pretrained_model, eval_loader)
def add_args(parser):
parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test.")
parser._action_groups[0].add_argument("fasta_file",
type=pathlib.Path,
help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.",
)
predict_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)")
predict_args.add_argument("--crop_data_to_model_lims", action="store_true",
help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa.")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
add_esm_args(parser)
return parser
def main(params):
if params.crop_data_to_model_lims:
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2", "label"])
data = data[data['seq1'].isin(get_fasta_ids(params.fasta_file))]
data = data[data['seq2'].isin(get_fasta_ids(params.fasta_file))]
data.to_csv(params.pairs_file, sep='\t', index=False, header=False)
compute_embeddings(params)
logging.info('Evaluating...')
test_metrics = test(params)[0]
test_metrics_df = pd.DataFrame.from_dict(test_metrics, orient='index')
test_metrics_df.to_csv(params.output + '.tsv', sep='\t', header=False)
if __name__ == '__main__':
test_parser = argparse.ArgumentParser()
parser = add_args(test_parser)
params = test_parser.parse_args()
main(params)
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib
from ..utils import add_general_args
from ..model import SensePPIModel
from ..dataset import PairSequenceData
from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
......@@ -45,6 +46,9 @@ def add_args(parser):
"Required format: 3 tab separated columns: first protein, "
"second protein (protein names have to be present in fasta_file), "
"label (0 or 1).")
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then train.",
)
train_args.add_argument("--valid_size", type=float, default=0.1,
help="Fraction of the training data to use for validation.")
train_args.add_argument("--seed", type=int, default=None, help="Global training seed.")
......
......@@ -26,7 +26,7 @@ class PairSequenceData(Dataset):
dtypes.update({'label': np.float16})
self.actions = pd.read_csv(self.action_path, delimiter='\t', names=["seq1", "seq2", "label"], dtype=dtypes)
else:
self.actions = pd.read_csv(self.action_path, delimiter='\t', names=["seq1", "seq2"], dtype=dtypes)
self.actions = pd.read_csv(self.action_path, delimiter='\t', usecols=[0, 1], names=["seq1", "seq2"], dtype=dtypes)
def get_emb(self, emb_id):
f = os.path.join(self.emb_dir, '{}.pt'.format(emb_id))
......
......@@ -104,7 +104,7 @@ def compute_embeddings(params):
# Compute ESM embeddings
logging.info('Computing ESM embeddings if they are not already computed. '
'If all the files alreaady exist in output_dir_esm, this step will be skipped.')
'If all the files alreaady exist in {} folder, this step will be skipped.'.format(params.output_dir_esm))
if not os.path.exists(params.output_dir_esm):
run(params)
......
import json
from Bio import SeqIO
from itertools import permutations, product
import pandas as pd
import numpy as np
import os
import urllib.request
import time
from tqdm import tqdm
from copy import deepcopy
import requests
import gzip
import shutil
def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
if mode == 'all_to_all':
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
pairs = []
for p in all_pairs:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p)
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
data = pd.read_csv('string_interactions.tsv', delimiter='\t')
# Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
ids_dict.update(dict(zip(data['preferredName_B'], data['stringId_B'])))
data = data[['stringId_A', 'stringId_B', 'score']]
data.columns = ['seq1', 'seq2', 'label']
pairs = pairs.merge(data, on=['seq1', 'seq2'], how='left').fillna(0)
if delete_proteins is not None:
print('Labels removed: ', delete_proteins)
string_ids_to_delete = []
for label in delete_proteins:
string_ids_to_delete.append(ids_dict[label])
print('String ids to delete: ', string_ids_to_delete)
pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)]
pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)]
pairs.to_csv('protein.actions.tsv', sep='\t', index=False, header=False)
def generate_dscript_gene_names(file_path,
only_positives=True,
species='9606'):
data = pd.read_csv(file_path, delimiter='\t', names=['seq1', 'seq2', 'label'])
if only_positives:
train_ids = set(data['seq1'][data['label'] == 1].values).union(set(data['seq2'][data['label'] == 1].values))
else:
train_ids = set(data['seq1'].values).union(set(data['seq2'].values))
# train_ids = [train_id.split('.')[1] for train_id in train_ids]
train_ids = [train_id for train_id in train_ids if train_id.startswith(species)]
if len(train_ids) == 0:
return None
# Write a request to STRING API to get the gene names for the ids in train_ids
# Split the request into chunks of 100 ids and make a pause of 1 second between each chunk
chunk_size = 300
genes_string = pd.DataFrame()
for i in tqdm(range(0, len(train_ids), chunk_size)):
chunk = deepcopy(train_ids[i:i + chunk_size])
url = 'https://string-db.org/api/tsv/get_string_ids?identifiers=%s&species={}'.format(species) % '%0d'.join(
[c.split('.')[-1] for c in chunk])
response = urllib.request.urlopen(url)
data = response.read()
text = data.decode('utf-8')
text = text.split('\n')
# text = [t for t in text if t]
text = [t.split('\t') for t in text]
df = pd.DataFrame(text,
columns=['queryIndex', 'stringId', 'ncbiTaxonId', 'taxonName', 'preferredName', 'annotation'])
# Remove line if queryIndex is not int
df = df[df['queryIndex'].apply(lambda x: x.isdigit())]
df['QueryString'] = df['queryIndex'].apply(lambda x: chunk[int(x)])
# add stringId and preferredName to genes_string
genes_string = pd.concat([genes_string, df[['QueryString', 'preferredName']]])
# time.sleep(0.2)
return genes_string
def get_names_from_string(ids, species):
string_api_url, _ = get_string_url()
params = {
"identifiers": "\r".join(ids), # your protein list
"species": species, # species NCBI identifier
"limit": 1, # only one (best) identifier per input protein
"echo_query": 1, # see your input identifiers in the output
}
request_url = "/".join([string_api_url, "tsv", "get_string_ids"])
results = requests.post(request_url, data=params)
lines = results.text.strip().split("\n")
return pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t'))
def get_string_url():
# Get stable api and current STRING version
request_url = "/".join(["https://string-db.org/api", "json", "version"])
response = requests.post(request_url)
version = json.loads(response.text)[0]['string_version']
stable_address = json.loads(response.text)[0]['stable_address']
return "/".join([stable_address, "api"]), version
def get_interactions_from_string(gene_names, species=9606, add_nodes=10, required_score=500, network_type='physical'):
string_api_url, version = get_string_url()
output_format = "tsv"
method = "network"
# Download protein sequences for given species if not downloaded yet
if not os.path.isfile('{}.protein.sequences.v{}.fa'.format(species, version)):
print('Downloading protein sequences')
url = 'https://stringdb-static.org/download/protein.sequences.v{}/{}.protein.sequences.v{}.fa.gz'.format(
version, species, version)
urllib.request.urlretrieve(url, '{}.protein.sequences.v{}.fa.gz'.format(species, version))
print('Unzipping protein sequences')
with gzip.open('{}.protein.sequences.v{}.fa.gz'.format(species, version), 'rb') as f_in:
with open('{}.protein.sequences.v{}.fa'.format(species, version), 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove('{}.protein.sequences.v{}.fa.gz'.format(species, version))
print('Done')
request_url = "/".join([string_api_url, output_format, method])
if isinstance(gene_names, str):
gene_names = [gene_names]
params = {
"identifiers": "%0d".join(gene_names),
"species": species,
"required_score": required_score,
"add_nodes": add_nodes,
"network_type": network_type
}
response = requests.post(request_url, data=params)
lines = response.text.strip().split("\n")
string_interactions = pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t'))
if 'Error' in string_interactions.columns:
raise Exception(string_interactions['ErrorMessage'].values[0])
if len(string_interactions) == 0:
raise Exception('No interactions found. Please revise your input parameters.')
# Remove duplicated interactions
string_interactions.drop_duplicates(inplace=True)
# Make the interactions symmetric: add the interactions where the first and second columns are swapped
string_interactions = pd.concat([string_interactions, string_interactions.rename(
columns={'stringId_A': 'stringId_B', 'stringId_B': 'stringId_A', 'preferredName_A': 'preferredName_B',
'preferredName_B': 'preferredName_A'})])
# Getting the sequences for hparams.genes in case there are proteins with no connections and add ghost self_connections to keep gene names in the file
string_names_input_genes = get_names_from_string(gene_names, species)
string_names_input_genes['stringId_A'] = string_names_input_genes['stringId']
string_names_input_genes['preferredName_A'] = string_names_input_genes['preferredName']
string_names_input_genes['stringId_B'] = string_names_input_genes['stringId']
string_names_input_genes['preferredName_B'] = string_names_input_genes['preferredName']
string_interactions = pd.concat([string_interactions, string_names_input_genes[
['stringId_A', 'preferredName_A', 'stringId_B', 'preferredName_B']]])
string_interactions.fillna(0, inplace=True)
# For all the proteins in the first ans second columns extract their sequences from 9606.protein.sequences.v11.5.fasta and write them to sequences.fasta
ids = list(string_interactions['stringId_A'].values) + list(string_interactions['stringId_B'].values) + \
string_names_input_genes['stringId'].to_list()
ids = set(ids)
with open('sequences.fasta', 'w') as f:
for record in SeqIO.parse('{}.protein.sequences.v{}.fa'.format(species, version), "fasta"):
if record.id in ids:
SeqIO.write(record, f, "fasta")
string_interactions.to_csv('string_interactions.tsv', sep='\t', index=False)
if __name__ == '__main__':
print(generate_dscript_gene_names(
file_path=os.path.join('..', 'STRING_full', 'preprocessed', 'protein.actions_full.tsv'),
only_positives=True,
species='362663'))
\ No newline at end of file
from Bio import SeqIO
import os
import argparse
from senseppi import __version__
import pathlib
import torch
def add_general_args(parser):
parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__))
parser.add_argument(
"fasta_file",
type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then train or test.",
)
parser.add_argument("--min_len", type=int, default=50,
help="Minimum length of the protein sequence. "
"The sequences with smaller length will not be considered.")
......@@ -50,6 +43,13 @@ def process_string_fasta(fasta_file, min_len, max_len):
os.rename('file.tmp', fasta_file)
def get_fasta_ids(fasta_file):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
return ids
def remove_argument(parser, arg):
for action in parser._actions:
opts = action.option_strings
......
......@@ -13,19 +13,23 @@ setup(
url="",
license="MIT",
packages=find_packages(),
long_description=long_description,
long_description_content_type="text/markdown",
include_package_data=True,
install_requires=[
"numpy",
"pandas",
"wget",
"torch>=1.12",
"matplotlib",
"seaborn",
"tqdm",
"scikit-learn",
"pytorch-lightning==1.9.0",
"torchmetrics",
"biopython",
"fair-esm"
"fair-esm",
"mmseqs2"
],
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment