Commit f3163134 by Gianluca LOMBARDI

Update imports and README

parent f812f2cc
# MuLAN: MUtational effects with Light Attention Networks
# MuLAN: Mutational effects with Light Attention Networks
![mulan abstract](./images/visual_abstract.png)
......@@ -13,8 +13,8 @@ Attention weights extracted from the model can give insights on protein interfac
### Installation
As a prerequisite, you must have PyTorch installed to use this repository. If not, it can be installed with conda running the following:
```bash
# PyTorch 2.1.0, CUDA 12.1
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
# PyTorch 2.4.0, CUDA 12.1
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
```
For other versions or installation methods, please refer to [PyTorch documentation](https://pytorch.org/get-started/locally/).
MuLAN and its dependencies can then be installed with
......@@ -23,15 +23,36 @@ git clone https://github.com/GianLMB/mulan
cd mulan
pip install .
```
We suggest to do it in a dedicated conda environment.
We suggest to do it in a dedicated conda environment.
Cloning the repository will also download weights for different model versions trained on SKEMPI dataset. They are available in the folder
`models/pretrained`. To be able to load them within the package, the environmental variable must be set:
```bash
export MULAN="path/to/mulan/folder"
```
Alternatively, if you prefer to move the checkpoint files to a different folder, you can acces them in the new location by setting the
`MULAN_MODELS_DIR` variable pointing to the corresponding folder.
### Usage
Available mulan models and pre-trained protein language models can be easily loaded through `mulan` interface. For MuLAN models:
```python
import mulan
print(mulan.get_available_models())
model = mulan.load_pretrained("mulan-ankh")
```
For some supported PLMs, instead:
```python
import mulan
print(mulan.get_available_plms())
model = mulan.load_pretrained_plm("ankh")
```
If the corresponding PLMs are not found on disk, they will be downloaded from HuggingFace Hub and stored by default in `~/.cache/huggingface/hub` folder.
We provide several command line interfaces for quick usage of MuLAN different applications:
- `mulan-predict` for $\Delta \Delta G$ prediction of single and multiple-point mutations;
- `mulan-att` to extract residues weights, related to interface regions;
- `mulan-landscape` to produce a full mutational landscape for a given complex;
- `mulan-train` to re-train the model on a custom dataset or to run cross validation (not added yet);
- `mulan-train` to re-train the model on a custom dataset or to run cross validation (not supported yet);
- `plm-embed` to extract embeddings from protein language models.
Since the script uses the `transformers` interface, only models that are saved on HuggingFace 🤗 Hub can be loaded.
......@@ -39,5 +60,16 @@ Information about required inputs and usage examples for each command are provid
## Citation
If you find this useful, please cite
```bibtex
@article {Lombardi2024.08.24.609515,
author = {Lombardi, Gianluca and Carbone, Alessandra},
title = {MuLAN: Mutation-driven Light Attention Networks for investigating protein-protein interactions from sequences},
year = {2024},
doi = {10.1101/2024.08.24.609515},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2024/08/26/2024.08.24.609515},
journal = {bioRxiv}
}
```
from .config import MulanConfig # noqa
from .modules import LightAttModel # noqa
from .utils import (
load_pretrained,
load_pretrained_plm,
get_available_models,
get_available_plms,
) # noqa
\ No newline at end of file
......@@ -49,9 +49,9 @@ three2one = {v: k for k, v in one2three.items()}
# Models names and paths
_BASE_DIR = os.getcwd() # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_BASE_DIR = os.environ.get("MULAN", os.getcwd())
_DEFAULT_MODELS_DIR = os.path.join(_BASE_DIR, "models/pretrained")
MODELS_DIR = os.environ.get("MULAN_MODELS_DIR", _DEFAULT_MODELS_DIR)
MODELS_DIR = os.environ.get("MULAN_MODELS_PATH", _DEFAULT_MODELS_DIR)
MODELS = {
"mulan-esm": f"{MODELS_DIR}/mulan_esm.ckpt",
"mulan-esm-multiple": f"{MODELS_DIR}/mulan_esm_multiple.ckpt",
......
......@@ -69,9 +69,7 @@ class AttentionMeanK(nn.Module):
],
dim=1,
)
# print(attn_weights.shape, o.shape)
o1 = torch.sum(o * attn_weights, dim=-1).view(batch_size, -1)
# print(o1.shape)
# max pooling
o2, _ = torch.max(o, dim=-1)
......@@ -80,7 +78,6 @@ class AttentionMeanK(nn.Module):
# mlp
o = torch.cat([o1, o2], dim=-1)
o = self.fc(o)
# attn_mean = torch.softmax(attn_weigths.mean(dim=-1)
output = OutputWithAttention(o, attn_weights)
return output
......@@ -130,7 +127,7 @@ class LightAttModel(nn.Module):
)
output = x_mut - x_wt
if zs_scores is not None and (self.config.add_scores or self.config.add_columns_scores):
if zs_scores is not None and self.config.add_scores:
output = torch.cat((output, zs_scores.view(batch_size, -1)), dim=-1)
output = self.linear(output).squeeze(-1)
......
......@@ -9,6 +9,7 @@ import numpy as np
from scipy.stats import rankdata
import mulan.constants as C
from mulan.constants import AAs, aa2idx, idx2aa, one2three, three2one # noqa
def mutation_generator(sequence):
......@@ -72,28 +73,39 @@ def dict_to_fasta(fasta_dict, fasta_file):
f.write(f"{value}\n")
def get_plm_names():
def get_available_plms():
"""Return the names of the available pretrained language models."""
return list(C.PLM_ENCODERS.keys())
def get_available_models():
"""Return the names of the available Mulan models."""
return list(C.MODELS.keys())
def load_pretrained_plm(model_name, device="cpu"):
model_id = C.PLM_ENCODERS.get(model_name)
if model_id is None:
raise ValueError(f"Invalid model_name: {model_name}. Must be one of {get_plm_names()}")
raise ValueError(
f"Invalid model_name: {model_name}. Must be one of {get_available_plms()}"
)
if "t5" in model_id or "ankh" in model_id:
from transformers import T5EncoderModel
model = T5EncoderModel.from_pretrained(model_id)
if "ankh" in model_id:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
else:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained(model_id)
tokenizer = T5Tokenizer.from_pretrained(model_id)
else:
try:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
except Exception as e:
......@@ -103,6 +115,16 @@ def load_pretrained_plm(model_name, device="cpu"):
return model, tokenizer
def load_pretrained(pretrained_model_name, device=None, **kwargs):
"""Load a pretrained model from disk."""
model_path = C.MODELS.get(pretrained_model_name)
if model_path is None:
raise ValueError(f"Invalid model_name: {pretrained_model_name}")
from mulan.modules import LightAttModel
return LightAttModel.from_pretrained(model_path, device=device, **kwargs)
@torch.inference_mode()
def embed_sequence(plm_model, plm_tokenizer, sequence):
"""Embed a sequence using a pretrained model."""
......
......@@ -8,8 +8,8 @@ import pandas as pd
import torch
import numpy as np
from mulan import utils, constants as C
from mulan.modules import LightAttModel
import mulan
import mulan.utils as utils
def get_args():
......@@ -24,10 +24,11 @@ def get_args():
The first is the sequence to be scored, the other is the partner.""",
)
parser.add_argument(
"-m", "--model-name",
"-m",
"--model-name",
type=str,
default="mulan-ankh",
help=f"""Name of the pre-trained model. Must be one of: {C.MODELS.keys()}.
help=f"""Name of the pre-trained model. Must be one of: {mulan.get_available_models()}.
If the model is a version of imulan, a TXT file containing zero-shot scores
must be provided, where lines correspond to mutated positions and columns
to amino acids, in alphabetic order for the single letter name, separated
......@@ -58,8 +59,8 @@ def get_args():
help="If not None, directory to store embeddings in PT format.",
)
args = parser.parse_args()
if args.model_name not in C.MODELS:
raise ValueError(f"Model name must be one of: {C.MODELS.keys()}")
if args.model_name not in mulan.get_available_models():
raise ValueError(f"Invalid model name: {args.model_name}")
args.sequences = args.sequences.upper()
if "imulan" in args.model_name and args.scores_file is None:
raise ValueError("Zero-shot scores file must be provided for imulan models.")
......@@ -86,7 +87,7 @@ def score_mutation(
wildtype_embedding,
partner_embedding,
mutation_embedding,
partner_embedding
partner_embedding,
],
zs_scores=zs_score,
)
......@@ -104,7 +105,7 @@ def run(
"""Compute full mutational landscape of a protein in a given complex."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq1, seq2 = re.sub(r"[UZOB]", "X", sequences).split(":")
output = torch.zeros(len(seq1), len(C.AAs))
output = torch.zeros(len(seq1), len(utils.AAs))
os.makedirs(output_dir, exist_ok=True)
if embeddings_dir is not None:
os.makedirs(embeddings_dir, exist_ok=True)
......@@ -112,7 +113,7 @@ def run(
# load models
plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name])
model = mulan.load_pretrained(model_name)
model.eval()
if scores_file is not None:
......@@ -127,12 +128,12 @@ def run(
utils.save_embedding(partner_embedding, embeddings_dir, "PARTNER")
# iterate over single-point mutations
num_iterations = len(seq1) * (len(C.AAs) - 1)
num_iterations = len(seq1) * (len(utils.AAs) - 1)
pbar = tqdm(initial=0, total=num_iterations, colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Scoring mutations")
for mut_name, mutation in utils.mutation_generator(seq1):
i, aa = int(mut_name[1:-1]) - 1, mut_name[-1]
zs_score = None if scores_file is None else zs_scores[i, C.aa2idx[aa]]
zs_score = None if scores_file is None else zs_scores[i, utils.aa2idx[aa]]
score = score_mutation(
model,
plm_model,
......@@ -144,16 +145,18 @@ def run(
zs_score,
embeddings_dir,
)
output[i, C.aa2idx[aa]] = score
output[i, utils.aa2idx[aa]] = score
pbar.update(1)
# save output
output = output.cpu().numpy()
if ranksort_output:
wt_index = (range(len(seq1)), tuple([C.aa2idx[aa] for aa in seq1]))
wt_index = (range(len(seq1)), tuple([utils.aa2idx[aa] for aa in seq1]))
output[wt_index] = -10000 # set to very low value to give lowest scores to wt residues
output = utils.ranksort(output)
output = pd.DataFrame(output, columns=C.AAs, index=[f"{i+1}{aa}" for i, aa in enumerate(seq1)])
output = pd.DataFrame(
output, columns=utils.AAs, index=[f"{i+1}{aa}" for i, aa in enumerate(seq1)]
)
output.to_csv(os.path.join(output_dir, "landscape.csv"), float_format="%.3f")
......@@ -170,4 +173,4 @@ def main():
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -4,15 +4,14 @@ import os
from argparse import ArgumentParser
from tqdm import tqdm
import torch
import h5py # type: ignore
import h5py
from mulan import utils, constants as C
from mulan.modules import LightAttModel
import mulan
def get_args():
parser = ArgumentParser(
prog="extract_attentions",
prog="mulan-att",
description=__doc__,
)
parser.add_argument(
......@@ -24,9 +23,7 @@ def get_args():
"--model-name",
type=str,
default="mulan-ankh",
help=f"""
Name of the pre-trained model. Must be one of:
{list(C.MODELS.keys())}.""",
help=f"Name of the pre-trained model. Must be one of: {mulan.get_available_models()}.",
)
parser.add_argument(
"-o", "--output-file", type=str, default="attentions.h5", help="Output H5 file"
......@@ -39,7 +36,7 @@ def get_args():
help="If not None, directory to store embeddings in PT format.",
)
args = parser.parse_args()
if args.model_name not in C.MODELS:
if args.model_name not in mulan.get_available_models():
raise ValueError(f"Invalid model name: {args.model_name}")
return args
......@@ -47,10 +44,10 @@ def get_args():
@torch.inference_mode
def run(model_name, fasta_file, output_file, embeddings_dir=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = utils.parse_fasta(fasta_file)
dataset = mulan.utils.parse_fasta(fasta_file)
plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name])
plm_model, plm_tokenizer = mulan.load_pretrained_plm(plm_name, device=device)
model = mulan.load_pretrained(model_name)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if embeddings_dir is not None:
os.makedirs(embeddings_dir, exist_ok=True)
......@@ -59,15 +56,15 @@ def run(model_name, fasta_file, output_file, embeddings_dir=None):
with h5py.File(output_file, "w") as f:
for name, sequence in dataset.items():
# embed sequence
embedding = utils.embed_sequence(plm_model, plm_tokenizer, sequence)
embedding = mulan.utils.embed_sequence(plm_model, plm_tokenizer, sequence)
# compute attention weights
attention = model([embedding] * 4, output_attentions=True).attention[0]
attention = attention.squeeze(0).cpu().numpy().mean(axis=(-2, -3))
attention = utils.minmax_scale(attention)
attention = mulan.utils.minmax_scale(attention)
# save attention weights
f.create_dataset(name, data=attention)
if embeddings_dir is not None:
utils.save_embedding(embedding, embeddings_dir, name)
mulan.utils.save_embedding(embedding, embeddings_dir, name)
pbar.update(1)
......
......@@ -6,12 +6,12 @@ from argparse import ArgumentParser
from tqdm import tqdm
import torch
from mulan import utils, constants as C
import mulan
def get_args():
parser = ArgumentParser(
prog="generate_embeddings",
prog="plm-embed",
description=__doc__,
)
parser.add_argument(
......@@ -22,8 +22,8 @@ def get_args():
parser.add_argument(
"model_name",
type=str,
default="ankh_large",
help=f"PLM name to be loaded. Must be one of {C.PLM_ENCODERS.keys()}",
default="ankh",
help=f"PLM name to be loaded. Must be one of {mulan.get_available_plms()}",
)
parser.add_argument(
"-o",
......@@ -33,21 +33,23 @@ def get_args():
help="Output directory. Defaults to './embeddings'",
)
args = parser.parse_args()
if args.model_name not in mulan.get_available_plms():
raise ValueError(f"Invalid model name: {args.model_name}")
return args
@torch.inference_mode()
def run(model_name, fasta_file, output_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tokenizer = utils.load_pretrained_plm(model_name, device=device)
dataset = utils.parse_fasta(fasta_file)
model, tokenizer = mulan.load_pretrained_plm(model_name, device=device)
dataset = mulan.utils.parse_fasta(fasta_file)
os.makedirs(output_dir, exist_ok=True)
pbar = tqdm(initial=0, total=len(dataset), colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Embedding sequences")
for name, sequence in dataset.items():
# embed sequence
embedding = utils.embed_sequence(model, tokenizer, sequence)
utils.save_embedding(embedding, output_dir, name)
embedding = mulan.utils.embed_sequence(model, tokenizer, sequence)
mulan.utils.save_embedding(embedding, output_dir, name)
pbar.update(1)
......
......@@ -7,21 +7,20 @@ from tqdm import tqdm
import torch
import pandas as pd
import mulan.constants as C
from mulan import utils
from mulan.modules import LightAttModel
import mulan
import mulan.utils as utils
def get_args():
parser = ArgumentParser(
prog="predict",
prog="mulan-predict",
description=__doc__,
)
parser.add_argument(
"model_name",
"--model-name",
type=str,
default="mulan-ankh",
help=f"Name of the pre-trained model. Must be one of: {C.MODELS.keys()}.",
help=f"Name of the pre-trained model. Must be one of: {mulan.get_available_models()}.",
)
parser.add_argument(
"input_file",
......@@ -54,9 +53,11 @@ def get_args():
parser.add_argument(
"--store-embeddings",
action="store_true",
help="Store embeddings in PT format. Only wild type embeddings are stored.",
help="Store embeddings in PT format. Output directory is the same of 'output_file'.",
)
args = parser.parse_args()
if args.model_name not in mulan.get_available_models():
raise ValueError(f"Invalid model name: {args.model_name}")
return args
......@@ -97,8 +98,9 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings):
num_iterations = sum(len(d["mutations"]) for d in data)
scores = []
plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name])
plm_model, plm_tokenizer = mulan.load_pretrained_plm(plm_name, device=device)
model = mulan.load_pretrained(model_name)
model = model.eval()
if "imulan" in model_name and scores_file is None:
raise ValueError("Zero-shot scores file must be provided for imulan models.")
zs_scores = parse_zs_scores(scores_file) if scores_file is not None else {}
......@@ -138,8 +140,6 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings):
)
if store_embeddings:
utils.save_embedding(seq1_embedding, embeddings_dir, f"{complex_name}_A")
utils.save_embedding(seq2_embedding, embeddings_dir, f"{complex_name}_B")
utils.save_embedding(
mut_seq1_embedding,
embeddings_dir,
......@@ -151,9 +151,13 @@ def run(model_name, input_file, scores_file, output_file, store_embeddings):
f"{complex_name}_{'-'.join([mut for mut in mutations if mut[1] == 'B'])}"
)
pbar.update(1)
if store_embeddings:
utils.save_embedding(seq1_embedding, embeddings_dir, f"{complex_name}_A")
utils.save_embedding(seq2_embedding, embeddings_dir, f"{complex_name}_B")
df = pd.DataFrame.from_records(scores)
df.to_csv(output_file, index=False, sep="\t")
df.to_csv(output_file, index=False, sep="\t", float_format="%.3f")
def main():
......
......@@ -4,7 +4,7 @@ from argparse import ArgumentParser
def get_args():
parser = ArgumentParser(
prog="train",
prog="mulan-train",
description=__doc__,
)
args = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment