0.6.5 updated names for tmp files for esm2 and fasta utils

parent f5945547
......@@ -10,6 +10,7 @@ import logging
from esm import FastaBatchedDataset, pretrained
from copy import copy
from Bio import SeqIO
from datetime import datetime
def add_esm_args(parent_parser):
......@@ -119,12 +120,18 @@ def compute_embeddings(params):
seq_dict.pop(seq_id)
if len(seq_dict) > 0:
params_esm = copy(params)
params_esm.fasta_file = Path(str(params.fasta_file).replace('fasta', 'tmp.fasta'))
try:
current_time = str(datetime.now()).replace(' ', '_')
params_esm.fasta_file = Path(current_time + '_' + str(params.fasta_file).replace('fasta', 'tmp.fasta'))
with open(params_esm.fasta_file, 'w') as f:
for seq_id in seq_dict.keys():
f.write('>' + seq_id + '\n')
f.write(str(seq_dict[seq_id].seq) + '\n')
run(params_esm)
except Exception as e:
raise e
finally:
if os.path.exists(params_esm.fasta_file):
os.remove(params_esm.fasta_file)
else:
logging.info('All ESM embeddings already computed')
......
......@@ -4,6 +4,7 @@ from senseppi import __version__
import torch
import logging
import argparse
from datetime import datetime
class ArgumentParserWithDefaults(argparse.ArgumentParser):
......@@ -68,7 +69,8 @@ def block_mps(params):
def process_string_fasta(fasta_file, min_len, max_len):
with open('file.tmp', 'w') as f:
tmp_name = str(datetime.now()).replace(' ', '_') + '_fasta_processed.tmp'
with open(tmp_name, 'w') as f:
for record in SeqIO.parse(fasta_file, "fasta"):
if len(record.seq) < min_len or len(record.seq) > max_len:
continue
......@@ -78,7 +80,7 @@ def process_string_fasta(fasta_file, min_len, max_len):
SeqIO.write(record, f, "fasta")
# Rename the temporary file to the original file
os.remove(fasta_file)
os.rename('file.tmp', fasta_file)
os.rename(tmp_name, fasta_file)
def get_fasta_ids(fasta_file):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment