0.1.7 embeddings are computed not from the beginning but only for missing sequences

parent ccdf1277
__version__ = "0.1.7" __version__ = "0.1.8"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -33,11 +33,11 @@ def main(): ...@@ -33,11 +33,11 @@ def main():
params = parser.parse_args() params = parser.parse_args()
if hasattr(params, 'device'): if hasattr(params, 'device'):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled # # WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps': # if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.' # logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.') # 'The cpu backend will be used instead.')
params.device = 'cpu' # params.device = 'cpu'
if params.device == 'gpu': if params.device == 'gpu':
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import os import os
import logging import logging
from esm import FastaBatchedDataset, pretrained from esm import FastaBatchedDataset, pretrained
from copy import copy
from Bio import SeqIO
def add_esm_args(parent_parser): def add_esm_args(parent_parser):
...@@ -43,37 +45,37 @@ def add_esm_args(parent_parser): ...@@ -43,37 +45,37 @@ def add_esm_args(parent_parser):
) )
def run(args): def run(params):
model, alphabet = pretrained.load_model_and_alphabet(args.model_location_esm) model, alphabet = pretrained.load_model_and_alphabet(params.model_location_esm)
model.eval() model.eval()
if args.device == 'gpu': if params.device == 'gpu':
model = model.cuda() model = model.cuda()
print("Transferred the ESM2 model to GPU") print("Transferred the ESM2 model to GPU")
elif args.device == 'mps': elif params.device == 'mps':
model = model.to('mps') model = model.to('mps')
print("Transferred the ESM2 model to MPS") print("Transferred the ESM2 model to MPS")
dataset = FastaBatchedDataset.from_file(args.fasta_file) dataset = FastaBatchedDataset.from_file(params.fasta_file)
batches = dataset.get_batch_indices(args.toks_per_batch_esm, extra_toks_per_seq=1) batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length_esm), batch_sampler=batches dataset, collate_fn=alphabet.get_batch_converter(params.truncation_seq_length_esm), batch_sampler=batches
) )
print(f"Read {args.fasta_file} with {len(dataset)} sequences") print(f"Read {params.fasta_file} with {len(dataset)} sequences")
args.output_dir_esm.mkdir(parents=True, exist_ok=True) params.output_dir_esm.mkdir(parents=True, exist_ok=True)
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers_esm) assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in params.repr_layers_esm)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers_esm] repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in params.repr_layers_esm]
with torch.no_grad(): with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader): for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print( print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
) )
if args.device == 'gpu': if params.device == 'gpu':
toks = toks.to(device="cuda", non_blocking=True) toks = toks.to(device="cuda", non_blocking=True)
elif args.device == 'mps': elif params.device == 'mps':
toks = toks.to(device="mps", non_blocking=True) toks = toks.to(device="mps", non_blocking=True)
out = model(toks, repr_layers=repr_layers, return_contacts=False) out = model(toks, repr_layers=repr_layers, return_contacts=False)
...@@ -83,42 +85,54 @@ def run(args): ...@@ -83,42 +85,54 @@ def run(args):
} }
for i, label in enumerate(labels): for i, label in enumerate(labels):
args.output_file_esm = args.output_dir_esm / f"{label}.pt" params.output_file_esm = params.output_dir_esm / f"{label}.pt"
args.output_file_esm.parent.mkdir(parents=True, exist_ok=True) params.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label} result = {"label": label}
truncate_len = min(args.truncation_seq_length_esm, len(strs[i])) truncate_len = min(params.truncation_seq_length_esm, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation # Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995 # See https://github.com/pytorch/pytorch/issues/1995
result["representations"] = { result["representations"] = {
layer: t[i, 1 : truncate_len + 1].clone() layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items() for layer, t in representations.items()
} }
torch.save( torch.save(
result, result,
args.output_file_esm, params.output_file_esm,
) )
def compute_embeddings(params): def compute_embeddings(params):
# Compute ESM embeddings # Compute ESM embeddings
logging.info('Computing ESM embeddings if they are not already computed. ' logging.info('Computing ESM embeddings. If all the files already exist in {} folder, '
'If all the files alreaady exist in {} folder, this step will be skipped.'.format(params.output_dir_esm)) 'this step will be skipped.'.format(params.output_dir_esm))
if not os.path.exists(params.output_dir_esm): if not os.path.exists(params.output_dir_esm):
run(params) run(params)
else: else:
with open(params.fasta_file, 'r') as f: with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')] # dict of only id and sequences from parsing fasta file
seq_dict = SeqIO.to_dict(SeqIO.parse(f, 'fasta'))
seq_ids = list(seq_dict.keys())
for seq_id in seq_ids: for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir_esm, seq_id + '.pt')): if os.path.exists(os.path.join(params.output_dir_esm, seq_id + '.pt')):
run(params) seq_dict.pop(seq_id)
break if len(seq_dict) > 0:
params_esm = copy(params)
params_esm.fasta_file = 'tmp_for_esm.fasta'
with open(params_esm.fasta_file, 'w') as f:
for seq_id in seq_dict.keys():
f.write('>' + seq_id + '\n')
f.write(str(seq_dict[seq_id].seq) + '\n')
run(params_esm)
os.remove(params_esm.fasta_file)
else:
logging.info('All ESM embeddings already computed')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() esm_parser = argparse.ArgumentParser()
add_esm_args(parser) add_esm_args(esm_parser)
args = parser.parse_args() args = esm_parser.parse_args()
run(args) run(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment