0.1.7 embeddings are computed not from the beginning but only for missing sequences

parent ccdf1277
__version__ = "0.1.7"
__version__ = "0.1.8"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -33,11 +33,11 @@ def main():
params = parser.parse_args()
if hasattr(params, 'device'):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
# # WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
# if params.device == 'mps':
# logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
# 'The cpu backend will be used instead.')
# params.device = 'cpu'
if params.device == 'gpu':
torch.set_float32_matmul_precision('high')
......
......@@ -7,6 +7,8 @@ import torch
import os
import logging
from esm import FastaBatchedDataset, pretrained
from copy import copy
from Bio import SeqIO
def add_esm_args(parent_parser):
......@@ -43,37 +45,37 @@ def add_esm_args(parent_parser):
)
def run(args):
model, alphabet = pretrained.load_model_and_alphabet(args.model_location_esm)
def run(params):
model, alphabet = pretrained.load_model_and_alphabet(params.model_location_esm)
model.eval()
if args.device == 'gpu':
if params.device == 'gpu':
model = model.cuda()
print("Transferred the ESM2 model to GPU")
elif args.device == 'mps':
elif params.device == 'mps':
model = model.to('mps')
print("Transferred the ESM2 model to MPS")
dataset = FastaBatchedDataset.from_file(args.fasta_file)
batches = dataset.get_batch_indices(args.toks_per_batch_esm, extra_toks_per_seq=1)
dataset = FastaBatchedDataset.from_file(params.fasta_file)
batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length_esm), batch_sampler=batches
dataset, collate_fn=alphabet.get_batch_converter(params.truncation_seq_length_esm), batch_sampler=batches
)
print(f"Read {args.fasta_file} with {len(dataset)} sequences")
print(f"Read {params.fasta_file} with {len(dataset)} sequences")
args.output_dir_esm.mkdir(parents=True, exist_ok=True)
params.output_dir_esm.mkdir(parents=True, exist_ok=True)
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers_esm)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers_esm]
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in params.repr_layers_esm)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in params.repr_layers_esm]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if args.device == 'gpu':
if params.device == 'gpu':
toks = toks.to(device="cuda", non_blocking=True)
elif args.device == 'mps':
elif params.device == 'mps':
toks = toks.to(device="mps", non_blocking=True)
out = model(toks, repr_layers=repr_layers, return_contacts=False)
......@@ -83,42 +85,54 @@ def run(args):
}
for i, label in enumerate(labels):
args.output_file_esm = args.output_dir_esm / f"{label}.pt"
args.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
params.output_file_esm = params.output_dir_esm / f"{label}.pt"
params.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label}
truncate_len = min(args.truncation_seq_length_esm, len(strs[i]))
truncate_len = min(params.truncation_seq_length_esm, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
result["representations"] = {
layer: t[i, 1 : truncate_len + 1].clone()
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
torch.save(
result,
args.output_file_esm,
params.output_file_esm,
)
def compute_embeddings(params):
# Compute ESM embeddings
logging.info('Computing ESM embeddings if they are not already computed. '
'If all the files alreaady exist in {} folder, this step will be skipped.'.format(params.output_dir_esm))
logging.info('Computing ESM embeddings. If all the files already exist in {} folder, '
'this step will be skipped.'.format(params.output_dir_esm))
if not os.path.exists(params.output_dir_esm):
run(params)
else:
with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')]
# dict of only id and sequences from parsing fasta file
seq_dict = SeqIO.to_dict(SeqIO.parse(f, 'fasta'))
seq_ids = list(seq_dict.keys())
for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir_esm, seq_id + '.pt')):
run(params)
break
if os.path.exists(os.path.join(params.output_dir_esm, seq_id + '.pt')):
seq_dict.pop(seq_id)
if len(seq_dict) > 0:
params_esm = copy(params)
params_esm.fasta_file = 'tmp_for_esm.fasta'
with open(params_esm.fasta_file, 'w') as f:
for seq_id in seq_dict.keys():
f.write('>' + seq_id + '\n')
f.write(str(seq_dict[seq_id].seq) + '\n')
run(params_esm)
os.remove(params_esm.fasta_file)
else:
logging.info('All ESM embeddings already computed')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_esm_args(parser)
args = parser.parse_args()
esm_parser = argparse.ArgumentParser()
add_esm_args(esm_parser)
args = esm_parser.parse_args()
run(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment