First fully assembled version 0.1.0

Comments, arguments and outdated code have to be cleaned, the version still needs some testing
parent 59a6e327
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef
......@@ -9,51 +11,60 @@ from scipy.cluster.hierarchy import linkage, fcluster
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
from pathlib import Path
import glob
from ..model import SensePPIModel
from ..utils import *
from ..network_utils import *
from ..esm2_model import add_esm_args
from ..esm2_model import add_esm_args, compute_embeddings
from ..dataset import PairSequenceData
def main(hparams):
LABEL_THRESHOLD = hparams.score / 1000.
def main(params):
LABEL_THRESHOLD = params.score / 1000.
PRED_THRESHOLD = params.pred_threshold / 1000.
pairs_file = 'protein.pairs_string.tsv'
fasta_file = 'sequences.fasta'
test_data = DscriptData(emb_dir='esm_emb_3B', max_len=800, dir_path='', actions_file='protein.actions.tsv')
print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len)
generate_pairs_string(fasta_file, output_file=pairs_file, with_self=False, delete_proteins=params.delete_proteins)
actions_path = os.path.join('..', 'Data', 'Dscript', 'preprocessed', 'human_train.tsv')
loadpath = os.path.join('..', DSCRIPT_PATH)
params.fasta_file = fasta_file
compute_embeddings(params)
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=pairs_file,
max_len=params.max_len, labels=False)
model = SensePPIModel(hparams)
pretrained_model = SensePPIModel(params)
if hparams.nogpu:
checkpoint = torch.load(loadpath, map_location=torch.device('cpu'))
if params.device == 'gpu':
checkpoint = torch.load(params.model_path)
elif params.device == 'mps':
checkpoint = torch.load(params.model_path, map_location=torch.device('mps'))
else:
checkpoint = torch.load(loadpath)
checkpoint = torch.load(params.model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator="cpu" if hparams.nogpu else 'gpu', logger=False)
trainer = pl.Trainer(accelerator=params.device, logger=False)
test_loader = DataLoader(dataset=test_data,
batch_size=64,
batch_size=params.batch_size,
num_workers=4)
preds = [pred for batch in trainer.predict(model, test_loader) for pred in batch.squeeze().tolist()]
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds)
# open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.actions.tsv', delimiter='\t', names=["seq1", "seq2", "label"])
data['binary_label'] = data['label'].apply(lambda x: 1 if x > LABEL_THRESHOLD else 0)
data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > LABEL_THRESHOLD else 0)
data['preds'] = preds
if hparams.normalize:
data['preds'] = (data['preds'] - data['preds'].min()) / (data['preds'].max() - data['preds'].min())
print(data.sort_values(by=['preds'], ascending=False).to_string())
data.to_csv(params.output + '.tsv', sep='\t', index=False)
# Calculate torch metrics based on data['binary_label'] and data['preds']
torch_labels = torch.tensor(data['binary_label'])
......@@ -64,9 +75,8 @@ def main(hparams):
print('F1Score: ', F1Score(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('MatthewsCorrCoef: ',
MatthewsCorrCoef(num_classes=2, threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('ROCAUC: ', AUROC()(torch_preds, torch_labels))
print('ROCAUC: ', AUROC(task='binary')(torch_preds, torch_labels))
# Create a dictionary of string ids and gene names from string_interactions_short.tsv
string_ids = {}
string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[
['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']]
......@@ -74,46 +84,43 @@ def main(hparams):
string_ids[row['stringId_A']] = row['preferredName_A']
string_ids[row['stringId_B']] = row['preferredName_B']
print('Fetching gene names for training set from STRING...')
if not os.path.exists('all_genes_train.tsv'):
all_genes = generate_dscript_gene_names(
file_path=actions_path,
only_positives=True,
species=str(hparams.species))
all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False)
else:
all_genes = pd.read_csv('all_genes_train.tsv', sep='\t')
# Create a tuple of gene pairs presented in training data, corrresponding gene names are found in 'genes' DataFrame
full_train_data = pd.read_csv(actions_path,
delimiter='\t', names=['seq1', 'seq2', 'label'])
# To make sure that we do not use the test species in the training data
full_train_data = full_train_data[full_train_data.seq1.str.startswith('6239') == False]
full_train_data = full_train_data[full_train_data.seq2.str.startswith('6239') == False]
if all_genes is not None:
full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge(
all_genes, left_on='seq2', right_on='QueryString', how='left')
full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']]
positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']]
full_train_data = full_train_data[['preferredName_x', 'preferredName_y']]
full_train_data = [tuple(x) for x in full_train_data.values]
positive_train_data = [tuple(x) for x in positive_train_data.values]
else:
full_train_data = None
positive_train_data = None
if not hparams.no_graphs:
# This part was needed to color the pairs belonging to the train data, temporarily removed
# print('Fetching gene names for training set from STRING...')
#
# if not os.path.exists('all_genes_train.tsv'):
# all_genes = generate_dscript_gene_names(
# file_path=actions_path,
# only_positives=True,
# species=str(hparams.species))
# all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False)
# else:
# all_genes = pd.read_csv('all_genes_train.tsv', sep='\t')
# full_train_data = pd.read_csv(actions_path,
# delimiter='\t', names=['seq1', 'seq2', 'label'])
#
# if all_genes is not None:
# full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge(
# all_genes, left_on='seq2', right_on='QueryString', how='left')
#
# full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']]
#
# positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']]
# full_train_data = full_train_data[['preferredName_x', 'preferredName_y']]
#
# full_train_data = [tuple(x) for x in full_train_data.values]
# positive_train_data = [tuple(x) for x in positive_train_data.values]
# else:
# full_train_data = None
# positive_train_data = None
if params.graphs:
# Create two subpolots but make a short gap between them
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2})
# Plot the predictions as matrix but do not sort the labels
data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='label')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='string_label')
# Produce the list of protein names from sequences.fasta file
protein_names = [line.strip()[1:] for line in open('sequences.fasta', 'r') if line.startswith('>')]
# Produce the list of gene names from genes Dataframe
......@@ -128,8 +135,8 @@ def main(hparams):
labels_heatmap.columns = gene_names
# Remove genes that are in hparams.delete_proteins
if hparams.delete_proteins is not None:
for protein in hparams.delete_proteins:
if params.delete_proteins is not None:
for protein in params.delete_proteins:
if protein in data_heatmap.index:
data_heatmap = data_heatmap.drop(protein, axis=0)
data_heatmap = data_heatmap.drop(protein, axis=1)
......@@ -157,12 +164,13 @@ def main(hparams):
np.triu_indices_from(data_heatmap.values)]
labels_heatmap.fillna(value=-1, inplace=True)
# In (labels+data)_heatmap, if a pair of genes is in train data, color it in black in the upper triangle
if full_train_data is not None:
for i, row in labels_heatmap.iterrows():
for j, _ in row.items():
if (i, j) in full_train_data or (j, i) in full_train_data:
labels_heatmap.loc[i, j] = -1
# This part was needed to color the pairs belonging to the train data, temporarily removed
# if full_train_data is not None:
# for i, row in labels_heatmap.iterrows():
# for j, _ in row.items():
# if (i, j) in full_train_data or (j, i) in full_train_data:
# labels_heatmap.loc[i, j] = -1
cmap = matplotlib.cm.get_cmap('coolwarm').copy()
cmap.set_bad("black")
......@@ -178,96 +186,89 @@ def main(hparams):
ax1.set_ylabel('String interactions', weight='bold', fontsize=18)
ax1.set_title('Predictions', weight='bold', fontsize=18)
ax1.yaxis.tick_right()
# ax1.plot([0, len(labels_heatmap)], [0, len(labels_heatmap)], color='white')
for i in range(len(labels_heatmap)):
ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100))
# print(data[data['label'] > 0].sort_values(by=['preds'], ascending=False))
# Build a multigraph from the data with the predictions as edge weights and edges in red and the labels as secondary edges in black
G = nx.Graph()
for i, row in data.iterrows():
if row['label'] > LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='black', weight=row['label'], style='dotted')
if row['string_label'] > LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='black', weight=row['string_label'], style='dotted')
if row['preds'] > PRED_THRESHOLD and G.has_edge(row['seq1'], row['seq2']):
G[row['seq1']][row['seq2']]['style'] = 'solid'
G[row['seq1']][row['seq2']]['color'] = 'limegreen'
if row['preds'] > PRED_THRESHOLD and row['label'] <= LABEL_THRESHOLD:
if row['preds'] > PRED_THRESHOLD and row['string_label'] <= LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid')
# Replace the string ids with gene names
G = nx.relabel_nodes(G, string_ids)
# If edge is present in training data make it blue
if positive_train_data is not None:
for edge in G.edges():
if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data:
print('TRAINING EDGE: ', edge)
G[edge[0]][edge[1]]['color'] = 'darkblue'
# G[edge[0]][edge[1]]['weight'] = 1
# This part was needed to color the pairs belonging to the train data, temporarily removed
# if positive_train_data is not None:
# for edge in G.edges():
# if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data:
# print('TRAINING EDGE: ', edge)
# G[edge[0]][edge[1]]['color'] = 'darkblue'
# # G[edge[0]][edge[1]]['weight'] = 1
# Make nodes red if they are present in training data
for node in G.nodes():
if all_genes is not None and node in all_genes['preferredName'].values:
G.nodes[node]['color'] = 'orange'
else:
G.nodes[node]['color'] = 'lightgrey'
# if all_genes is not None and node in all_genes['preferredName'].values:
# G.nodes[node]['color'] = 'orange'
# else:
G.nodes[node]['color'] = 'lightgrey'
# nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G), edge_labels=nx.get_edge_attributes(G, 'weight'))
# Make edges longer and do not put nodes too close to each other
pos = nx.spring_layout(G, k=2., iterations=100)
# Call the same function nx.draw with the same arguments as above but make sure it is plotted in ax2
nx.draw(G, pos=pos, with_labels=True, ax=ax2,
edge_color=[G[u][v]['color'] for u, v in G.edges()], width=[G[u][v]['weight'] for u, v in G.edges()],
style=[G[u][v]['style'] for u, v in G.edges()],
node_color=[G.nodes[node]['color'] for node in G.nodes()])
# Put a legend for colors
legend_elements = [
Line2D([0], [0], marker='_', color='darkblue', label='PP from training data', markerfacecolor='darkblue',
markersize=10),
Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10),
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10),
Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
markersize=10, linestyle='dotted')]
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10)]
# Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
# markersize=10, linestyle='dotted')]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8)
savepath = 'graph_{}_{}'.format('_'.join(hparams.genes), hparams.species)
if 'Dscript' in actions_path:
savepath += '_Dscript'
if 'HVLSTM' in loadpath:
savepath += '_HVLSTM'
else:
version = re.search('version_([0-9]+)', loadpath)
if version is not None:
savepath += '_v' + version.group(1)
savepath += '.pdf'
savepath = '{}_graph_{}_{}.pdf'.format(params.output, '_'.join(params.genes), params.species)
plt.savefig(savepath, bbox_inches='tight', dpi=600)
print("The graphs were saved in: ", savepath)
print("The graphs were saved to: ", savepath)
plt.show()
plt.close()
os.remove(fasta_file)
os.remove(pairs_file)
for f in glob.glob('{}.protein.sequences*'.format(params.species)):
os.remove(f)
os.remove('string_interactions.tsv')
def add_args(parser):
parser = add_general_args(parser)
parser2 = parser.add_argument_group(title="General options")
parser2.add_argument("--no_graphs", action='store_true', help="No plotting testing graphs.")
parser2.add_argument("-g", "--genes", type=str, nargs="+", default="RFC5",
help="Name of gene to fetch from STRING database. Several names can be typed (separated by whitespaces). Default: RFC5")
parser2.add_argument("-s", "--species", type=int, default=9606,
string_pred_args = parser.add_argument_group(title="General options")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("genes", type=str, nargs="+",
help="Name of gene to fetch from STRING database. Several names can be typed (separated by "
"whitespaces)")
string_pred_args.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)")
parser2.add_argument("-n", "--nodes", type=int, default=10,
string_pred_args.add_argument("-n", "--nodes", type=int, default=10,
help="Number of nodes to fetch from STRING database. Default: 10")
parser2.add_argument("-r", "--score", type=int, default=500,
string_pred_args.add_argument("-r", "--score", type=int, default=0,
help="Score threshold for STRING connections. Range: (0, 1000). Default: 500")
parser2.add_argument("-p", "--pred_threshold", type=int, default=500,
string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500,
help="Prediction threshold. Range: (0, 1000). Default: 500")
parser2.add_argument("--network_type", type=str, default="physical",
string_pred_args.add_argument("--graphs", action='store_true', help="Enables plotting the heatmap and a network graph.")
string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string",
help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)")
string_pred_args.add_argument("--network_type", type=str, default="physical",
help="Network type: \"physical\" or \"functional\". Default: \"physical\"")
parser2.add_argument("--normalize", action='store_true', help="Normalize the predictions.")
parser2.add_argument("--delete_proteins", type=str, nargs="+", default=None,
string_pred_args.add_argument("--delete_proteins", type=str, nargs="+", default=None,
help="List of proteins to delete from the graph. Default: None")
parser = SensePPIModel.add_model_specific_args(parser)
......@@ -277,50 +278,6 @@ def add_args(parser):
return parser
if __name__ == '__main__':
params = parser.parse_args()
if torch.cuda.is_available():
# torch.cuda.set_per_process_memory_fraction(0.9, 0)
print('Number of devices: ', torch.cuda.device_count())
print('GPU used: ', torch.cuda.get_device_name(0))
torch.set_float32_matmul_precision('high')
else:
print('No GPU available, using the CPU instead.')
params.nogpu = True
print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
process_string_fasta('sequences.fasta')
generate_pairs('sequences.fasta', mode='all_to_all', with_self=False, delete_proteins=params.delete_proteins)
# Compute ESM embeddings
# First, check is all embeddings are already computed
params.model_location = 'esm2_t36_3B_UR50D'
params.fasta_file = 'sequences.fasta'
params.output_dir = Path('esm_emb_3B')
params.include = 'per_tok'
with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')]
if not os.path.exists(params.output_dir):
print('Computing ESM embeddings...')
esm_extract.run(params)
else:
for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir, seq_id + '.pt')):
print('Computing ESM embeddings...')
esm_extract.run(params)
break
print('Predicting...')
main(params)
os.remove('sequences.fasta')
os.remove('protein.actions.tsv')
os.remove('string_interactions.tsv')
# srun --gres gpu:1 python network_test.py -n 30 --pred_threshold 500 -r 0 -s 9606 -g C1R RFC5 --delete_proteins C1S RFC3 RFC4 DSCC1 CHTF8 RAD17 RPA1 RPA2
\ No newline at end of file
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args()
\ No newline at end of file
......@@ -37,7 +37,7 @@ def test(params):
def add_args(parser):
parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Predict args")
test_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
......@@ -47,10 +47,10 @@ def add_args(parser):
help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.",
)
predict_args.add_argument("-o", "--output", type=str, default="test_metrics",
test_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)")
predict_args.add_argument("--crop_data_to_model_lims", action="store_true",
test_args.add_argument("--crop_data_to_model_lims", action="store_true",
help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa.")
......
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib
import argparse
from ..utils import add_general_args
from ..model import SensePPIModel
from ..dataset import PairSequenceData
......
......@@ -13,45 +13,44 @@ import gzip
import shutil
def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_proteins=None):
def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
if mode == 'all_to_all':
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
pairs = []
for p in all_pairs:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p)
pairs = []
for p in all_pairs:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p)
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
data = pd.read_csv('string_interactions.tsv', delimiter='\t')
data = pd.read_csv('string_interactions.tsv', delimiter='\t')
# Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
ids_dict.update(dict(zip(data['preferredName_B'], data['stringId_B'])))
# Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
ids_dict.update(dict(zip(data['preferredName_B'], data['stringId_B'])))
data = data[['stringId_A', 'stringId_B', 'score']]
data.columns = ['seq1', 'seq2', 'label']
data = data[['stringId_A', 'stringId_B', 'score']]
data.columns = ['seq1', 'seq2', 'label']
pairs = pairs.merge(data, on=['seq1', 'seq2'], how='left').fillna(0)
pairs = pairs.merge(data, on=['seq1', 'seq2'], how='left').fillna(0)
if delete_proteins is not None:
print('Labels removed: ', delete_proteins)
string_ids_to_delete = []
for label in delete_proteins:
string_ids_to_delete.append(ids_dict[label])
print('String ids to delete: ', string_ids_to_delete)
pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)]
pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)]
if delete_proteins is not None:
print('Labels removed: ', delete_proteins)
string_ids_to_delete = []
for label in delete_proteins:
string_ids_to_delete.append(ids_dict[label])
print('String ids to delete: ', string_ids_to_delete)
pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)]
pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)]
pairs.to_csv('protein.actions.tsv', sep='\t', index=False, header=False)
pairs.to_csv(output_file, sep='\t', index=False, header=False)
def generate_dscript_gene_names(file_path,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment