Commit f3163134 by Gianluca LOMBARDI

Update imports and README

parent f812f2cc
# MuLAN: MUtational effects with Light Attention Networks # MuLAN: Mutational effects with Light Attention Networks
![mulan abstract](./images/visual_abstract.png) ![mulan abstract](./images/visual_abstract.png)
...@@ -13,8 +13,8 @@ Attention weights extracted from the model can give insights on protein interfac ...@@ -13,8 +13,8 @@ Attention weights extracted from the model can give insights on protein interfac
### Installation ### Installation
As a prerequisite, you must have PyTorch installed to use this repository. If not, it can be installed with conda running the following: As a prerequisite, you must have PyTorch installed to use this repository. If not, it can be installed with conda running the following:
```bash ```bash
# PyTorch 2.1.0, CUDA 12.1 # PyTorch 2.4.0, CUDA 12.1
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
``` ```
For other versions or installation methods, please refer to [PyTorch documentation](https://pytorch.org/get-started/locally/). For other versions or installation methods, please refer to [PyTorch documentation](https://pytorch.org/get-started/locally/).
MuLAN and its dependencies can then be installed with MuLAN and its dependencies can then be installed with
...@@ -23,15 +23,36 @@ git clone https://github.com/GianLMB/mulan ...@@ -23,15 +23,36 @@ git clone https://github.com/GianLMB/mulan
cd mulan cd mulan
pip install . pip install .
``` ```
We suggest to do it in a dedicated conda environment. We suggest to do it in a dedicated conda environment.
Cloning the repository will also download weights for different model versions trained on SKEMPI dataset. They are available in the folder
`models/pretrained`. To be able to load them within the package, the environmental variable must be set:
```bash
export MULAN="path/to/mulan/folder"
```
Alternatively, if you prefer to move the checkpoint files to a different folder, you can acces them in the new location by setting the
`MULAN_MODELS_DIR` variable pointing to the corresponding folder.
### Usage ### Usage
Available mulan models and pre-trained protein language models can be easily loaded through `mulan` interface. For MuLAN models:
```python
import mulan
print(mulan.get_available_models())
model = mulan.load_pretrained("mulan-ankh")
```
For some supported PLMs, instead:
```python
import mulan
print(mulan.get_available_plms())
model = mulan.load_pretrained_plm("ankh")
```
If the corresponding PLMs are not found on disk, they will be downloaded from HuggingFace Hub and stored by default in `~/.cache/huggingface/hub` folder.
We provide several command line interfaces for quick usage of MuLAN different applications: We provide several command line interfaces for quick usage of MuLAN different applications:
- `mulan-predict` for $\Delta \Delta G$ prediction of single and multiple-point mutations; - `mulan-predict` for $\Delta \Delta G$ prediction of single and multiple-point mutations;
- `mulan-att` to extract residues weights, related to interface regions; - `mulan-att` to extract residues weights, related to interface regions;
- `mulan-landscape` to produce a full mutational landscape for a given complex; - `mulan-landscape` to produce a full mutational landscape for a given complex;
- `mulan-train` to re-train the model on a custom dataset or to run cross validation (not added yet); - `mulan-train` to re-train the model on a custom dataset or to run cross validation (not supported yet);
- `plm-embed` to extract embeddings from protein language models. - `plm-embed` to extract embeddings from protein language models.
Since the script uses the `transformers` interface, only models that are saved on HuggingFace 🤗 Hub can be loaded. Since the script uses the `transformers` interface, only models that are saved on HuggingFace 🤗 Hub can be loaded.
...@@ -39,5 +60,16 @@ Information about required inputs and usage examples for each command are provid ...@@ -39,5 +60,16 @@ Information about required inputs and usage examples for each command are provid
## Citation ## Citation
If you find this useful, please cite
```bibtex
@article {Lombardi2024.08.24.609515,
author = {Lombardi, Gianluca and Carbone, Alessandra},
title = {MuLAN: Mutation-driven Light Attention Networks for investigating protein-protein interactions from sequences},
year = {2024},
doi = {10.1101/2024.08.24.609515},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2024/08/26/2024.08.24.609515},
journal = {bioRxiv}
}
```
from .config import MulanConfig # noqa
from .modules import LightAttModel # noqa
from .utils import (
load_pretrained,
load_pretrained_plm,
get_available_models,
get_available_plms,
) # noqa
\ No newline at end of file
...@@ -49,9 +49,9 @@ three2one = {v: k for k, v in one2three.items()} ...@@ -49,9 +49,9 @@ three2one = {v: k for k, v in one2three.items()}
# Models names and paths # Models names and paths
_BASE_DIR = os.getcwd() # os.path.dirname(os.path.dirname(os.path.abspath(__file__))) _BASE_DIR = os.environ.get("MULAN", os.getcwd())
_DEFAULT_MODELS_DIR = os.path.join(_BASE_DIR, "models/pretrained") _DEFAULT_MODELS_DIR = os.path.join(_BASE_DIR, "models/pretrained")
MODELS_DIR = os.environ.get("MULAN_MODELS_DIR", _DEFAULT_MODELS_DIR) MODELS_DIR = os.environ.get("MULAN_MODELS_PATH", _DEFAULT_MODELS_DIR)
MODELS = { MODELS = {
"mulan-esm": f"{MODELS_DIR}/mulan_esm.ckpt", "mulan-esm": f"{MODELS_DIR}/mulan_esm.ckpt",
"mulan-esm-multiple": f"{MODELS_DIR}/mulan_esm_multiple.ckpt", "mulan-esm-multiple": f"{MODELS_DIR}/mulan_esm_multiple.ckpt",
......
...@@ -69,9 +69,7 @@ class AttentionMeanK(nn.Module): ...@@ -69,9 +69,7 @@ class AttentionMeanK(nn.Module):
], ],
dim=1, dim=1,
) )
# print(attn_weights.shape, o.shape)
o1 = torch.sum(o * attn_weights, dim=-1).view(batch_size, -1) o1 = torch.sum(o * attn_weights, dim=-1).view(batch_size, -1)
# print(o1.shape)
# max pooling # max pooling
o2, _ = torch.max(o, dim=-1) o2, _ = torch.max(o, dim=-1)
...@@ -80,7 +78,6 @@ class AttentionMeanK(nn.Module): ...@@ -80,7 +78,6 @@ class AttentionMeanK(nn.Module):
# mlp # mlp
o = torch.cat([o1, o2], dim=-1) o = torch.cat([o1, o2], dim=-1)
o = self.fc(o) o = self.fc(o)
# attn_mean = torch.softmax(attn_weigths.mean(dim=-1)
output = OutputWithAttention(o, attn_weights) output = OutputWithAttention(o, attn_weights)
return output return output
...@@ -130,7 +127,7 @@ class LightAttModel(nn.Module): ...@@ -130,7 +127,7 @@ class LightAttModel(nn.Module):
) )
output = x_mut - x_wt output = x_mut - x_wt
if zs_scores is not None and (self.config.add_scores or self.config.add_columns_scores): if zs_scores is not None and self.config.add_scores:
output = torch.cat((output, zs_scores.view(batch_size, -1)), dim=-1) output = torch.cat((output, zs_scores.view(batch_size, -1)), dim=-1)
output = self.linear(output).squeeze(-1) output = self.linear(output).squeeze(-1)
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
from scipy.stats import rankdata from scipy.stats import rankdata
import mulan.constants as C import mulan.constants as C
from mulan.constants import AAs, aa2idx, idx2aa, one2three, three2one # noqa
def mutation_generator(sequence): def mutation_generator(sequence):
...@@ -72,28 +73,39 @@ def dict_to_fasta(fasta_dict, fasta_file): ...@@ -72,28 +73,39 @@ def dict_to_fasta(fasta_dict, fasta_file):
f.write(f"{value}\n") f.write(f"{value}\n")
def get_plm_names(): def get_available_plms():
"""Return the names of the available pretrained language models.""" """Return the names of the available pretrained language models."""
return list(C.PLM_ENCODERS.keys()) return list(C.PLM_ENCODERS.keys())
def get_available_models():
"""Return the names of the available Mulan models."""
return list(C.MODELS.keys())
def load_pretrained_plm(model_name, device="cpu"): def load_pretrained_plm(model_name, device="cpu"):
model_id = C.PLM_ENCODERS.get(model_name) model_id = C.PLM_ENCODERS.get(model_name)
if model_id is None: if model_id is None:
raise ValueError(f"Invalid model_name: {model_name}. Must be one of {get_plm_names()}") raise ValueError(
f"Invalid model_name: {model_name}. Must be one of {get_available_plms()}"
)
if "t5" in model_id or "ankh" in model_id: if "t5" in model_id or "ankh" in model_id:
from transformers import T5EncoderModel from transformers import T5EncoderModel
model = T5EncoderModel.from_pretrained(model_id) model = T5EncoderModel.from_pretrained(model_id)
if "ankh" in model_id: if "ankh" in model_id:
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
else: else:
from transformers import T5Tokenizer from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained(model_id)
tokenizer = T5Tokenizer.from_pretrained(model_id)
else: else:
try: try:
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id) model = AutoModel.from_pretrained(model_id)
except Exception as e: except Exception as e:
...@@ -103,6 +115,16 @@ def load_pretrained_plm(model_name, device="cpu"): ...@@ -103,6 +115,16 @@ def load_pretrained_plm(model_name, device="cpu"):
return model, tokenizer return model, tokenizer
def load_pretrained(pretrained_model_name, device=None, **kwargs):
"""Load a pretrained model from disk."""
model_path = C.MODELS.get(pretrained_model_name)
if model_path is None:
raise ValueError(f"Invalid model_name: {pretrained_model_name}")
from mulan.modules import LightAttModel
return LightAttModel.from_pretrained(model_path, device=device, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def embed_sequence(plm_model, plm_tokenizer, sequence): def embed_sequence(plm_model, plm_tokenizer, sequence):
"""Embed a sequence using a pretrained model.""" """Embed a sequence using a pretrained model."""
......
...@@ -8,8 +8,8 @@ import pandas as pd ...@@ -8,8 +8,8 @@ import pandas as pd
import torch import torch
import numpy as np import numpy as np
from mulan import utils, constants as C import mulan
from mulan.modules import LightAttModel import mulan.utils as utils
def get_args(): def get_args():
...@@ -24,10 +24,11 @@ def get_args(): ...@@ -24,10 +24,11 @@ def get_args():
The first is the sequence to be scored, the other is the partner.""", The first is the sequence to be scored, the other is the partner.""",
) )
parser.add_argument( parser.add_argument(
"-m", "--model-name", "-m",
"--model-name",
type=str, type=str,
default="mulan-ankh", default="mulan-ankh",
help=f"""Name of the pre-trained model. Must be one of: {C.MODELS.keys()}. help=f"""Name of the pre-trained model. Must be one of: {mulan.get_available_models()}.
If the model is a version of imulan, a TXT file containing zero-shot scores If the model is a version of imulan, a TXT file containing zero-shot scores
must be provided, where lines correspond to mutated positions and columns must be provided, where lines correspond to mutated positions and columns
to amino acids, in alphabetic order for the single letter name, separated to amino acids, in alphabetic order for the single letter name, separated
...@@ -58,8 +59,8 @@ def get_args(): ...@@ -58,8 +59,8 @@ def get_args():
help="If not None, directory to store embeddings in PT format.", help="If not None, directory to store embeddings in PT format.",
) )
args = parser.parse_args() args = parser.parse_args()
if args.model_name not in C.MODELS: if args.model_name not in mulan.get_available_models():
raise ValueError(f"Model name must be one of: {C.MODELS.keys()}") raise ValueError(f"Invalid model name: {args.model_name}")
args.sequences = args.sequences.upper() args.sequences = args.sequences.upper()
if "imulan" in args.model_name and args.scores_file is None: if "imulan" in args.model_name and args.scores_file is None:
raise ValueError("Zero-shot scores file must be provided for imulan models.") raise ValueError("Zero-shot scores file must be provided for imulan models.")
...@@ -86,7 +87,7 @@ def score_mutation( ...@@ -86,7 +87,7 @@ def score_mutation(
wildtype_embedding, wildtype_embedding,
partner_embedding, partner_embedding,
mutation_embedding, mutation_embedding,
partner_embedding partner_embedding,
], ],
zs_scores=zs_score, zs_scores=zs_score,
) )
...@@ -104,7 +105,7 @@ def run( ...@@ -104,7 +105,7 @@ def run(
"""Compute full mutational landscape of a protein in a given complex.""" """Compute full mutational landscape of a protein in a given complex."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq1, seq2 = re.sub(r"[UZOB]", "X", sequences).split(":") seq1, seq2 = re.sub(r"[UZOB]", "X", sequences).split(":")
output = torch.zeros(len(seq1), len(C.AAs)) output = torch.zeros(len(seq1), len(utils.AAs))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if embeddings_dir is not None: if embeddings_dir is not None:
os.makedirs(embeddings_dir, exist_ok=True) os.makedirs(embeddings_dir, exist_ok=True)
...@@ -112,7 +113,7 @@ def run( ...@@ -112,7 +113,7 @@ def run(
# load models # load models
plm_name = model_name.split("-")[1] plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device) plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name]) model = mulan.load_pretrained(model_name)
model.eval() model.eval()
if scores_file is not None: if scores_file is not None:
...@@ -127,12 +128,12 @@ def run( ...@@ -127,12 +128,12 @@ def run(
utils.save_embedding(partner_embedding, embeddings_dir, "PARTNER") utils.save_embedding(partner_embedding, embeddings_dir, "PARTNER")
# iterate over single-point mutations # iterate over single-point mutations
num_iterations = len(seq1) * (len(C.AAs) - 1) num_iterations = len(seq1) * (len(utils.AAs) - 1)
pbar = tqdm(initial=0, total=num_iterations, colour="red", dynamic_ncols=True, ascii="-#") pbar = tqdm(initial=0, total=num_iterations, colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Scoring mutations") pbar.set_description("Scoring mutations")
for mut_name, mutation in utils.mutation_generator(seq1): for mut_name, mutation in utils.mutation_generator(seq1):
i, aa = int(mut_name[1:-1]) - 1, mut_name[-1] i, aa = int(mut_name[1:-1]) - 1, mut_name[-1]
zs_score = None if scores_file is None else zs_scores[i, C.aa2idx[aa]] zs_score = None if scores_file is None else zs_scores[i, utils.aa2idx[aa]]
score = score_mutation( score = score_mutation(
model, model,
plm_model, plm_model,
...@@ -144,16 +145,18 @@ def run( ...@@ -144,16 +145,18 @@ def run(
zs_score, zs_score,
embeddings_dir, embeddings_dir,
) )
output[i, C.aa2idx[aa]] = score output[i, utils.aa2idx[aa]] = score
pbar.update(1) pbar.update(1)
# save output # save output
output = output.cpu().numpy() output = output.cpu().numpy()
if ranksort_output: if ranksort_output:
wt_index = (range(len(seq1)), tuple([C.aa2idx[aa] for aa in seq1])) wt_index = (range(len(seq1)), tuple([utils.aa2idx[aa] for aa in seq1]))
output[wt_index] = -10000 # set to very low value to give lowest scores to wt residues output[wt_index] = -10000 # set to very low value to give lowest scores to wt residues
output = utils.ranksort(output) output = utils.ranksort(output)
output = pd.DataFrame(output, columns=C.AAs, index=[f"{i+1}{aa}" for i, aa in enumerate(seq1)]) output = pd.DataFrame(
output, columns=utils.AAs, index=[f"{i+1}{aa}" for i, aa in enumerate(seq1)]
)
output.to_csv(os.path.join(output_dir, "landscape.csv"), float_format="%.3f") output.to_csv(os.path.join(output_dir, "landscape.csv"), float_format="%.3f")
...@@ -170,4 +173,4 @@ def main(): ...@@ -170,4 +173,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -4,15 +4,14 @@ import os ...@@ -4,15 +4,14 @@ import os
from argparse import ArgumentParser from argparse import ArgumentParser
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import h5py # type: ignore import h5py
from mulan import utils, constants as C import mulan
from mulan.modules import LightAttModel
def get_args(): def get_args():
parser = ArgumentParser( parser = ArgumentParser(
prog="extract_attentions", prog="mulan-att",
description=__doc__, description=__doc__,
) )
parser.add_argument( parser.add_argument(
...@@ -24,9 +23,7 @@ def get_args(): ...@@ -24,9 +23,7 @@ def get_args():
"--model-name", "--model-name",
type=str, type=str,
default="mulan-ankh", default="mulan-ankh",
help=f""" help=f"Name of the pre-trained model. Must be one of: {mulan.get_available_models()}.",
Name of the pre-trained model. Must be one of:
{list(C.MODELS.keys())}.""",
) )
parser.add_argument( parser.add_argument(
"-o", "--output-file", type=str, default="attentions.h5", help="Output H5 file" "-o", "--output-file", type=str, default="attentions.h5", help="Output H5 file"
...@@ -39,7 +36,7 @@ def get_args(): ...@@ -39,7 +36,7 @@ def get_args():
help="If not None, directory to store embeddings in PT format.", help="If not None, directory to store embeddings in PT format.",
) )
args = parser.parse_args() args = parser.parse_args()
if args.model_name not in C.MODELS: if args.model_name not in mulan.get_available_models():
raise ValueError(f"Invalid model name: {args.model_name}") raise ValueError(f"Invalid model name: {args.model_name}")
return args return args
...@@ -47,10 +44,10 @@ def get_args(): ...@@ -47,10 +44,10 @@ def get_args():
@torch.inference_mode @torch.inference_mode
def run(model_name, fasta_file, output_file, embeddings_dir=None): def run(model_name, fasta_file, output_file, embeddings_dir=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = utils.parse_fasta(fasta_file) dataset = mulan.utils.parse_fasta(fasta_file)
plm_name = model_name.split("-")[1] plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device) plm_model, plm_tokenizer = mulan.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name]) model = mulan.load_pretrained(model_name)
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
if embeddings_dir is not None: if embeddings_dir is not None:
os.makedirs(embeddings_dir, exist_ok=True) os.makedirs(embeddings_dir, exist_ok=True)
...@@ -59,15 +56,15 @@ def run(model_name, fasta_file, output_file, embeddings_dir=None): ...@@ -59,15 +56,15 @@ def run(model_name, fasta_file, output_file, embeddings_dir=None):
with h5py.File(output_file, "w") as f: with h5py.File(output_file, "w") as f:
for name, sequence in dataset.items(): for name, sequence in dataset.items():
# embed sequence # embed sequence
embedding = utils.embed_sequence(plm_model, plm_tokenizer, sequence) embedding = mulan.utils.embed_sequence(plm_model, plm_tokenizer, sequence)
# compute attention weights # compute attention weights
attention = model([embedding] * 4, output_attentions=True).attention[0] attention = model([embedding] * 4, output_attentions=True).attention[0]
attention = attention.squeeze(0).cpu().numpy().mean(axis=(-2, -3)) attention = attention.squeeze(0).cpu().numpy().mean(axis=(-2, -3))
attention = utils.minmax_scale(attention) attention = mulan.utils.minmax_scale(attention)
# save attention weights # save attention weights
f.create_dataset(name, data=attention) f.create_dataset(name, data=attention)
if embeddings_dir is not None: if embeddings_dir is not None:
utils.save_embedding(embedding, embeddings_dir, name) mulan.utils.save_embedding(embedding, embeddings_dir, name)
pbar.update(1) pbar.update(1)
......
...@@ -6,12 +6,12 @@ from argparse import ArgumentParser ...@@ -6,12 +6,12 @@ from argparse import ArgumentParser
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from mulan import utils, constants as C import mulan
def get_args(): def get_args():
parser = ArgumentParser( parser = ArgumentParser(
prog="generate_embeddings", prog="plm-embed",
description=__doc__, description=__doc__,
) )
parser.add_argument( parser.add_argument(
...@@ -22,8 +22,8 @@ def get_args(): ...@@ -22,8 +22,8 @@ def get_args():
parser.add_argument( parser.add_argument(
"model_name", "model_name",
type=str, type=str,
default="ankh_large", default="ankh",
help=f"PLM name to be loaded. Must be one of {C.PLM_ENCODERS.keys()}", help=f"PLM name to be loaded. Must be one of {mulan.get_available_plms()}",
) )
parser.add_argument( parser.add_argument(
"-o", "-o",
...@@ -33,21 +33,23 @@ def get_args(): ...@@ -33,21 +33,23 @@ def get_args():
help="Output directory. Defaults to './embeddings'", help="Output directory. Defaults to './embeddings'",
) )
args = parser.parse_args() args = parser.parse_args()
if args.model_name not in mulan.get_available_plms():
raise ValueError(f"Invalid model name: {args.model_name}")
return args return args
@torch.inference_mode() @torch.inference_mode()
def run(model_name, fasta_file, output_dir): def run(model_name, fasta_file, output_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tokenizer = utils.load_pretrained_plm(model_name, device=device) model, tokenizer = mulan.load_pretrained_plm(model_name, device=device)
dataset = utils.parse_fasta(fasta_file) dataset = mulan.utils.parse_fasta(fasta_file)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
pbar = tqdm(initial=0, total=len(dataset), colour="red", dynamic_ncols=True, ascii="-#") pbar = tqdm(initial=0, total=len(dataset), colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Embedding sequences") pbar.set_description("Embedding sequences")
for name, sequence in dataset.items(): for name, sequence in dataset.items():
# embed sequence # embed sequence
embedding = utils.embed_sequence(model, tokenizer, sequence) embedding = mulan.utils.embed_sequence(model, tokenizer, sequence)
utils.save_embedding(embedding, output_dir, name) mulan.utils.save_embedding(embedding, output_dir, name)
pbar.update(1) pbar.update(1)
......
...@@ -7,21 +7,20 @@ from tqdm import tqdm ...@@ -7,21 +7,20 @@ from tqdm import tqdm
import torch import torch
import pandas as pd import pandas as pd
import mulan.constants as C import mulan
from mulan import utils import mulan.utils as utils
from mulan.modules import LightAttModel
def get_args(): def get_args():
parser = ArgumentParser( parser = ArgumentParser(
prog="predict", prog="mulan-predict",
description=__doc__, description=__doc__,
) )
parser.add_argument( parser.add_argument(
"model_name", "--model-name",
type=str, type=str,
default="mulan-ankh", default="mulan-ankh",
help=f"Name of the pre-trained model. Must be one of: {C.MODELS.keys()}.", help=f"Name of the pre-trained model. Must be one of: {mulan.get_available_models()}.",
) )
parser.add_argument( parser.add_argument(
"input_file", "input_file",
...@@ -54,9 +53,11 @@ def get_args(): ...@@ -54,9 +53,11 @@ def get_args():
parser.add_argument( parser.add_argument(
"--store-embeddings", "--store-embeddings",
action="store_true", action="store_true",
help="Store embeddings in PT format. Only wild type embeddings are stored.", help="Store embeddings in PT format. Output directory is the same of 'output_file'.",
) )
args = parser.parse_args() args = parser.parse_args()
if args.model_name not in mulan.get_available_models():
raise ValueError(f"Invalid model name: {args.model_name}")
return args return args
...@@ -97,8 +98,9 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings): ...@@ -97,8 +98,9 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings):
num_iterations = sum(len(d["mutations"]) for d in data) num_iterations = sum(len(d["mutations"]) for d in data)
scores = [] scores = []
plm_name = model_name.split("-")[1] plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device) plm_model, plm_tokenizer = mulan.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name]) model = mulan.load_pretrained(model_name)
model = model.eval()
if "imulan" in model_name and scores_file is None: if "imulan" in model_name and scores_file is None:
raise ValueError("Zero-shot scores file must be provided for imulan models.") raise ValueError("Zero-shot scores file must be provided for imulan models.")
zs_scores = parse_zs_scores(scores_file) if scores_file is not None else {} zs_scores = parse_zs_scores(scores_file) if scores_file is not None else {}
...@@ -138,8 +140,6 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings): ...@@ -138,8 +140,6 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings):
) )
if store_embeddings: if store_embeddings:
utils.save_embedding(seq1_embedding, embeddings_dir, f"{complex_name}_A")
utils.save_embedding(seq2_embedding, embeddings_dir, f"{complex_name}_B")
utils.save_embedding( utils.save_embedding(
mut_seq1_embedding, mut_seq1_embedding,
embeddings_dir, embeddings_dir,
...@@ -151,9 +151,13 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings): ...@@ -151,9 +151,13 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings):
f"{complex_name}_{'-'.join([mut for mut in mutations if mut[1] == 'B'])}" f"{complex_name}_{'-'.join([mut for mut in mutations if mut[1] == 'B'])}"
) )
pbar.update(1) pbar.update(1)
if store_embeddings:
utils.save_embedding(seq1_embedding, embeddings_dir, f"{complex_name}_A")
utils.save_embedding(seq2_embedding, embeddings_dir, f"{complex_name}_B")
df = pd.DataFrame.from_records(scores) df = pd.DataFrame.from_records(scores)
df.to_csv(output_file, index=False, sep="\t") df.to_csv(output_file, index=False, sep="\t", float_format="%.3f")
def main(): def main():
......
...@@ -4,7 +4,7 @@ from argparse import ArgumentParser ...@@ -4,7 +4,7 @@ from argparse import ArgumentParser
def get_args(): def get_args():
parser = ArgumentParser( parser = ArgumentParser(
prog="train", prog="mulan-train",
description=__doc__, description=__doc__,
) )
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment