0.6.5 updated names for tmp files for esm2 and fasta utils

parent f5945547
...@@ -10,6 +10,7 @@ import logging ...@@ -10,6 +10,7 @@ import logging
from esm import FastaBatchedDataset, pretrained from esm import FastaBatchedDataset, pretrained
from copy import copy from copy import copy
from Bio import SeqIO from Bio import SeqIO
from datetime import datetime
def add_esm_args(parent_parser): def add_esm_args(parent_parser):
...@@ -119,12 +120,18 @@ def compute_embeddings(params): ...@@ -119,12 +120,18 @@ def compute_embeddings(params):
seq_dict.pop(seq_id) seq_dict.pop(seq_id)
if len(seq_dict) > 0: if len(seq_dict) > 0:
params_esm = copy(params) params_esm = copy(params)
params_esm.fasta_file = Path(str(params.fasta_file).replace('fasta', 'tmp.fasta')) try:
current_time = str(datetime.now()).replace(' ', '_')
params_esm.fasta_file = Path(current_time + '_' + str(params.fasta_file).replace('fasta', 'tmp.fasta'))
with open(params_esm.fasta_file, 'w') as f: with open(params_esm.fasta_file, 'w') as f:
for seq_id in seq_dict.keys(): for seq_id in seq_dict.keys():
f.write('>' + seq_id + '\n') f.write('>' + seq_id + '\n')
f.write(str(seq_dict[seq_id].seq) + '\n') f.write(str(seq_dict[seq_id].seq) + '\n')
run(params_esm) run(params_esm)
except Exception as e:
raise e
finally:
if os.path.exists(params_esm.fasta_file):
os.remove(params_esm.fasta_file) os.remove(params_esm.fasta_file)
else: else:
logging.info('All ESM embeddings already computed') logging.info('All ESM embeddings already computed')
......
...@@ -4,6 +4,7 @@ from senseppi import __version__ ...@@ -4,6 +4,7 @@ from senseppi import __version__
import torch import torch
import logging import logging
import argparse import argparse
from datetime import datetime
class ArgumentParserWithDefaults(argparse.ArgumentParser): class ArgumentParserWithDefaults(argparse.ArgumentParser):
...@@ -68,7 +69,8 @@ def block_mps(params): ...@@ -68,7 +69,8 @@ def block_mps(params):
def process_string_fasta(fasta_file, min_len, max_len): def process_string_fasta(fasta_file, min_len, max_len):
with open('file.tmp', 'w') as f: tmp_name = str(datetime.now()).replace(' ', '_') + '_fasta_processed.tmp'
with open(tmp_name, 'w') as f:
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
if len(record.seq) < min_len or len(record.seq) > max_len: if len(record.seq) < min_len or len(record.seq) > max_len:
continue continue
...@@ -78,7 +80,7 @@ def process_string_fasta(fasta_file, min_len, max_len): ...@@ -78,7 +80,7 @@ def process_string_fasta(fasta_file, min_len, max_len):
SeqIO.write(record, f, "fasta") SeqIO.write(record, f, "fasta")
# Rename the temporary file to the original file # Rename the temporary file to the original file
os.remove(fasta_file) os.remove(fasta_file)
os.rename('file.tmp', fasta_file) os.rename(tmp_name, fasta_file)
def get_fasta_ids(fasta_file): def get_fasta_ids(fasta_file):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment