0.6.5 updated names for tmp files for predict and predict string

parent 04ca5c16
__version__ = "0.6.4"
__version__ = "0.6.5"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -9,6 +9,7 @@ from ..dataset import PairSequenceData
from ..model import SensePPIModel
from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
from datetime import datetime
def predict(params):
......@@ -120,7 +121,8 @@ def get_protein_names(fasta_file):
def main(params):
tmp_pairs = 'senseppi_pairs_for_prediction.tmp'
current_time = str(datetime.now()).replace(' ', '_')
tmp_pairs = current_time + '_senseppi_pairs_for_prediction.tsv.tmp'
try:
fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None:
......
......@@ -19,149 +19,151 @@ def main(params):
label_threshold = params.score / 1000.
pred_threshold = params.pred_threshold / 1000.
pairs_file = 'protein.pairs_string.tsv'
fasta_file = 'sequences.fasta'
print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len)
generate_pairs_string(fasta_file, output_file=pairs_file, delete_proteins=params.delete_proteins)
params.fasta_file = fasta_file
params.pairs_file = pairs_file
compute_embeddings(params)
preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > label_threshold else 0)
data['preds'] = preds
print(data.sort_values(by=['preds'], ascending=False).to_string())
string_ids = {}
string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[
['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']]
for i, row in string_tsv.iterrows():
string_ids[row['stringId_A']] = row['preferredName_A']
string_ids[row['stringId_B']] = row['preferredName_B']
data_to_save = data.copy()
data_to_save['seq1'] = data_to_save['seq1'].apply(lambda x: string_ids[x])
data_to_save['seq2'] = data_to_save['seq2'].apply(lambda x: string_ids[x])
data_to_save = data_to_save.sort_values(by=['preds'], ascending=False)
data_to_save.to_csv(params.output + '.tsv', sep='\t', index=False)
if params.graphs:
# Create two subpolots but make a short gap between them
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2})
# Plot the predictions as matrix but do not sort the labels
data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='string_label')
# Produce the list of protein names from sequences.fasta file
protein_names = [line.strip()[1:] for line in open('sequences.fasta', 'r') if line.startswith('>')]
# Produce the list of gene names from genes Dataframe
gene_names = [string_ids[prot] for prot in protein_names]
# Sort the preds for index and columns based on the order of proteins in the fasta file
data_heatmap = data_heatmap.reindex(index=protein_names, columns=protein_names)
labels_heatmap = labels_heatmap.reindex(index=protein_names, columns=protein_names)
# Replace labels with gene names
data_heatmap.index = gene_names
data_heatmap.columns = gene_names
labels_heatmap.index = gene_names
labels_heatmap.columns = gene_names
# Remove genes that are in hparams.delete_proteins
if params.delete_proteins is not None:
for protein in params.delete_proteins:
if protein in data_heatmap.index:
data_heatmap = data_heatmap.drop(protein, axis=0)
data_heatmap = data_heatmap.drop(protein, axis=1)
labels_heatmap = labels_heatmap.drop(protein, axis=0)
labels_heatmap = labels_heatmap.drop(protein, axis=1)
# Make sure that the matrices are symmetric for clustering
labels_heatmap = labels_heatmap.fillna(value=0)
labels_heatmap = labels_heatmap + labels_heatmap.T
np.fill_diagonal(labels_heatmap.values, -1)
data_heatmap = data_heatmap.fillna(value=0)
data_heatmap = data_heatmap + data_heatmap.T
np.fill_diagonal(data_heatmap.values, -1)
linkages = linkage(labels_heatmap, method='complete', metric='euclidean')
new_labels = np.argsort(fcluster(linkages, 0.05, criterion='distance'))
col_order = labels_heatmap.columns[new_labels]
row_order = labels_heatmap.index[new_labels]
labels_heatmap = labels_heatmap.reindex(index=row_order, columns=col_order)
data_heatmap = data_heatmap.reindex(index=row_order, columns=col_order)
# Fill the upper triangle of labels_heatmap with values from data_heatmap
labels_heatmap.values[np.triu_indices_from(labels_heatmap.values)] = data_heatmap.values[
np.triu_indices_from(data_heatmap.values)]
labels_heatmap.fillna(value=-1, inplace=True)
cmap = matplotlib.cm.get_cmap('coolwarm').copy()
cmap.set_bad("black")
sns.heatmap(labels_heatmap, cmap=cmap, vmin=0, vmax=1,
ax=ax1, mask=labels_heatmap == -1,
cbar=False, square=True)
cbar = ax1.figure.colorbar(ax1.collections[0], ax=ax1, location='right', pad=0.15)
cbar.ax.yaxis.set_ticks_position('right')
ax1.set_yticklabels(ax1.get_yticklabels(), rotation=0)
ax1.set_ylabel('String interactions', weight='bold', fontsize=18)
ax1.set_title('Predictions', weight='bold', fontsize=18)
ax1.yaxis.tick_right()
for i in range(len(labels_heatmap)):
ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100))
pred_graph = nx.Graph()
for i, row in data.iterrows():
if row['string_label'] > label_threshold:
pred_graph.add_edge(row['seq1'], row['seq2'], color='black', weight=row['string_label'], style='dotted')
if row['preds'] > pred_threshold and pred_graph.has_edge(row['seq1'], row['seq2']):
pred_graph[row['seq1']][row['seq2']]['style'] = 'solid'
pred_graph[row['seq1']][row['seq2']]['color'] = 'limegreen'
if row['preds'] > pred_threshold and row['string_label'] <= label_threshold:
pred_graph.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid')
for node in pred_graph.nodes():
pred_graph.nodes[node]['color'] = 'lightgrey'
# Replace the string ids with gene names
pred_graph = nx.relabel_nodes(pred_graph, string_ids)
pos = nx.spring_layout(pred_graph, k=2., iterations=100)
nx.draw(pred_graph, pos=pos, with_labels=True, ax=ax2,
edge_color=[pred_graph[u][v]['color'] for u, v in pred_graph.edges()],
width=[pred_graph[u][v]['weight'] for u, v in pred_graph.edges()],
style=[pred_graph[u][v]['style'] for u, v in pred_graph.edges()],
node_color=[pred_graph.nodes[node]['color'] for node in pred_graph.nodes()])
legend_elements = [
Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10),
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10),
Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
markersize=10, linestyle='dotted')]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8)
save_path = '{}_graph_{}_{}.pdf'.format(params.output, '_'.join(params.genes), params.species)
plt.savefig(save_path, bbox_inches='tight', dpi=600)
print("The graphs were saved to: ", save_path)
plt.show()
plt.close()
os.remove(fasta_file)
os.remove(pairs_file)
for f in glob.glob('{}.protein.sequences*'.format(params.species)):
os.remove(f)
os.remove('string_interactions.tsv')
string_pairs_file, fasta_file = get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
pairs_file = string_pairs_file.replace('.tsv', '_for_pred.tsv')
try:
process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len)
generate_pairs_string(fasta_file, string_pairs_file, output_file=pairs_file, delete_proteins=params.delete_proteins)
params.fasta_file = fasta_file
params.pairs_file = pairs_file
compute_embeddings(params)
preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv(pairs_file, delimiter='\t', names=["seq1", "seq2", "string_label"])
data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > label_threshold else 0)
data['preds'] = preds
print(data.sort_values(by=['preds'], ascending=False).to_string())
string_ids = {}
string_tsv = pd.read_csv(string_pairs_file, delimiter='\t')[
['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']]
for i, row in string_tsv.iterrows():
string_ids[row['stringId_A']] = row['preferredName_A']
string_ids[row['stringId_B']] = row['preferredName_B']
data_to_save = data.copy()
data_to_save['seq1'] = data_to_save['seq1'].apply(lambda x: string_ids[x])
data_to_save['seq2'] = data_to_save['seq2'].apply(lambda x: string_ids[x])
data_to_save = data_to_save.sort_values(by=['preds'], ascending=False)
data_to_save.to_csv(params.output + '.tsv', sep='\t', index=False)
if params.graphs:
# Create two subpolots but make a short gap between them
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2})
# Plot the predictions as matrix but do not sort the labels
data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='string_label')
# Produce the list of protein names from fasta file
protein_names = [line.strip()[1:] for line in open(fasta_file, 'r') if line.startswith('>')]
# Produce the list of gene names from genes Dataframe
gene_names = [string_ids[prot] for prot in protein_names]
# Sort the preds for index and columns based on the order of proteins in the fasta file
data_heatmap = data_heatmap.reindex(index=protein_names, columns=protein_names)
labels_heatmap = labels_heatmap.reindex(index=protein_names, columns=protein_names)
# Replace labels with gene names
data_heatmap.index = gene_names
data_heatmap.columns = gene_names
labels_heatmap.index = gene_names
labels_heatmap.columns = gene_names
# Remove genes that are in hparams.delete_proteins
if params.delete_proteins is not None:
for protein in params.delete_proteins:
if protein in data_heatmap.index:
data_heatmap = data_heatmap.drop(protein, axis=0)
data_heatmap = data_heatmap.drop(protein, axis=1)
labels_heatmap = labels_heatmap.drop(protein, axis=0)
labels_heatmap = labels_heatmap.drop(protein, axis=1)
# Make sure that the matrices are symmetric for clustering
labels_heatmap = labels_heatmap.fillna(value=0)
labels_heatmap = labels_heatmap + labels_heatmap.T
np.fill_diagonal(labels_heatmap.values, -1)
data_heatmap = data_heatmap.fillna(value=0)
data_heatmap = data_heatmap + data_heatmap.T
np.fill_diagonal(data_heatmap.values, -1)
linkages = linkage(labels_heatmap, method='complete', metric='euclidean')
new_labels = np.argsort(fcluster(linkages, 0.05, criterion='distance'))
col_order = labels_heatmap.columns[new_labels]
row_order = labels_heatmap.index[new_labels]
labels_heatmap = labels_heatmap.reindex(index=row_order, columns=col_order)
data_heatmap = data_heatmap.reindex(index=row_order, columns=col_order)
# Fill the upper triangle of labels_heatmap with values from data_heatmap
labels_heatmap.values[np.triu_indices_from(labels_heatmap.values)] = data_heatmap.values[
np.triu_indices_from(data_heatmap.values)]
labels_heatmap.fillna(value=-1, inplace=True)
cmap = matplotlib.cm.get_cmap('coolwarm').copy()
cmap.set_bad("black")
sns.heatmap(labels_heatmap, cmap=cmap, vmin=0, vmax=1,
ax=ax1, mask=labels_heatmap == -1,
cbar=False, square=True)
cbar = ax1.figure.colorbar(ax1.collections[0], ax=ax1, location='right', pad=0.15)
cbar.ax.yaxis.set_ticks_position('right')
ax1.set_yticklabels(ax1.get_yticklabels(), rotation=0)
ax1.set_ylabel('String interactions', weight='bold', fontsize=18)
ax1.set_title('Predictions', weight='bold', fontsize=18)
ax1.yaxis.tick_right()
for i in range(len(labels_heatmap)):
ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100))
pred_graph = nx.Graph()
for i, row in data.iterrows():
if row['string_label'] > label_threshold:
pred_graph.add_edge(row['seq1'], row['seq2'], color='black', weight=row['string_label'], style='dotted')
if row['preds'] > pred_threshold and pred_graph.has_edge(row['seq1'], row['seq2']):
pred_graph[row['seq1']][row['seq2']]['style'] = 'solid'
pred_graph[row['seq1']][row['seq2']]['color'] = 'limegreen'
if row['preds'] > pred_threshold and row['string_label'] <= label_threshold:
pred_graph.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid')
for node in pred_graph.nodes():
pred_graph.nodes[node]['color'] = 'lightgrey'
# Replace the string ids with gene names
pred_graph = nx.relabel_nodes(pred_graph, string_ids)
pos = nx.spring_layout(pred_graph, k=2., iterations=100)
nx.draw(pred_graph, pos=pos, with_labels=True, ax=ax2,
edge_color=[pred_graph[u][v]['color'] for u, v in pred_graph.edges()],
width=[pred_graph[u][v]['weight'] for u, v in pred_graph.edges()],
style=[pred_graph[u][v]['style'] for u, v in pred_graph.edges()],
node_color=[pred_graph.nodes[node]['color'] for node in pred_graph.nodes()])
legend_elements = [
Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10),
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10),
Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
markersize=10, linestyle='dotted')]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8)
save_path = '{}_graph_{}_{}.pdf'.format(params.output, '_'.join(params.genes), params.species)
plt.savefig(save_path, bbox_inches='tight', dpi=600)
print("The graphs were saved to: ", save_path)
plt.show()
plt.close()
except Exception as e:
raise e
finally:
# Remove the temporary files
os.remove(fasta_file)
os.remove(pairs_file)
os.remove(string_pairs_file)
for f in glob.glob('{}.protein.sequences*'.format(params.species)):
os.remove(f)
def add_args(parser):
......
......@@ -7,11 +7,12 @@ import urllib.request
import requests
import gzip
import shutil
from datetime import datetime
DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/"
def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
def generate_pairs_string(fasta_file, pairs_file, output_file, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
......@@ -23,7 +24,7 @@ def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
data = pd.read_csv('string_interactions.tsv', delimiter='\t')
data = pd.read_csv(pairs_file, delimiter='\t')
# Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
......@@ -70,6 +71,11 @@ def get_string_url():
def get_interactions_from_string(gene_names, species=9606, add_nodes=10, required_score=500, network_type='physical'):
current_time = str(datetime.now()).replace(' ', '_')
pairs_file = current_time + '_protein.pairs_string.tsv'
fasta_file = current_time + '_sequences.fasta'
string_api_url, version = get_string_url()
output_format = "tsv"
method = "network"
......@@ -134,11 +140,13 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require
string_names_input_genes['stringId'].to_list()
ids = set(ids)
with open('sequences.fasta', 'w') as f:
with open(fasta_file, 'w') as f:
for record in SeqIO.parse('{}.protein.sequences.v{}.fa'.format(species, version), "fasta"):
if record.id in ids:
SeqIO.write(record, f, "fasta")
string_interactions.to_csv('string_interactions.tsv', sep='\t', index=False)
string_interactions.to_csv(pairs_file, sep='\t', index=False)
return pairs_file, fasta_file
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment