0.1.9 minor bugfix

parent a65a0e1a
...@@ -3,7 +3,6 @@ import pytorch_lightning as pl ...@@ -3,7 +3,6 @@ import pytorch_lightning as pl
from itertools import permutations, product from itertools import permutations, product
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import logging
import pathlib import pathlib
import argparse import argparse
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
...@@ -97,11 +96,7 @@ def main(params): ...@@ -97,11 +96,7 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Predicting...') logging.info('Predicting...')
preds = predict(params) preds = predict(params)
...@@ -116,8 +111,8 @@ def main(params): ...@@ -116,8 +111,8 @@ def main(params):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() pred_parser = argparse.ArgumentParser()
parser = add_args(parser) pred_parser = add_args(pred_parser)
params = parser.parse_args() pred_params = pred_parser.parse_args()
main(params) main(pred_params)
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef
import networkx as nx import networkx as nx
import seaborn as sns import seaborn as sns
...@@ -10,19 +8,17 @@ from matplotlib.patches import Rectangle ...@@ -10,19 +8,17 @@ from matplotlib.patches import Rectangle
import argparse import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import glob import glob
import logging
from ..model import SensePPIModel from ..model import SensePPIModel
from ..utils import * from ..utils import *
from ..network_utils import * from ..network_utils import *
from ..esm2_model import add_esm_args, compute_embeddings from ..esm2_model import add_esm_args, compute_embeddings
from ..dataset import PairSequenceData
from .predict import predict from .predict import predict
def main(params): def main(params):
LABEL_THRESHOLD = params.score / 1000. label_threshold = params.score / 1000.
PRED_THRESHOLD = params.pred_threshold / 1000. pred_threshold = params.pred_threshold / 1000.
pairs_file = 'protein.pairs_string.tsv' pairs_file = 'protein.pairs_string.tsv'
fasta_file = 'sequences.fasta' fasta_file = 'sequences.fasta'
...@@ -30,22 +26,18 @@ def main(params): ...@@ -30,22 +26,18 @@ def main(params):
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes, get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type) required_score=params.score, network_type=params.network_type)
process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len) process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len)
generate_pairs_string(fasta_file, output_file=pairs_file, with_self=False, delete_proteins=params.delete_proteins) generate_pairs_string(fasta_file, output_file=pairs_file, delete_proteins=params.delete_proteins)
params.fasta_file = fasta_file params.fasta_file = fasta_file
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
preds = predict(params) preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions # open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"]) data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > LABEL_THRESHOLD else 0) data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > label_threshold else 0)
data['preds'] = preds data['preds'] = preds
print(data.sort_values(by=['preds'], ascending=False).to_string()) print(data.sort_values(by=['preds'], ascending=False).to_string())
...@@ -53,13 +45,18 @@ def main(params): ...@@ -53,13 +45,18 @@ def main(params):
# Calculate torch metrics based on data['binary_label'] and data['preds'] # Calculate torch metrics based on data['binary_label'] and data['preds']
torch_labels = torch.tensor(data['binary_label']) torch_labels = torch.tensor(data['binary_label'])
torch_preds = torch.tensor(data['preds']) torch_preds = torch.tensor(data['preds'])
print('Accuracy: ', Accuracy(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) print('Accuracy: ',
print('Precision: ', Precision(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) Accuracy(threshold=pred_threshold, task='binary')(torch_preds, torch_labels))
print('Recall: ', Recall(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) print('Precision: ',
print('F1Score: ', F1Score(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) Precision(threshold=pred_threshold, task='binary')(torch_preds, torch_labels))
print('Recall: ',
Recall(threshold=pred_threshold, task='binary')(torch_preds, torch_labels))
print('F1Score: ',
F1Score(threshold=pred_threshold, task='binary')(torch_preds, torch_labels))
print('MatthewsCorrCoef: ', print('MatthewsCorrCoef: ',
MatthewsCorrCoef(num_classes=2, threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) MatthewsCorrCoef(num_classes=2, threshold=pred_threshold, task='binary')(torch_preds, torch_labels))
print('ROCAUC: ', AUROC(task='binary')(torch_preds, torch_labels)) print('ROCAUC: ',
AUROC(task='binary')(torch_preds, torch_labels))
string_ids = {} string_ids = {}
string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[ string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[
...@@ -74,37 +71,6 @@ def main(params): ...@@ -74,37 +71,6 @@ def main(params):
data_to_save = data_to_save.sort_values(by=['preds'], ascending=False) data_to_save = data_to_save.sort_values(by=['preds'], ascending=False)
data_to_save.to_csv(params.output + '.tsv', sep='\t', index=False) data_to_save.to_csv(params.output + '.tsv', sep='\t', index=False)
# This part was needed to color the pairs belonging to the train data, temporarily removed
# print('Fetching gene names for training set from STRING...')
#
# if not os.path.exists('all_genes_train.tsv'):
# all_genes = generate_dscript_gene_names(
# file_path=actions_path,
# only_positives=True,
# species=str(hparams.species))
# all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False)
# else:
# all_genes = pd.read_csv('all_genes_train.tsv', sep='\t')
# full_train_data = pd.read_csv(actions_path,
# delimiter='\t', names=['seq1', 'seq2', 'label'])
#
# if all_genes is not None:
# full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge(
# all_genes, left_on='seq2', right_on='QueryString', how='left')
#
# full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']]
#
# positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']]
# full_train_data = full_train_data[['preferredName_x', 'preferredName_y']]
#
# full_train_data = [tuple(x) for x in full_train_data.values]
# positive_train_data = [tuple(x) for x in positive_train_data.values]
# else:
# full_train_data = None
# positive_train_data = None
if params.graphs: if params.graphs:
# Create two subpolots but make a short gap between them # Create two subpolots but make a short gap between them
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2}) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2})
...@@ -154,14 +120,6 @@ def main(params): ...@@ -154,14 +120,6 @@ def main(params):
np.triu_indices_from(data_heatmap.values)] np.triu_indices_from(data_heatmap.values)]
labels_heatmap.fillna(value=-1, inplace=True) labels_heatmap.fillna(value=-1, inplace=True)
# This part was needed to color the pairs belonging to the train data, temporarily removed
# if full_train_data is not None:
# for i, row in labels_heatmap.iterrows():
# for j, _ in row.items():
# if (i, j) in full_train_data or (j, i) in full_train_data:
# labels_heatmap.loc[i, j] = -1
cmap = matplotlib.cm.get_cmap('coolwarm').copy() cmap = matplotlib.cm.get_cmap('coolwarm').copy()
cmap.set_bad("black") cmap.set_bad("black")
...@@ -179,53 +137,40 @@ def main(params): ...@@ -179,53 +137,40 @@ def main(params):
for i in range(len(labels_heatmap)): for i in range(len(labels_heatmap)):
ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100)) ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100))
G = nx.Graph() pred_graph = nx.Graph()
for i, row in data.iterrows(): for i, row in data.iterrows():
if row['string_label'] > LABEL_THRESHOLD: if row['string_label'] > label_threshold:
G.add_edge(row['seq1'], row['seq2'], color='black', weight=row['string_label'], style='dotted') pred_graph.add_edge(row['seq1'], row['seq2'], color='black', weight=row['string_label'], style='dotted')
if row['preds'] > PRED_THRESHOLD and G.has_edge(row['seq1'], row['seq2']): if row['preds'] > pred_threshold and pred_graph.has_edge(row['seq1'], row['seq2']):
G[row['seq1']][row['seq2']]['style'] = 'solid' pred_graph[row['seq1']][row['seq2']]['style'] = 'solid'
G[row['seq1']][row['seq2']]['color'] = 'limegreen' pred_graph[row['seq1']][row['seq2']]['color'] = 'limegreen'
if row['preds'] > PRED_THRESHOLD and row['string_label'] <= LABEL_THRESHOLD: if row['preds'] > pred_threshold and row['string_label'] <= label_threshold:
G.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid') pred_graph.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid')
for node in pred_graph.nodes():
pred_graph.nodes[node]['color'] = 'lightgrey'
# Replace the string ids with gene names # Replace the string ids with gene names
G = nx.relabel_nodes(G, string_ids) pred_graph = nx.relabel_nodes(pred_graph, string_ids)
# This part was needed to color the pairs belonging to the train data, temporarily removed pos = nx.spring_layout(pred_graph, k=2., iterations=100)
nx.draw(pred_graph, pos=pos, with_labels=True, ax=ax2,
# if positive_train_data is not None: edge_color=[pred_graph[u][v]['color'] for u, v in pred_graph.edges()],
# for edge in G.edges(): width=[pred_graph[u][v]['weight'] for u, v in pred_graph.edges()],
# if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data: style=[pred_graph[u][v]['style'] for u, v in pred_graph.edges()],
# print('TRAINING EDGE: ', edge) node_color=[pred_graph.nodes[node]['color'] for node in pred_graph.nodes()])
# G[edge[0]][edge[1]]['color'] = 'darkblue'
# # G[edge[0]][edge[1]]['weight'] = 1
# Make nodes red if they are present in training data
for node in G.nodes():
# if all_genes is not None and node in all_genes['preferredName'].values:
# G.nodes[node]['color'] = 'orange'
# else:
G.nodes[node]['color'] = 'lightgrey'
pos = nx.spring_layout(G, k=2., iterations=100)
nx.draw(G, pos=pos, with_labels=True, ax=ax2,
edge_color=[G[u][v]['color'] for u, v in G.edges()], width=[G[u][v]['weight'] for u, v in G.edges()],
style=[G[u][v]['style'] for u, v in G.edges()],
node_color=[G.nodes[node]['color'] for node in G.nodes()])
legend_elements = [ legend_elements = [
# Line2D([0], [0], marker='_', color='darkblue', label='PP from training data', markerfacecolor='darkblue',
# markersize=10),
Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10), Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10),
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10), Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10),
Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black', Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
markersize=10, linestyle='dotted')] markersize=10, linestyle='dotted')]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8) plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8)
savepath = '{}_graph_{}_{}.pdf'.format(params.output, '_'.join(params.genes), params.species) save_path = '{}_graph_{}_{}.pdf'.format(params.output, '_'.join(params.genes), params.species)
plt.savefig(savepath, bbox_inches='tight', dpi=600) plt.savefig(save_path, bbox_inches='tight', dpi=600)
print("The graphs were saved to: ", savepath) print("The graphs were saved to: ", save_path)
plt.show() plt.show()
plt.close() plt.close()
...@@ -235,6 +180,7 @@ def main(params): ...@@ -235,6 +180,7 @@ def main(params):
os.remove(f) os.remove(f)
os.remove('string_interactions.tsv') os.remove('string_interactions.tsv')
def add_args(parser): def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
...@@ -242,8 +188,8 @@ def add_args(parser): ...@@ -242,8 +188,8 @@ def add_args(parser):
parser._action_groups[0].add_argument("model_path", type=str, parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.") help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("genes", type=str, nargs="+", parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be typed (separated by " help="Name of gene to fetch from STRING database. Several names can be "
"whitespaces)") "typed (separated by whitespaces).")
string_pred_args.add_argument("-s", "--species", type=int, default=9606, string_pred_args.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)") help="Species from STRING database. Default: 9606 (H. Sapiens)")
string_pred_args.add_argument("-n", "--nodes", type=int, default=10, string_pred_args.add_argument("-n", "--nodes", type=int, default=10,
...@@ -252,7 +198,8 @@ def add_args(parser): ...@@ -252,7 +198,8 @@ def add_args(parser):
help="Score threshold for STRING connections. Range: (0, 1000). Default: 500") help="Score threshold for STRING connections. Range: (0, 1000). Default: 500")
string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500, string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500,
help="Prediction threshold. Range: (0, 1000). Default: 500") help="Prediction threshold. Range: (0, 1000). Default: 500")
string_pred_args.add_argument("--graphs", action='store_true', help="Enables plotting the heatmap and a network graph.") string_pred_args.add_argument("--graphs", action='store_true',
help="Enables plotting the heatmap and a network graph.")
string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string", string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string",
help="A path to a file where the predictions will be saved. " help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
...@@ -269,6 +216,8 @@ def add_args(parser): ...@@ -269,6 +216,8 @@ def add_args(parser):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() pred_parser = argparse.ArgumentParser()
parser = add_args(parser) pred_parser = add_args(pred_parser)
params = parser.parse_args() pred_params = pred_parser.parse_args()
\ No newline at end of file
main(pred_params)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
import pandas as pd import pandas as pd
import logging
import pathlib import pathlib
import argparse import argparse
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
...@@ -72,11 +71,7 @@ def main(params): ...@@ -72,11 +71,7 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Evaluating...') logging.info('Evaluating...')
test_metrics = test(params)[0] test_metrics = test(params)[0]
...@@ -87,7 +82,7 @@ def main(params): ...@@ -87,7 +82,7 @@ def main(params):
if __name__ == '__main__': if __name__ == '__main__':
test_parser = argparse.ArgumentParser() test_parser = argparse.ArgumentParser()
parser = add_args(test_parser) test_parser = add_args(test_parser)
params = test_parser.parse_args() test_params = test_parser.parse_args()
main(params) main(test_params)
...@@ -2,8 +2,7 @@ import pytorch_lightning as pl ...@@ -2,8 +2,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib import pathlib
import argparse import argparse
import logging from ..utils import *
from ..utils import add_general_args
from ..model import SensePPIModel from ..model import SensePPIModel
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
from ..esm2_model import add_esm_args, compute_embeddings from ..esm2_model import add_esm_args, compute_embeddings
...@@ -15,11 +14,7 @@ def main(params): ...@@ -15,11 +14,7 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file, dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True) max_len=params.max_len, labels=True)
...@@ -80,8 +75,8 @@ def add_args(parser): ...@@ -80,8 +75,8 @@ def add_args(parser):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() train_parser = argparse.ArgumentParser()
parser = add_args(parser) train_parser = add_args(train_parser)
params = parser.parse_args() train_params = train_parser.parse_args()
main(params) main(train_params)
\ No newline at end of file \ No newline at end of file
...@@ -15,18 +15,13 @@ import shutil ...@@ -15,18 +15,13 @@ import shutil
DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/" DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/"
def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None): def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
ids = [] ids = []
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id) ids.append(record.id)
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
pairs = [] pairs = []
for p in all_pairs: for p in [p for p in permutations(ids, 2)]:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs: if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p) pairs.append(p)
......
...@@ -2,6 +2,7 @@ from Bio import SeqIO ...@@ -2,6 +2,7 @@ from Bio import SeqIO
import os import os
from senseppi import __version__ from senseppi import __version__
import torch import torch
import logging
def add_general_args(parser): def add_general_args(parser):
...@@ -29,6 +30,18 @@ def determine_device(): ...@@ -29,6 +30,18 @@ def determine_device():
return device return device
def block_mps(params):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if hasattr(params, 'device'):
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
if torch.cuda.is_available():
params.device = 'gpu'
else:
params.device = 'cpu'
def process_string_fasta(fasta_file, min_len, max_len): def process_string_fasta(fasta_file, min_len, max_len):
with open('file.tmp', 'w') as f: with open('file.tmp', 'w') as f:
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment