First fully assembled version 0.1.0

Comments, arguments and outdated code have to be cleaned, the version still needs some testing
parent 59a6e327
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef
...@@ -9,51 +11,60 @@ from scipy.cluster.hierarchy import linkage, fcluster ...@@ -9,51 +11,60 @@ from scipy.cluster.hierarchy import linkage, fcluster
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path import glob
from ..model import SensePPIModel from ..model import SensePPIModel
from ..utils import * from ..utils import *
from ..network_utils import * from ..network_utils import *
from ..esm2_model import add_esm_args from ..esm2_model import add_esm_args, compute_embeddings
from ..dataset import PairSequenceData
def main(hparams): def main(params):
LABEL_THRESHOLD = hparams.score / 1000. LABEL_THRESHOLD = params.score / 1000.
PRED_THRESHOLD = params.pred_threshold / 1000. PRED_THRESHOLD = params.pred_threshold / 1000.
pairs_file = 'protein.pairs_string.tsv'
fasta_file = 'sequences.fasta'
test_data = DscriptData(emb_dir='esm_emb_3B', max_len=800, dir_path='', actions_file='protein.actions.tsv') print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len)
generate_pairs_string(fasta_file, output_file=pairs_file, with_self=False, delete_proteins=params.delete_proteins)
actions_path = os.path.join('..', 'Data', 'Dscript', 'preprocessed', 'human_train.tsv') params.fasta_file = fasta_file
loadpath = os.path.join('..', DSCRIPT_PATH) compute_embeddings(params)
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=pairs_file,
max_len=params.max_len, labels=False)
model = SensePPIModel(hparams) pretrained_model = SensePPIModel(params)
if hparams.nogpu: if params.device == 'gpu':
checkpoint = torch.load(loadpath, map_location=torch.device('cpu')) checkpoint = torch.load(params.model_path)
elif params.device == 'mps':
checkpoint = torch.load(params.model_path, map_location=torch.device('mps'))
else: else:
checkpoint = torch.load(loadpath) checkpoint = torch.load(params.model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict']) pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator="cpu" if hparams.nogpu else 'gpu', logger=False) trainer = pl.Trainer(accelerator=params.device, logger=False)
test_loader = DataLoader(dataset=test_data, test_loader = DataLoader(dataset=test_data,
batch_size=64, batch_size=params.batch_size,
num_workers=4) num_workers=4)
preds = [pred for batch in trainer.predict(model, test_loader) for pred in batch.squeeze().tolist()] preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds) preds = np.asarray(preds)
# open the actions tsv file as dataframe and add the last column with the predictions # open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.actions.tsv', delimiter='\t', names=["seq1", "seq2", "label"]) data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
data['binary_label'] = data['label'].apply(lambda x: 1 if x > LABEL_THRESHOLD else 0) data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > LABEL_THRESHOLD else 0)
data['preds'] = preds data['preds'] = preds
if hparams.normalize:
data['preds'] = (data['preds'] - data['preds'].min()) / (data['preds'].max() - data['preds'].min())
print(data.sort_values(by=['preds'], ascending=False).to_string()) print(data.sort_values(by=['preds'], ascending=False).to_string())
data.to_csv(params.output + '.tsv', sep='\t', index=False)
# Calculate torch metrics based on data['binary_label'] and data['preds'] # Calculate torch metrics based on data['binary_label'] and data['preds']
torch_labels = torch.tensor(data['binary_label']) torch_labels = torch.tensor(data['binary_label'])
...@@ -64,9 +75,8 @@ def main(hparams): ...@@ -64,9 +75,8 @@ def main(hparams):
print('F1Score: ', F1Score(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) print('F1Score: ', F1Score(threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('MatthewsCorrCoef: ', print('MatthewsCorrCoef: ',
MatthewsCorrCoef(num_classes=2, threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels)) MatthewsCorrCoef(num_classes=2, threshold=PRED_THRESHOLD, task='binary')(torch_preds, torch_labels))
print('ROCAUC: ', AUROC()(torch_preds, torch_labels)) print('ROCAUC: ', AUROC(task='binary')(torch_preds, torch_labels))
# Create a dictionary of string ids and gene names from string_interactions_short.tsv
string_ids = {} string_ids = {}
string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[ string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[
['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']] ['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']]
...@@ -74,46 +84,43 @@ def main(hparams): ...@@ -74,46 +84,43 @@ def main(hparams):
string_ids[row['stringId_A']] = row['preferredName_A'] string_ids[row['stringId_A']] = row['preferredName_A']
string_ids[row['stringId_B']] = row['preferredName_B'] string_ids[row['stringId_B']] = row['preferredName_B']
print('Fetching gene names for training set from STRING...') # This part was needed to color the pairs belonging to the train data, temporarily removed
if not os.path.exists('all_genes_train.tsv'): # print('Fetching gene names for training set from STRING...')
all_genes = generate_dscript_gene_names( #
file_path=actions_path, # if not os.path.exists('all_genes_train.tsv'):
only_positives=True, # all_genes = generate_dscript_gene_names(
species=str(hparams.species)) # file_path=actions_path,
all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False) # only_positives=True,
else: # species=str(hparams.species))
all_genes = pd.read_csv('all_genes_train.tsv', sep='\t') # all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False)
# else:
# Create a tuple of gene pairs presented in training data, corrresponding gene names are found in 'genes' DataFrame # all_genes = pd.read_csv('all_genes_train.tsv', sep='\t')
full_train_data = pd.read_csv(actions_path,
delimiter='\t', names=['seq1', 'seq2', 'label']) # full_train_data = pd.read_csv(actions_path,
# delimiter='\t', names=['seq1', 'seq2', 'label'])
# To make sure that we do not use the test species in the training data #
full_train_data = full_train_data[full_train_data.seq1.str.startswith('6239') == False] # if all_genes is not None:
full_train_data = full_train_data[full_train_data.seq2.str.startswith('6239') == False] # full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge(
# all_genes, left_on='seq2', right_on='QueryString', how='left')
if all_genes is not None: #
full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge( # full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']]
all_genes, left_on='seq2', right_on='QueryString', how='left') #
# positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']]
full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']] # full_train_data = full_train_data[['preferredName_x', 'preferredName_y']]
#
positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']] # full_train_data = [tuple(x) for x in full_train_data.values]
full_train_data = full_train_data[['preferredName_x', 'preferredName_y']] # positive_train_data = [tuple(x) for x in positive_train_data.values]
# else:
full_train_data = [tuple(x) for x in full_train_data.values] # full_train_data = None
positive_train_data = [tuple(x) for x in positive_train_data.values] # positive_train_data = None
else:
full_train_data = None if params.graphs:
positive_train_data = None
if not hparams.no_graphs:
# Create two subpolots but make a short gap between them # Create two subpolots but make a short gap between them
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2}) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.2})
# Plot the predictions as matrix but do not sort the labels # Plot the predictions as matrix but do not sort the labels
data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds') data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='label') labels_heatmap = data.pivot(index='seq1', columns='seq2', values='string_label')
# Produce the list of protein names from sequences.fasta file # Produce the list of protein names from sequences.fasta file
protein_names = [line.strip()[1:] for line in open('sequences.fasta', 'r') if line.startswith('>')] protein_names = [line.strip()[1:] for line in open('sequences.fasta', 'r') if line.startswith('>')]
# Produce the list of gene names from genes Dataframe # Produce the list of gene names from genes Dataframe
...@@ -128,8 +135,8 @@ def main(hparams): ...@@ -128,8 +135,8 @@ def main(hparams):
labels_heatmap.columns = gene_names labels_heatmap.columns = gene_names
# Remove genes that are in hparams.delete_proteins # Remove genes that are in hparams.delete_proteins
if hparams.delete_proteins is not None: if params.delete_proteins is not None:
for protein in hparams.delete_proteins: for protein in params.delete_proteins:
if protein in data_heatmap.index: if protein in data_heatmap.index:
data_heatmap = data_heatmap.drop(protein, axis=0) data_heatmap = data_heatmap.drop(protein, axis=0)
data_heatmap = data_heatmap.drop(protein, axis=1) data_heatmap = data_heatmap.drop(protein, axis=1)
...@@ -157,12 +164,13 @@ def main(hparams): ...@@ -157,12 +164,13 @@ def main(hparams):
np.triu_indices_from(data_heatmap.values)] np.triu_indices_from(data_heatmap.values)]
labels_heatmap.fillna(value=-1, inplace=True) labels_heatmap.fillna(value=-1, inplace=True)
# In (labels+data)_heatmap, if a pair of genes is in train data, color it in black in the upper triangle # This part was needed to color the pairs belonging to the train data, temporarily removed
if full_train_data is not None:
for i, row in labels_heatmap.iterrows(): # if full_train_data is not None:
for j, _ in row.items(): # for i, row in labels_heatmap.iterrows():
if (i, j) in full_train_data or (j, i) in full_train_data: # for j, _ in row.items():
labels_heatmap.loc[i, j] = -1 # if (i, j) in full_train_data or (j, i) in full_train_data:
# labels_heatmap.loc[i, j] = -1
cmap = matplotlib.cm.get_cmap('coolwarm').copy() cmap = matplotlib.cm.get_cmap('coolwarm').copy()
cmap.set_bad("black") cmap.set_bad("black")
...@@ -178,96 +186,89 @@ def main(hparams): ...@@ -178,96 +186,89 @@ def main(hparams):
ax1.set_ylabel('String interactions', weight='bold', fontsize=18) ax1.set_ylabel('String interactions', weight='bold', fontsize=18)
ax1.set_title('Predictions', weight='bold', fontsize=18) ax1.set_title('Predictions', weight='bold', fontsize=18)
ax1.yaxis.tick_right() ax1.yaxis.tick_right()
# ax1.plot([0, len(labels_heatmap)], [0, len(labels_heatmap)], color='white')
for i in range(len(labels_heatmap)): for i in range(len(labels_heatmap)):
ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100)) ax1.add_patch(Rectangle((i, i), 1, 1, fill=True, color='white', alpha=1, zorder=100))
# print(data[data['label'] > 0].sort_values(by=['preds'], ascending=False))
# Build a multigraph from the data with the predictions as edge weights and edges in red and the labels as secondary edges in black
G = nx.Graph() G = nx.Graph()
for i, row in data.iterrows(): for i, row in data.iterrows():
if row['label'] > LABEL_THRESHOLD: if row['string_label'] > LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='black', weight=row['label'], style='dotted') G.add_edge(row['seq1'], row['seq2'], color='black', weight=row['string_label'], style='dotted')
if row['preds'] > PRED_THRESHOLD and G.has_edge(row['seq1'], row['seq2']): if row['preds'] > PRED_THRESHOLD and G.has_edge(row['seq1'], row['seq2']):
G[row['seq1']][row['seq2']]['style'] = 'solid' G[row['seq1']][row['seq2']]['style'] = 'solid'
G[row['seq1']][row['seq2']]['color'] = 'limegreen' G[row['seq1']][row['seq2']]['color'] = 'limegreen'
if row['preds'] > PRED_THRESHOLD and row['label'] <= LABEL_THRESHOLD: if row['preds'] > PRED_THRESHOLD and row['string_label'] <= LABEL_THRESHOLD:
G.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid') G.add_edge(row['seq1'], row['seq2'], color='red', weight=row['preds'], style='solid')
# Replace the string ids with gene names # Replace the string ids with gene names
G = nx.relabel_nodes(G, string_ids) G = nx.relabel_nodes(G, string_ids)
# If edge is present in training data make it blue # This part was needed to color the pairs belonging to the train data, temporarily removed
if positive_train_data is not None:
for edge in G.edges(): # if positive_train_data is not None:
if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data: # for edge in G.edges():
print('TRAINING EDGE: ', edge) # if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data:
G[edge[0]][edge[1]]['color'] = 'darkblue' # print('TRAINING EDGE: ', edge)
# G[edge[0]][edge[1]]['weight'] = 1 # G[edge[0]][edge[1]]['color'] = 'darkblue'
# # G[edge[0]][edge[1]]['weight'] = 1
# Make nodes red if they are present in training data # Make nodes red if they are present in training data
for node in G.nodes(): for node in G.nodes():
if all_genes is not None and node in all_genes['preferredName'].values: # if all_genes is not None and node in all_genes['preferredName'].values:
G.nodes[node]['color'] = 'orange' # G.nodes[node]['color'] = 'orange'
else: # else:
G.nodes[node]['color'] = 'lightgrey' G.nodes[node]['color'] = 'lightgrey'
# nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G), edge_labels=nx.get_edge_attributes(G, 'weight'))
# Make edges longer and do not put nodes too close to each other
pos = nx.spring_layout(G, k=2., iterations=100) pos = nx.spring_layout(G, k=2., iterations=100)
# Call the same function nx.draw with the same arguments as above but make sure it is plotted in ax2
nx.draw(G, pos=pos, with_labels=True, ax=ax2, nx.draw(G, pos=pos, with_labels=True, ax=ax2,
edge_color=[G[u][v]['color'] for u, v in G.edges()], width=[G[u][v]['weight'] for u, v in G.edges()], edge_color=[G[u][v]['color'] for u, v in G.edges()], width=[G[u][v]['weight'] for u, v in G.edges()],
style=[G[u][v]['style'] for u, v in G.edges()], style=[G[u][v]['style'] for u, v in G.edges()],
node_color=[G.nodes[node]['color'] for node in G.nodes()]) node_color=[G.nodes[node]['color'] for node in G.nodes()])
# Put a legend for colors
legend_elements = [ legend_elements = [
Line2D([0], [0], marker='_', color='darkblue', label='PP from training data', markerfacecolor='darkblue', Line2D([0], [0], marker='_', color='darkblue', label='PP from training data', markerfacecolor='darkblue',
markersize=10), markersize=10),
Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10), Line2D([0], [0], marker='_', color='limegreen', label='PP', markerfacecolor='limegreen', markersize=10),
Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10), Line2D([0], [0], marker='_', color='red', label='FP', markerfacecolor='red', markersize=10)]
Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black', # Line2D([0], [0], marker='_', color='black', label='FN - based on STRING', markerfacecolor='black',
markersize=10, linestyle='dotted')] # markersize=10, linestyle='dotted')]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8) plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 0.0), ncol=1, fontsize=8)
savepath = 'graph_{}_{}'.format('_'.join(hparams.genes), hparams.species) savepath = '{}_graph_{}_{}.pdf'.format(params.output, '_'.join(params.genes), params.species)
if 'Dscript' in actions_path:
savepath += '_Dscript'
if 'HVLSTM' in loadpath:
savepath += '_HVLSTM'
else:
version = re.search('version_([0-9]+)', loadpath)
if version is not None:
savepath += '_v' + version.group(1)
savepath += '.pdf'
plt.savefig(savepath, bbox_inches='tight', dpi=600) plt.savefig(savepath, bbox_inches='tight', dpi=600)
print("The graphs were saved in: ", savepath) print("The graphs were saved to: ", savepath)
plt.show() plt.show()
plt.close() plt.close()
os.remove(fasta_file)
os.remove(pairs_file)
for f in glob.glob('{}.protein.sequences*'.format(params.species)):
os.remove(f)
os.remove('string_interactions.tsv')
def add_args(parser): def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
parser2 = parser.add_argument_group(title="General options") string_pred_args = parser.add_argument_group(title="General options")
parser2.add_argument("--no_graphs", action='store_true', help="No plotting testing graphs.") parser._action_groups[0].add_argument("model_path", type=str,
parser2.add_argument("-g", "--genes", type=str, nargs="+", default="RFC5", help="A path to .ckpt file that contains weights to a pretrained model.")
help="Name of gene to fetch from STRING database. Several names can be typed (separated by whitespaces). Default: RFC5") parser._action_groups[0].add_argument("genes", type=str, nargs="+",
parser2.add_argument("-s", "--species", type=int, default=9606, help="Name of gene to fetch from STRING database. Several names can be typed (separated by "
"whitespaces)")
string_pred_args.add_argument("-s", "--species", type=int, default=9606,
help="Species from STRING database. Default: 9606 (H. Sapiens)") help="Species from STRING database. Default: 9606 (H. Sapiens)")
parser2.add_argument("-n", "--nodes", type=int, default=10, string_pred_args.add_argument("-n", "--nodes", type=int, default=10,
help="Number of nodes to fetch from STRING database. Default: 10") help="Number of nodes to fetch from STRING database. Default: 10")
parser2.add_argument("-r", "--score", type=int, default=500, string_pred_args.add_argument("-r", "--score", type=int, default=0,
help="Score threshold for STRING connections. Range: (0, 1000). Default: 500") help="Score threshold for STRING connections. Range: (0, 1000). Default: 500")
parser2.add_argument("-p", "--pred_threshold", type=int, default=500, string_pred_args.add_argument("-p", "--pred_threshold", type=int, default=500,
help="Prediction threshold. Range: (0, 1000). Default: 500") help="Prediction threshold. Range: (0, 1000). Default: 500")
parser2.add_argument("--network_type", type=str, default="physical", string_pred_args.add_argument("--graphs", action='store_true', help="Enables plotting the heatmap and a network graph.")
string_pred_args.add_argument("-o", "--output", type=str, default="preds_from_string",
help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)")
string_pred_args.add_argument("--network_type", type=str, default="physical",
help="Network type: \"physical\" or \"functional\". Default: \"physical\"") help="Network type: \"physical\" or \"functional\". Default: \"physical\"")
parser2.add_argument("--normalize", action='store_true', help="Normalize the predictions.") string_pred_args.add_argument("--delete_proteins", type=str, nargs="+", default=None,
parser2.add_argument("--delete_proteins", type=str, nargs="+", default=None,
help="List of proteins to delete from the graph. Default: None") help="List of proteins to delete from the graph. Default: None")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
...@@ -277,50 +278,6 @@ def add_args(parser): ...@@ -277,50 +278,6 @@ def add_args(parser):
return parser return parser
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args() params = parser.parse_args()
\ No newline at end of file
if torch.cuda.is_available():
# torch.cuda.set_per_process_memory_fraction(0.9, 0)
print('Number of devices: ', torch.cuda.device_count())
print('GPU used: ', torch.cuda.get_device_name(0))
torch.set_float32_matmul_precision('high')
else:
print('No GPU available, using the CPU instead.')
params.nogpu = True
print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
process_string_fasta('sequences.fasta')
generate_pairs('sequences.fasta', mode='all_to_all', with_self=False, delete_proteins=params.delete_proteins)
# Compute ESM embeddings
# First, check is all embeddings are already computed
params.model_location = 'esm2_t36_3B_UR50D'
params.fasta_file = 'sequences.fasta'
params.output_dir = Path('esm_emb_3B')
params.include = 'per_tok'
with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')]
if not os.path.exists(params.output_dir):
print('Computing ESM embeddings...')
esm_extract.run(params)
else:
for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir, seq_id + '.pt')):
print('Computing ESM embeddings...')
esm_extract.run(params)
break
print('Predicting...')
main(params)
os.remove('sequences.fasta')
os.remove('protein.actions.tsv')
os.remove('string_interactions.tsv')
# srun --gres gpu:1 python network_test.py -n 30 --pred_threshold 500 -r 0 -s 9606 -g C1R RFC5 --delete_proteins C1S RFC3 RFC4 DSCC1 CHTF8 RAD17 RPA1 RPA2
\ No newline at end of file
...@@ -37,7 +37,7 @@ def test(params): ...@@ -37,7 +37,7 @@ def test(params):
def add_args(parser): def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Predict args") test_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str, parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.") help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None, parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
...@@ -47,10 +47,10 @@ def add_args(parser): ...@@ -47,10 +47,10 @@ def add_args(parser):
help="FASTA file on which to extract the ESM2 " help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.", "representations and then evaluate.",
) )
predict_args.add_argument("-o", "--output", type=str, default="test_metrics", test_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. " help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
predict_args.add_argument("--crop_data_to_model_lims", action="store_true", test_args.add_argument("--crop_data_to_model_lims", action="store_true",
help="If set, the data will be cropped to the limits of the model: " help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa.") "evaluations will be done only for proteins >50aa and <800aa.")
......
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib import pathlib
import argparse
from ..utils import add_general_args from ..utils import add_general_args
from ..model import SensePPIModel from ..model import SensePPIModel
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
......
...@@ -13,45 +13,44 @@ import gzip ...@@ -13,45 +13,44 @@ import gzip
import shutil import shutil
def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_proteins=None): def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None):
ids = [] ids = []
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id) ids.append(record.id)
if mode == 'all_to_all': if with_self:
if with_self: all_pairs = [p for p in product(ids, repeat=2)]
all_pairs = [p for p in product(ids, repeat=2)] else:
else: all_pairs = [p for p in permutations(ids, 2)]
all_pairs = [p for p in permutations(ids, 2)]
pairs = [] pairs = []
for p in all_pairs: for p in all_pairs:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs: if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p) pairs.append(p)
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2']) pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
data = pd.read_csv('string_interactions.tsv', delimiter='\t') data = pd.read_csv('string_interactions.tsv', delimiter='\t')
# Creating a dictionary of string ids and gene names # Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A'])) ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
ids_dict.update(dict(zip(data['preferredName_B'], data['stringId_B']))) ids_dict.update(dict(zip(data['preferredName_B'], data['stringId_B'])))
data = data[['stringId_A', 'stringId_B', 'score']] data = data[['stringId_A', 'stringId_B', 'score']]
data.columns = ['seq1', 'seq2', 'label'] data.columns = ['seq1', 'seq2', 'label']
pairs = pairs.merge(data, on=['seq1', 'seq2'], how='left').fillna(0) pairs = pairs.merge(data, on=['seq1', 'seq2'], how='left').fillna(0)
if delete_proteins is not None: if delete_proteins is not None:
print('Labels removed: ', delete_proteins) print('Labels removed: ', delete_proteins)
string_ids_to_delete = [] string_ids_to_delete = []
for label in delete_proteins: for label in delete_proteins:
string_ids_to_delete.append(ids_dict[label]) string_ids_to_delete.append(ids_dict[label])
print('String ids to delete: ', string_ids_to_delete) print('String ids to delete: ', string_ids_to_delete)
pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)] pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)]
pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)] pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)]
pairs.to_csv('protein.actions.tsv', sep='\t', index=False, header=False) pairs.to_csv(output_file, sep='\t', index=False, header=False)
def generate_dscript_gene_names(file_path, def generate_dscript_gene_names(file_path,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment