All needed commands for version 0.1.0 are included

Working commands: predict, test, string_dataset_create
train and predict_string are still in progress
parent 04240968
__version__ = "0.1.0"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils
from . import model, commands, esm2_model, dataset, utils, network_utils
__all__ = [
"model",
"commands",
"esm2_model",
"dataset",
"utils"
"utils",
"network_utils"
]
\ No newline at end of file
import argparse
import logging
from .commands import *
from senseppi import __version__
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(
description="SENSE_PPI: Sequence-based EvolutIoNary ScalE Protein-Protein Interaction prediction",
usage="senseppi <command> [<args>]",
......@@ -15,7 +18,12 @@ def main():
subparsers = parser.add_subparsers(title="The list of SEINE-PPI commands:", required=True, dest="cmd")
modules = {'train': train, 'predict': predict}
modules = {'train': train,
'predict': predict,
'string_dataset_create': string_dataset_create,
'test': test,
'predict_string': predict_string
}
for name, module in modules.items():
sp = subparsers.add_parser(name)
......@@ -25,7 +33,7 @@ def main():
params = parser.parse_args()
#WARNING: due to some internal issues of torch, the mps backend is temporarily disabled
if params.device == 'mps':
if hasattr(params, 'device') and params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
......
__all__ = ['predict', 'train']
\ No newline at end of file
__all__ = ['predict', 'train', 'string_dataset_create', 'test', 'predict_string']
\ No newline at end of file
......@@ -4,6 +4,8 @@ from itertools import permutations, product
import numpy as np
import pandas as pd
import logging
import pathlib
import argparse
from ..dataset import PairSequenceData
from ..model import SensePPIModel
from ..utils import *
......@@ -66,6 +68,9 @@ def add_args(parser):
predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.")
......@@ -88,8 +93,6 @@ def add_args(parser):
def main(params):
logging.info("Device used: ", params.device)
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs_file is None:
generate_pairs(params.fasta_file, 'protein.pairs.tsv', with_self=params.with_self)
......
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef
import networkx as nx
import seaborn as sns
import matplotlib
from matplotlib.lines import Line2D
from scipy.cluster.hierarchy import linkage, fcluster
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
from pathlib import Path
from ..model import SensePPIModel
from ..utils import *
from ..network_utils import *
from ..esm2_model import add_esm_args
def main(hparams):
LABEL_THRESHOLD = hparams.score / 1000.
PRED_THRESHOLD = params.pred_threshold / 1000.
test_data = DscriptData(emb_dir='esm_emb_3B', max_len=800, dir_path='', actions_file='protein.actions.tsv')
actions_path = os.path.join('..', 'Data', 'Dscript', 'preprocessed', 'human_train.tsv')
loadpath = os.path.join('..', DSCRIPT_PATH)
model = SensePPIModel(hparams)
if hparams.nogpu:
checkpoint = torch.load(loadpath, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(loadpath)
model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator="cpu" if hparams.nogpu else 'gpu', logger=False)
test_loader = DataLoader(dataset=test_data,
batch_size=64,
num_workers=4)
preds = [pred for batch in trainer.predict(model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds)
# open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.actions.tsv', delimiter='\t', names=["seq1", "seq2", "label"])
data['binary_label'] = data['label'].apply(lambda x: 1 if x > LABEL_THRESHOLD else 0)
data['preds'] = preds
if hparams.normalize:
data['preds'] = (data['preds'] - data['preds'].min()) / (data['preds'].max() - data['preds'].min())
print(data.sort_values(by=['preds'], ascending=False).to_string())
# Calculate torch metrics based on data['binary_label'] and data['preds']
torch_labels = torch.tensor(data['binary_label'])
torch_preds = torch.tensor(data['preds'])
print('Accuracy: ', Accuracy(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('Precision: ', Precision(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('Recall: ', Recall(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('F1Score: ', F1Score(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('MatthewsCorrCoef: ',
MatthewsCorrCoef(num_classes=2, threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('ROCAUC: ', AUROC()(torch_preds, torch_labels))
# Create a dictionary of string ids and gene names from string_interactions_short.tsv
string_ids = {}
string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[
['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']]
for i, row in string_tsv.iterrows():
string_ids[row['stringId_A']] = row['preferredName_A']
string_ids[row['stringId_B']] = row['preferredName_B']
print('Fetching gene names for training set from STRING...')
if not os.path.exists('all_genes_train.tsv'):
all_genes = generate_dscript_gene_names(
file_path=actions_path,
only_positives=True,
species=str(hparams.species))
all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False)
else:
all_genes = pd.read_csv('all_genes_train.tsv', sep='\t')
# Create a tuple of gene pairs presented in training data, corrresponding gene names are found in 'genes' DataFrame
full_train_data = pd.read_csv(actions_path,
delimiter='\t', names=['seq1', 'seq2', 'label'])
# To make sure that we do not use the test species in the training data
full_train_data = full_train_data[full_train_data.seq1.str.startswith('6239') == False]
full_train_data = full_train_data[full_train_data.seq2.str.startswith('6239') == False]
if all_genes is not None:
full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge(
all_genes, left_on='seq2', right_on='QueryString', how='left')
full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']]
positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']]
full_train_data = full_train_data[['preferredName_x', 'preferredName_y']]
full_train_data = [tuple(x) for x in full_train_data.values]
positive_train_data = [tuple(x) for x in positive_train_data.values]
else:
full_train_data = None
positive_train_data = None
if not hparams.no_graphs:
# Create two subpolots but make a short gap between them
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2})
# Plot the predictions as matrix but do not sort the labels
data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='label')
# Produce the list of protein names from sequences.fasta file
protein_names = [line.strip()[1:] for line in open('sequences.fasta', 'r') if line.startswith('>')]
# Produce the list of gene names from genes Dataframe
gene_names = [string_ids[prot] for prot in protein_names]
# Sort the preds for index and columns based on the order of proteins in the fasta file
data_heatmap = data_heatmap.reindex(index=protein_names, columns=protein_names)
labels_heatmap = labels_heatmap.reindex(index=protein_names, columns=protein_names)
# Replace labels with gene names
data_heatmap.index = gene_names
data_heatmap.columns = gene_names
labels_heatmap.index = gene_names
labels_heatmap.columns = gene_names
# Remove genes that are in hparams.delete_proteins
if hparams.delete_proteins is not None:
for protein in hparams.delete_proteins:
if protein in data_heatmap.index:
data_heatmap = data_heatmap.drop(protein, axis=0)
data_heatmap = data_heatmap.drop(protein, axis=1)
labels_heatmap = labels_heatmap.drop(protein, axis=0)
labels_heatmap = labels_heatmap.drop(protein, axis=1)
# Make sure that the matrices are symmetric for clustering
labels_heatmap = labels_heatmap.fillna(value=0)
labels_heatmap = labels_heatmap + labels_heatmap.T
np.fill_diagonal(labels_heatmap.values, -1)
data_heatmap = data_heatmap.fillna(value=0)
data_heatmap = data_heatmap + data_heatmap.T
np.fill_diagonal(data_heatmap.values, -1)
linkages = linkage(labels_heatmap, method='complete', metric='euclidean')
new_labels = np.argsort(fcluster(linkages, 0.05, criterion='distance'))
col_order = labels_heatmap.columns[new_labels]
row_order = labels_heatmap.index[new_labels]
labels_heatmap = labels_heatmap.reindex(index=row_order, columns=col_order)
data_heatmap = data_heatmap.reindex(index=row_order, columns=col_order)
# Fill the upper triangle of labels_heatmap with values from data_heatmap
labels_heatmap.values[np.triu_indices_from(labels_heatmap.values)] = data_heatmap.values[
np.triu_indices_from(data_heatmap.values)]
labels_heatmap.fillna(value=-1, inplace=True)
# In (labels+data)_heatmap, if a pair of genes is in train data, color it in black in the upper triangle
if full_train_data is not None:
for i, row in labels_heatmap.iterrows():
for j, _ in row.items():
if (i, j) in full_train_data or (j, i) in full_train_data:
labels_heatmap.loc[i, j] = -1
cmap = matplotlib.cm.get_cmap('coolwarm').copy()
cmap.set_bad("black")
sns.heatmap(labels_heatmap, cmap=cmap, vmin=0, vmax=1,
ax=ax1, mask=labels_heatmap == -1,
cbar=False, square=True) # , linewidths=0.5, linecolor='white')
cbar = ax1.figure.colorbar(ax1.collections[0], ax=ax1, location='right', pad=0.15)
cbar.ax.yaxis.set_ticks_position('right')
ax1.set_yticklabels(ax1.get_yticklabels(), rotation=0)
ax1.set_ylabel('String interactions', weight='bold', fontsize=18)
ax1.set_title('Predictions', weight='bold', fontsize=18)
ax1.yaxis.tick_right()
# ax1.plot([0, len(labels_heatmap)], [0, len(labels_heatmap)], color='white')
for i in range(len(labels_heatmap)):
ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100))
# print(data[data['label'] > 0].sort_values(by=['preds'], ascending=False))
# Build a multigraph from the data with the predictions as edge weights and edges in red and the labels as secondary edges in black
G = nx.Graph()
for i, row in data.iterrows():
if row['label'] > LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='black', weight=row['label'], style='dotted')
if row['preds'] > PRED_THRESHOLD and G.has_edge(row['seq1'], row['seq2']):
G[row['seq1']][row['seq2']]['style'] = 'solid'
G[row['seq1']][row['seq2']]['color'] = 'limegreen'
if row['preds'] > PRED_THRESHOLD and row['label'] <= LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid')
# Replace the string ids with gene names
G = nx.relabel_nodes(G, string_ids)
# If edge is present in training data make it blue
if positive_train_data is not None:
for edge in G.edges():
if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data:
print('TRAINING EDGE: ', edge)
G[edge[0]][edge[1]]['color'] = 'darkblue'
# G[edge[0]][edge[1]]['weight'] = 1
# Make nodes red if they are present in training data
for node in G.nodes():
if all_genes is not None and node in all_genes['preferredName'].values:
G.nodes[node]['color'] = 'orange'
else:
G.nodes[node]['color'] = 'lightgrey'
# nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G), edge_labels=nx.get_edge_attributes(G, 'weight'))
# Make edges longer and do not put nodes too close to each other
pos = nx.spring_layout(G, k=2., iterations=100)
# Call the same function nx.draw with the same arguments as above but make sure it is plotted in ax2
nx.draw(G, pos=pos, with_labels=True, ax=ax2,
edge_color=[G[u][v]['color'] for u, v in G.edges()], width=[G[u][v]['weight'] for u, v in G.edges()],
style=[G[u][v]['style'] for u, v in G.edges()],
node_color=[G.nodes[node]['color'] for node in G.nodes()])
# Put a legend for colors
legend_elements = [
Line2D([0], [0], marker='_', color='darkblue', label='PP from training data', markerfacecolor='darkblue',
markersize=10),
Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10),
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10),
Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
markersize=10, linestyle='dotted')]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8)
savepath = 'graph_{}_{}'.format('_'.join(hparams.genes), hparams.species)
if 'Dscript' in actions_path:
savepath += '_Dscript'
if 'HVLSTM' in loadpath:
savepath += '_HVLSTM'
else:
version = re.search('version_([0-9]+)', loadpath)
if version is not None:
savepath += '_v' + version.group(1)
savepath += '.pdf'
plt.savefig(savepath, bbox_inches='tight', dpi=600)
print("The graphs were saved in: ", savepath)
plt.show()
plt.close()
def add_args(parser):
parser = add_general_args(parser)
parser2 = parser.add_argument_group(title="General options")
parser2.add_argument("--no_graphs", action='store_true', help="No plotting testing graphs.")
parser2.add_argument("-g", "--genes", type=str, nargs="+", default="RFC5",
help="Name of gene to fetch from STRING database. Several names can be typed (separated by whitespaces). Default: RFC5")
parser2.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)")
parser2.add_argument("-n", "--nodes", type=int, default=10,
help="Number of nodes to fetch from STRING database. Default: 10")
parser2.add_argument("-r", "--score", type=int, default=500,
help="Score threshold for STRING connections. Range: (0, 1000). Default: 500")
parser2.add_argument("-p", "--pred_threshold", type=int, default=500,
help="Prediction threshold. Range: (0, 1000). Default: 500")
parser2.add_argument("--network_type", type=str, default="physical",
help="Network type: \"physical\" or \"functional\". Default: \"physical\"")
parser2.add_argument("--normalize", action='store_true', help="Normalize the predictions.")
parser2.add_argument("--delete_proteins", type=str, nargs="+", default=None,
help="List of proteins to delete from the graph. Default: None")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
add_esm_args(parser)
return parser
if __name__ == '__main__':
params = parser.parse_args()
if torch.cuda.is_available():
# torch.cuda.set_per_process_memory_fraction(0.9, 0)
print('Number of devices: ', torch.cuda.device_count())
print('GPU used: ', torch.cuda.get_device_name(0))
torch.set_float32_matmul_precision('high')
else:
print('No GPU available, using the CPU instead.')
params.nogpu = True
print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
process_string_fasta('sequences.fasta')
generate_pairs('sequences.fasta', mode='all_to_all', with_self=False, delete_proteins=params.delete_proteins)
# Compute ESM embeddings
# First, check is all embeddings are already computed
params.model_location = 'esm2_t36_3B_UR50D'
params.fasta_file = 'sequences.fasta'
params.output_dir = Path('esm_emb_3B')
params.include = 'per_tok'
with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')]
if not os.path.exists(params.output_dir):
print('Computing ESM embeddings...')
esm_extract.run(params)
else:
for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir, seq_id + '.pt')):
print('Computing ESM embeddings...')
esm_extract.run(params)
break
print('Predicting...')
main(params)
os.remove('sequences.fasta')
os.remove('protein.actions.tsv')
os.remove('string_interactions.tsv')
# srun --gres gpu:1 python network_test.py -n 30 --pred_threshold 500 -r 0 -s 9606 -g C1R RFC5 --delete_proteins C1S RFC3 RFC4 DSCC1 CHTF8 RAD17 RPA1 RPA2
\ No newline at end of file
import pandas as pd
import os
from tqdm import tqdm
from Bio import SeqIO
import logging
import argparse
import subprocess
import wget
import gzip
import shutil
import random
from ..network_utils import get_string_url
def _count_generator(reader):
b = reader(1024 * 1024)
while b:
yield b
b = reader(1024 * 1024)
# A class containing methods for preprocessing the full STRING data
class STRINGDatasetCreation:
def __init__(self, params):
self.interactions_file = params.interactions
self.sequences_file = params.sequences
self.min_length = params.min_length
self.max_length = params.max_length
self.species = params.species
self.max_positive_pairs = params.max_positive_pairs
self.combined_score = params.combined_score
self.experimental_score = params.experimental_score
self.intermediate_file = 'interactions_intermediate.tmp'
self.proprocced_protein_set = None
if not params.not_remove_long_short_proteins:
self.process_fasta_file()
self._select_interactions()
def _select_interactions(self):
# Creating a new intermediate file with only the interactions that are suitable for training
# Such interaction happen only between proteins of appropriate lenghs
#
# And either have a high combined score of > 700
#
# Further on, redundant interactions are removed as well as sequences with inappropriate
# length and interactions based on homology
if not os.path.isfile(self.intermediate_file):
if not os.path.isfile('clusters.tsv'):
logging.info('Running mmseqs to clusterize proteins')
logging.info('This might take a while if you secide to process the whole STRING database.')
logging.info(
'In order to install mmseqs (if not installed), please visit https://github.com/soedinglab/MMseqs2')
commands = "; ".join(["mkdir mmseqDBs",
"mmseqs createdb {} mmseqDBs/DB".format(self.sequences_file),
"mmseqs cluster mmseqDBs/DB mmseqDBs/clusterDB tmp --min-seq-id 0.4 --alignment-mode 3 --cov-mode 1 --threads 8",
"mmseqs createtsv mmseqDBs/DB mmseqDBs/DB mmseqDBs/clusterDB clusters.tsv",
"rm -r mmseqDBs",
"rm -r tmp"])
ret = subprocess.run(commands, shell=True, capture_output=True)
print(ret.stdout.decode())
print(ret.stderr.decode())
logging.info('Clusters file created')
# Compute the length of self.interactions_file
with open(self.interactions_file, 'rb') as f:
c_generator = _count_generator(f.raw.read)
n_interactions = sum(buffer.count(b'\n') for buffer in c_generator) + 1
logging.info('Removing redundant interactions')
clusters = pd.read_csv('clusters.tsv', sep='\t', header=None,
names=['cluster', 'protein']).set_index('protein')
clusters = clusters.to_dict()['cluster']
logging.info('Clusters file loaded')
print('Proteins in clusters: {}'.format(len(clusters)))
existing_cluster_pairs = set()
infomsg = 'Extracting only entries with no homology and:\n' \
'combined score >= {} '.format(self.combined_score)
if self.experimental_score is not None:
infomsg += '\nexperimental score >= {}'.format(self.experimental_score)
logging.info(infomsg)
with open(self.intermediate_file, 'w') as f:
with open(self.interactions_file, 'r') as f2:
f.write('\t'.join(f2.readline().strip().split(' ')) + '\n')
for line in tqdm(f2, total=n_interactions):
line = line.strip().split(' ')
if self.species is not None:
if not line[0].startswith(self.species) or not line[1].startswith(self.species):
continue
if self.experimental_score is not None and int(line[3]) < self.experimental_score:
continue
if int(line[2]) == 0 and int(line[-1]) >= self.combined_score:
try:
cluster1 = clusters[line[0]]
cluster2 = clusters[line[1]]
except KeyError:
continue
if cluster1 < cluster2:
cluster_pair = (cluster1, cluster2)
else:
cluster_pair = (cluster2, cluster1)
if cluster_pair not in existing_cluster_pairs:
existing_cluster_pairs.add(cluster_pair)
f.write('\t'.join(line) + '\n')
print('Number of proteins: {}'.format(len(clusters)))
print('Number of interactions: {}'.format(len(existing_cluster_pairs)))
print('Intermediate preprocessing done.')
print('You can find the preprocessed file in {}. If the dataset creation is fully done, '
'it will be deleted.'.format(self.intermediate_file))
def final_preprocessing_positives(self):
# This function generates the final preprocessed file.
data = pd.read_csv(self.intermediate_file, sep='\t')
# Here you can put further constraints on the interactions.
# For example, you can further remove unreliable interactions.
data = data[['protein1', 'protein2', 'combined_score']]
if self.max_positive_pairs is not None:
self.max_positive_pairs = min(self.max_positive_pairs, len(data))
data = data.sort_values(by=['combined_score'], ascending=False).iloc[:self.max_positive_pairs]
data['combined_score'] = 1
data.to_csv("protein.pairs_{}.tsv.tmp".format(self.species), sep='\t', index=False)
# Create new fasta file with only the proteins that are in the interactions file
proteins = set(data['protein1'].unique()).union(set(data['protein2'].unique()))
self.proprocced_protein_set = proteins
with open("sequences_{}.fasta".format(self.species), 'w') as f:
for record in tqdm(SeqIO.parse(self.sequences_file, "fasta")):
if record.id in proteins:
SeqIO.write(record, f, "fasta")
logging.info('Final preprocessing for only positive pairs done.')
def create_negatives(self):
if not os.path.isfile('clusters_preprocessed.tsv'):
logging.info(
'Running mmseqs to compute pairwise sequence similarity for all proteins in preprocceced file.')
logging.info('This might take a while.')
commands = "; ".join(["mkdir mmseqDBs",
"mmseqs createdb sequences_{}.fasta mmseqDBs/DB".format(self.species),
"mmseqs cluster mmseqDBs/DB mmseqDBs/clusterDB tmp --min-seq-id 0.4 --alignment-mode 3 --cov-mode 1 --threads 8",
"mmseqs createtsv mmseqDBs/DB mmseqDBs/DB mmseqDBs/clusterDB clusters_preprocessed.tsv",
"rm -r mmseqDBs",
"rm -r tmp"])
ret = subprocess.run(commands, shell=True, capture_output=True)
print(ret.stdout.decode())
print(ret.stderr.decode())
logging.info('Clusters file created')
clusters_preprocessed = pd.read_csv('clusters_preprocessed.tsv', sep='\t', header=None,
names=['cluster', 'protein']).set_index('protein')
clusters_preprocessed = clusters_preprocessed.to_dict()['cluster']
# Creating new protein.pairs.tsv file that will be used for training This file will contain both positive and
# negative pairs with ratio 1:10 The negative pairs will be generated using the clusters file: making sure
# that any paired protein is not in the same cluster with proteins interacting with a given one already.
# This is done to make sure that the negative pairs are not too similar to the positive ones
proteins = list(clusters_preprocessed.keys())
interactions = pd.read_csv("protein.pairs_{}.tsv.tmp".format(self.species), sep='\t')
logging.info('Generating negative pairs.')
tqdm.pandas()
proteins1 = random.choices(proteins, k=len(interactions) * 12)
proteins2 = random.choices(proteins, k=len(interactions) * 12)
negative_pairs = pd.DataFrame({'protein1': proteins1, 'protein2': proteins2, 'combined_score': 0})
logging.info('Negative pairs generated. Filtering out duplicates.')
# Make protein1 and protein2 in alphabetical order
negative_pairs['protein1'], negative_pairs['protein2'] = zip(*negative_pairs.progress_apply(
lambda x: (x['protein1'], x['protein2']) if x['protein1'] < x['protein2'] else (
x['protein2'], x['protein1']), axis=1))
negative_pairs = negative_pairs.drop_duplicates()
logging.info('Duplicates filtered out. Filtering out pairs that are already in the positive interactions file.')
negative_pairs = negative_pairs[
~negative_pairs.progress_apply(lambda x: len(interactions[(interactions['protein1'] == x[
'protein1']) & (interactions['protein2'] == x['protein2'])]) > 0, axis=1)]
negative_pairs = negative_pairs[
~negative_pairs.progress_apply(lambda x: len(interactions[(interactions['protein1'] == x[
'protein2']) & (interactions['protein2'] == x['protein1'])]) > 0, axis=1)]
logging.info(
'Pairs that are already in the positive interactions file filtered out. Filtering out pairs that are in '
'the same cluster with proteins interacting with a given one already.')
negative_pairs = negative_pairs[~negative_pairs.progress_apply(
lambda x: clusters_preprocessed[x['protein2']] in [clusters_preprocessed[i] for i in
interactions[interactions['protein1'] == x['protein1']][
'protein2'].unique()], axis=1)]
assert len(negative_pairs) > len(interactions) * 10, 'Not enough negative pairs generated. P' \
'lease try again and increase the number of pairs (>1000).'
negative_pairs = negative_pairs.iloc[:len(interactions) * 10]
logging.info('Negative pairs generated. Saving to file.')
interactions = pd.concat([interactions, negative_pairs], ignore_index=True)
interactions.to_csv(os.path.join("protein.pairs_{}.tsv".format(self.species)), sep='\t',
index=False,
header=False)
os.remove("protein.pairs_{}.tsv.tmp".format(self.species))
os.remove(self.intermediate_file)
os.remove("clusters_preprocessed.tsv")
os.remove("clusters.tsv")
os.remove(self.interactions_file)
os.remove(self.sequences_file)
# A method to remove sequences of inappropriate length from a fasta file
def process_fasta_file(self):
logging.info('Getting protein names out of fasta file.')
logging.info(
'Removing proteins that are shorter than {}aa or longer than {}aa.'.format(self.min_length,
self.max_length))
with open('seqs.tmp', 'w') as f:
for record in tqdm(SeqIO.parse(self.sequences_file, "fasta")):
if len(record.seq) < self.min_length or len(record.seq) > self.max_length:
continue
record.description = ''
record.name = ''
SeqIO.write(record, f, "fasta")
# Rename the temporary file to the original file
os.rename('seqs.tmp', self.sequences_file)
def add_args(parser):
parser.add_argument("species", type=str,
help="The Taxon identifier of the organism of interest.")
parser.add_argument("--interactions", type=str, default=None,
help="The physical links (full) file from STRING for the "
"organism of interest.")
parser.add_argument("--sequences", type=str, default=None,
help="The sequences file downloaded from the same page of STRING. "
"For both files see https://string-db.org/cgi/download")
parser.add_argument("--not_remove_long_short_proteins", action='store_true',
help="Whether to remove proteins that are too short or too long. "
"Normally, the long and short proteins are removed.")
parser.add_argument("--min_length", type=int, default=50,
help="The minimum length of a protein to be included in the dataset.")
parser.add_argument("--max_length", type=int, default=800,
help="The maximum length of a protein to be included in the dataset.")
parser.add_argument("--max_positive_pairs", type=int, default=5000,
help="The maximum number of positive pairs to be included in the dataset. "
"If None, all pairs are included. If specified, the pairs are selected "
"based on the combined score in STRING.")
parser.add_argument("--combined_score", type=int, default=500,
help="The combined score threshold for the pairs extracted from STRING. "
"Ranges from 0 to 1000.")
parser.add_argument("--experimental_score", type=int, default=None,
help="The experimental score threshold for the pairs extracted from STRING. "
"Ranges from 0 to 1000. Default is None, which means that the experimental "
"score is not used.")
return parser
def main(params):
if params.interactions is None or params.sequences is None:
logging.info('One or both of the files are not specified (interactions or sequences). '
'Downloading from STRING...')
_, version = get_string_url()
url = "https://stringdb-static.org/download/protein.physical.links.full.v{0}/{1}.protein.physical.links.full.v{0}.txt.gz".format(version, params.species)
string_file_name_links = "{1}.protein.physical.links.full.v{0}.txt".format(version, params.species)
wget.download(url, out=string_file_name_links+'.gz')
with gzip.open(string_file_name_links+'.gz', 'rb') as f_in:
with open(string_file_name_links, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
url = "https://stringdb-static.org/download/protein.sequences.v{0}/{1}.protein.sequences.v{0}.fa.gz".format(version, params.species)
string_file_name_seqs = "{1}.protein.sequences.v{0}.fa".format(version, params.species)
wget.download(url, out=string_file_name_seqs+'.gz')
with gzip.open(string_file_name_seqs+'.gz', 'rb') as f_in:
with open(string_file_name_seqs, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(string_file_name_seqs+'.gz')
os.remove(string_file_name_links+'.gz')
params.interactions = string_file_name_links
params.sequences = string_file_name_seqs
data = STRINGDatasetCreation(params)
data.final_preprocessing_positives()
data.create_negatives()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args()
main(params)
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import pandas as pd
import logging
import pathlib
import argparse
from ..dataset import PairSequenceData
from ..model import SensePPIModel
from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
def test(params):
eval_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
pretrained_model = SensePPIModel(params)
if params.device == 'gpu':
checkpoint = torch.load(params.model_path)
elif params.device == 'mps':
checkpoint = torch.load(params.model_path, map_location=torch.device('mps'))
else:
checkpoint = torch.load(params.model_path, map_location=torch.device('cpu'))
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False)
eval_loader = DataLoader(dataset=eval_data,
batch_size=params.batch_size,
num_workers=4)
return trainer.test(pretrained_model, eval_loader)
def add_args(parser):
parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test.")
parser._action_groups[0].add_argument("fasta_file",
type=pathlib.Path,
help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.",
)
predict_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)")
predict_args.add_argument("--crop_data_to_model_lims", action="store_true",
help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa.")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
add_esm_args(parser)
return parser
def main(params):
if params.crop_data_to_model_lims:
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2", "label"])
data = data[data['seq1'].isin(get_fasta_ids(params.fasta_file))]
data = data[data['seq2'].isin(get_fasta_ids(params.fasta_file))]
data.to_csv(params.pairs_file, sep='\t', index=False, header=False)
compute_embeddings(params)
logging.info('Evaluating...')
test_metrics = test(params)[0]
test_metrics_df = pd.DataFrame.from_dict(test_metrics, orient='index')
test_metrics_df.to_csv(params.output + '.tsv', sep='\t', header=False)
if __name__ == '__main__':
test_parser = argparse.ArgumentParser()
parser = add_args(test_parser)
params = test_parser.parse_args()
main(params)
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib
from ..utils import add_general_args
from ..model import SensePPIModel
from ..dataset import PairSequenceData
from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
......@@ -45,6 +46,9 @@ def add_args(parser):
"Required format: 3 tab separated columns: first protein, "
"second protein (protein names have to be present in fasta_file), "
"label (0 or 1).")
parser._action_groups[0].add_argument("fasta_file", type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then train.",
)
train_args.add_argument("--valid_size", type=float, default=0.1,
help="Fraction of the training data to use for validation.")
train_args.add_argument("--seed", type=int, default=None, help="Global training seed.")
......
......@@ -26,7 +26,7 @@ class PairSequenceData(Dataset):
dtypes.update({'label': np.float16})
self.actions = pd.read_csv(self.action_path, delimiter='\t', names=["seq1", "seq2", "label"], dtype=dtypes)
else:
self.actions = pd.read_csv(self.action_path, delimiter='\t', names=["seq1", "seq2"], dtype=dtypes)
self.actions = pd.read_csv(self.action_path, delimiter='\t', usecols=[0, 1], names=["seq1", "seq2"], dtype=dtypes)
def get_emb(self, emb_id):
f = os.path.join(self.emb_dir, '{}.pt'.format(emb_id))
......
......@@ -104,7 +104,7 @@ def compute_embeddings(params):
# Compute ESM embeddings
logging.info('Computing ESM embeddings if they are not already computed. '
'If all the files alreaady exist in output_dir_esm, this step will be skipped.')
'If all the files alreaady exist in {} folder, this step will be skipped.'.format(params.output_dir_esm))
if not os.path.exists(params.output_dir_esm):
run(params)
......
import json
from Bio import SeqIO
from itertools import permutations, product
import pandas as pd
import numpy as np
import os
import urllib.request
import time
from tqdm import tqdm
from copy import deepcopy
import requests
import gzip
import shutil
def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
if mode == 'all_to_all':
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
pairs = []
for p in all_pairs:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p)
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
data = pd.read_csv('string_interactions.tsv', delimiter='\t')
# Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
ids_dict.update(dict(zip(data['preferredName_B'], data['stringId_B'])))
data = data[['stringId_A', 'stringId_B', 'score']]
data.columns = ['seq1', 'seq2', 'label']
pairs = pairs.merge(data, on=['seq1', 'seq2'], how='left').fillna(0)
if delete_proteins is not None:
print('Labels removed: ', delete_proteins)
string_ids_to_delete = []
for label in delete_proteins:
string_ids_to_delete.append(ids_dict[label])
print('String ids to delete: ', string_ids_to_delete)
pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)]
pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)]
pairs.to_csv('protein.actions.tsv', sep='\t', index=False, header=False)
def generate_dscript_gene_names(file_path,
only_positives=True,
species='9606'):
data = pd.read_csv(file_path, delimiter='\t', names=['seq1', 'seq2', 'label'])
if only_positives:
train_ids = set(data['seq1'][data['label'] == 1].values).union(set(data['seq2'][data['label'] == 1].values))
else:
train_ids = set(data['seq1'].values).union(set(data['seq2'].values))
# train_ids = [train_id.split('.')[1] for train_id in train_ids]
train_ids = [train_id for train_id in train_ids if train_id.startswith(species)]
if len(train_ids) == 0:
return None
# Write a request to STRING API to get the gene names for the ids in train_ids
# Split the request into chunks of 100 ids and make a pause of 1 second between each chunk
chunk_size = 300
genes_string = pd.DataFrame()
for i in tqdm(range(0, len(train_ids), chunk_size)):
chunk = deepcopy(train_ids[i:i + chunk_size])
url = 'https://string-db.org/api/tsv/get_string_ids?identifiers=%s&species={}'.format(species) % '%0d'.join(
[c.split('.')[-1] for c in chunk])
response = urllib.request.urlopen(url)
data = response.read()
text = data.decode('utf-8')
text = text.split('\n')
# text = [t for t in text if t]
text = [t.split('\t') for t in text]
df = pd.DataFrame(text,
columns=['queryIndex', 'stringId', 'ncbiTaxonId', 'taxonName', 'preferredName', 'annotation'])
# Remove line if queryIndex is not int
df = df[df['queryIndex'].apply(lambda x: x.isdigit())]
df['QueryString'] = df['queryIndex'].apply(lambda x: chunk[int(x)])
# add stringId and preferredName to genes_string
genes_string = pd.concat([genes_string, df[['QueryString', 'preferredName']]])
# time.sleep(0.2)
return genes_string
def get_names_from_string(ids, species):
string_api_url, _ = get_string_url()
params = {
"identifiers": "\r".join(ids), # your protein list
"species": species, # species NCBI identifier
"limit": 1, # only one (best) identifier per input protein
"echo_query": 1, # see your input identifiers in the output
}
request_url = "/".join([string_api_url, "tsv", "get_string_ids"])
results = requests.post(request_url, data=params)
lines = results.text.strip().split("\n")
return pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t'))
def get_string_url():
# Get stable api and current STRING version
request_url = "/".join(["https://string-db.org/api", "json", "version"])
response = requests.post(request_url)
version = json.loads(response.text)[0]['string_version']
stable_address = json.loads(response.text)[0]['stable_address']
return "/".join([stable_address, "api"]), version
def get_interactions_from_string(gene_names, species=9606, add_nodes=10, required_score=500, network_type='physical'):
string_api_url, version = get_string_url()
output_format = "tsv"
method = "network"
# Download protein sequences for given species if not downloaded yet
if not os.path.isfile('{}.protein.sequences.v{}.fa'.format(species, version)):
print('Downloading protein sequences')
url = 'https://stringdb-static.org/download/protein.sequences.v{}/{}.protein.sequences.v{}.fa.gz'.format(
version, species, version)
urllib.request.urlretrieve(url, '{}.protein.sequences.v{}.fa.gz'.format(species, version))
print('Unzipping protein sequences')
with gzip.open('{}.protein.sequences.v{}.fa.gz'.format(species, version), 'rb') as f_in:
with open('{}.protein.sequences.v{}.fa'.format(species, version), 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove('{}.protein.sequences.v{}.fa.gz'.format(species, version))
print('Done')
request_url = "/".join([string_api_url, output_format, method])
if isinstance(gene_names, str):
gene_names = [gene_names]
params = {
"identifiers": "%0d".join(gene_names),
"species": species,
"required_score": required_score,
"add_nodes": add_nodes,
"network_type": network_type
}
response = requests.post(request_url, data=params)
lines = response.text.strip().split("\n")
string_interactions = pd.DataFrame([line.split('\t') for line in lines[1:]], columns=lines[0].split('\t'))
if 'Error' in string_interactions.columns:
raise Exception(string_interactions['ErrorMessage'].values[0])
if len(string_interactions) == 0:
raise Exception('No interactions found. Please revise your input parameters.')
# Remove duplicated interactions
string_interactions.drop_duplicates(inplace=True)
# Make the interactions symmetric: add the interactions where the first and second columns are swapped
string_interactions = pd.concat([string_interactions, string_interactions.rename(
columns={'stringId_A': 'stringId_B', 'stringId_B': 'stringId_A', 'preferredName_A': 'preferredName_B',
'preferredName_B': 'preferredName_A'})])
# Getting the sequences for hparams.genes in case there are proteins with no connections and add ghost self_connections to keep gene names in the file
string_names_input_genes = get_names_from_string(gene_names, species)
string_names_input_genes['stringId_A'] = string_names_input_genes['stringId']
string_names_input_genes['preferredName_A'] = string_names_input_genes['preferredName']
string_names_input_genes['stringId_B'] = string_names_input_genes['stringId']
string_names_input_genes['preferredName_B'] = string_names_input_genes['preferredName']
string_interactions = pd.concat([string_interactions, string_names_input_genes[
['stringId_A', 'preferredName_A', 'stringId_B', 'preferredName_B']]])
string_interactions.fillna(0, inplace=True)
# For all the proteins in the first ans second columns extract their sequences from 9606.protein.sequences.v11.5.fasta and write them to sequences.fasta
ids = list(string_interactions['stringId_A'].values) + list(string_interactions['stringId_B'].values) + \
string_names_input_genes['stringId'].to_list()
ids = set(ids)
with open('sequences.fasta', 'w') as f:
for record in SeqIO.parse('{}.protein.sequences.v{}.fa'.format(species, version), "fasta"):
if record.id in ids:
SeqIO.write(record, f, "fasta")
string_interactions.to_csv('string_interactions.tsv', sep='\t', index=False)
if __name__ == '__main__':
print(generate_dscript_gene_names(
file_path=os.path.join('..', 'STRING_full', 'preprocessed', 'protein.actions_full.tsv'),
only_positives=True,
species='362663'))
\ No newline at end of file
from Bio import SeqIO
import os
import argparse
from senseppi import __version__
import pathlib
import torch
def add_general_args(parser):
parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__))
parser.add_argument(
"fasta_file",
type=pathlib.Path,
help="FASTA file on which to extract the ESM2 representations and then train or test.",
)
parser.add_argument("--min_len", type=int, default=50,
help="Minimum length of the protein sequence. "
"The sequences with smaller length will not be considered.")
......@@ -50,6 +43,13 @@ def process_string_fasta(fasta_file, min_len, max_len):
os.rename('file.tmp', fasta_file)
def get_fasta_ids(fasta_file):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
return ids
def remove_argument(parser, arg):
for action in parser._actions:
opts = action.option_strings
......
......@@ -13,19 +13,23 @@ setup(
url="",
license="MIT",
packages=find_packages(),
long_description=long_description,
long_description_content_type="text/markdown",
include_package_data=True,
install_requires=[
"numpy",
"pandas",
"wget",
"torch>=1.12",
"matplotlib",
"seaborn",
"tqdm",
"scikit-learn",
"pytorch-lightning==1.9.0",
"torchmetrics",
"biopython",
"fair-esm"
"fair-esm",
"mmseqs2"
],
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment