0.6.5 updated names for tmp files for predict and predict string

parent 04ca5c16
__version__ = "0.6.4"
__version__ = "0.6.5"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -9,6 +9,7 @@ from ..dataset import PairSequenceData
from ..model import SensePPIModel
from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
from datetime import datetime
def predict(params):
......@@ -120,7 +121,8 @@ def get_protein_names(fasta_file):
def main(params):
tmp_pairs = 'senseppi_pairs_for_prediction.tmp'
current_time = str(datetime.now()).replace(' ', '_')
tmp_pairs = current_time + '_senseppi_pairs_for_prediction.tsv.tmp'
try:
fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None:
......
......@@ -19,14 +19,13 @@ def main(params):
label_threshold = params.score / 1000.
pred_threshold = params.pred_threshold / 1000.
pairs_file = 'protein.pairs_string.tsv'
fasta_file = 'sequences.fasta'
print('Fetching interactions from STRING database...')
get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
string_pairs_file, fasta_file = get_interactions_from_string(params.genes, species=params.species, add_nodes=params.nodes,
required_score=params.score, network_type=params.network_type)
pairs_file = string_pairs_file.replace('.tsv', '_for_pred.tsv')
try:
process_string_fasta(fasta_file, min_len=params.min_len, max_len=params.max_len)
generate_pairs_string(fasta_file, output_file=pairs_file, delete_proteins=params.delete_proteins)
generate_pairs_string(fasta_file, string_pairs_file, output_file=pairs_file, delete_proteins=params.delete_proteins)
params.fasta_file = fasta_file
params.pairs_file = pairs_file
......@@ -35,14 +34,14 @@ def main(params):
preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
data = pd.read_csv(pairs_file, delimiter='\t', names=["seq1", "seq2", "string_label"])
data['binary_label'] = data['string_label'].apply(lambda x: 1 if x > label_threshold else 0)
data['preds'] = preds
print(data.sort_values(by=['preds'], ascending=False).to_string())
string_ids = {}
string_tsv = pd.read_csv('string_interactions.tsv', delimiter='\t')[
string_tsv = pd.read_csv(string_pairs_file, delimiter='\t')[
['preferredName_A', 'preferredName_B', 'stringId_A', 'stringId_B']]
for i, row in string_tsv.iterrows():
string_ids[row['stringId_A']] = row['preferredName_A']
......@@ -60,8 +59,8 @@ def main(params):
# Plot the predictions as matrix but do not sort the labels
data_heatmap = data.pivot(index='seq1', columns='seq2', values='preds')
labels_heatmap = data.pivot(index='seq1', columns='seq2', values='string_label')
# Produce the list of protein names from sequences.fasta file
protein_names = [line.strip()[1:] for line in open('sequences.fasta', 'r') if line.startswith('>')]
# Produce the list of protein names from fasta file
protein_names = [line.strip()[1:] for line in open(fasta_file, 'r') if line.startswith('>')]
# Produce the list of gene names from genes Dataframe
gene_names = [string_ids[prot] for prot in protein_names]
# Sort the preds for index and columns based on the order of proteins in the fasta file
......@@ -156,12 +155,15 @@ def main(params):
print("The graphs were saved to: ", save_path)
plt.show()
plt.close()
except Exception as e:
raise e
finally:
# Remove the temporary files
os.remove(fasta_file)
os.remove(pairs_file)
os.remove(string_pairs_file)
for f in glob.glob('{}.protein.sequences*'.format(params.species)):
os.remove(f)
os.remove('string_interactions.tsv')
def add_args(parser):
......
......@@ -7,11 +7,12 @@ import urllib.request
import requests
import gzip
import shutil
from datetime import datetime
DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/"
def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
def generate_pairs_string(fasta_file, pairs_file, output_file, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
......@@ -23,7 +24,7 @@ def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
pairs = pd.DataFrame(pairs, columns=['seq1', 'seq2'])
data = pd.read_csv('string_interactions.tsv', delimiter='\t')
data = pd.read_csv(pairs_file, delimiter='\t')
# Creating a dictionary of string ids and gene names
ids_dict = dict(zip(data['preferredName_A'], data['stringId_A']))
......@@ -70,6 +71,11 @@ def get_string_url():
def get_interactions_from_string(gene_names, species=9606, add_nodes=10, required_score=500, network_type='physical'):
current_time = str(datetime.now()).replace(' ', '_')
pairs_file = current_time + '_protein.pairs_string.tsv'
fasta_file = current_time + '_sequences.fasta'
string_api_url, version = get_string_url()
output_format = "tsv"
method = "network"
......@@ -134,11 +140,13 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require
string_names_input_genes['stringId'].to_list()
ids = set(ids)
with open('sequences.fasta', 'w') as f:
with open(fasta_file, 'w') as f:
for record in SeqIO.parse('{}.protein.sequences.v{}.fa'.format(species, version), "fasta"):
if record.id in ids:
SeqIO.write(record, f, "fasta")
string_interactions.to_csv('string_interactions.tsv', sep='\t', index=False)
string_interactions.to_csv(pairs_file, sep='\t', index=False)
return pairs_file, fasta_file
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment