Commit f812f2cc by Gianluca Lombardi Committed by Gianluca LOMBARDI

Initial commit

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Custom files and directories
log/
cache/
notebooks/
This diff is collapsed. Click to expand it.
# MuLAN: MUtational effects with Light Attention Networks
![mulan abstract](./images/visual_abstract.png)
MuLAN is a deep learning method that leverages transfer learning from fundational protein language models
and light attention to predict mutational effects in protein complexes.
Inputs to the model are only the sequences of interacting proteins and (optionally) zero-shot scores for the considered mutations.
Attention weights extracted from the model can give insights on protein interface regions.
## Quick start
### Installation
As a prerequisite, you must have PyTorch installed to use this repository. If not, it can be installed with conda running the following:
```bash
# PyTorch 2.1.0, CUDA 12.1
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
```
For other versions or installation methods, please refer to [PyTorch documentation](https://pytorch.org/get-started/locally/).
MuLAN and its dependencies can then be installed with
```bash
git clone https://github.com/GianLMB/mulan
cd mulan
pip install .
```
We suggest to do it in a dedicated conda environment.
### Usage
We provide several command line interfaces for quick usage of MuLAN different applications:
- `mulan-predict` for $\Delta \Delta G$ prediction of single and multiple-point mutations;
- `mulan-att` to extract residues weights, related to interface regions;
- `mulan-landscape` to produce a full mutational landscape for a given complex;
- `mulan-train` to re-train the model on a custom dataset or to run cross validation (not added yet);
- `plm-embed` to extract embeddings from protein language models.
Since the script uses the `transformers` interface, only models that are saved on HuggingFace 🤗 Hub can be loaded.
Information about required inputs and usage examples for each command are provided with the `--help` flag.
## Citation
>P10599
MVKQIESKTAFQEALDAAGDKLVVVDFSATWCGPCKMIKPFFHSLSEKYSNVIFLEVDVD
DCQDVASECEVKCMPTFQFFKKGQKVGEFSGANKEKLEATINELV
>P00974
MKMSRLCLSVALLVLLGTLAASTPGCDTSNQAKAQRPDFCLEPPYTGPCKARIIRYFYNA
KAGLCQTFVYGGCRAKRNNFKSAEDCMRTCGGAIGPWENL
\ No newline at end of file
{
"hidden_size": 64,
"last_hidden_size": 40,
"kernel_sizes": [1, 5, 9],
"hidden_dropout_prob": 0.1,
"conv_dropout": 0.1,
"padding_value": 0,
"add_scores": false
}
\ No newline at end of file
{
"hidden_size": 64,
"last_hidden_size": 40,
"kernel_sizes": [1, 5, 9],
"hidden_dropout_prob": 0.1,
"conv_dropout": 0.1,
"padding_value": 0,
"add_scores": true
}
\ No newline at end of file
"""Definition of Config class"""
from typing import Union, Tuple, get_type_hints
from dataclasses import dataclass, asdict
import os
import json
@dataclass
class MulanConfig:
hidden_size: int = 64
last_hidden_size: int = 20
hidden_dropout_prob: float = 0.1
padding_value: float = 0
kernel_sizes: Union[ Tuple, int] = (1, 5, 9)
conv_dropout: float = 0.1
add_scores: bool = False
@classmethod
def from_json(cls, json_path, strict=False):
try:
with open(json_path, "r") as f:
args_dict = json.load(f)
return cls.from_dict(args_dict, strict=strict)
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {json_path}")
@classmethod
def from_dict(cls, args_dict, strict=False):
attributes = {**cls.__annotations__, **get_type_hints(cls)}
keys_to_exclude = list(set(args_dict.keys()) - set(attributes))
if not strict:
if keys_to_exclude:
print(f"Keys {keys_to_exclude} do not match {cls.__name__} attributes and are thus ignored")
args_dict = {key: value for key, value in args_dict.items() if key in attributes}
else:
if keys_to_exclude:
raise KeyError(f"Unrecognized keys {keys_to_exclude} found in input dictionary")
keys_defaulted = list(set(attributes) - set(args_dict.keys()))
if keys_defaulted:
print(f"Keys {keys_defaulted} were not found in input dictionary and are initialized to default values")
return cls(**args_dict)
def save(self, json_path=None) -> str:
if json_path is None:
json_path = os.path.join("config", "config.json")
with open(json_path, "w") as f:
json.dump(asdict(self), f, indent=4, default=lambda x: x.__dict__)
return json_path
"""Constants used in the package."""
import os
# Single letter, three letter, and full amino acid names.
aa_names = (
('A', 'ALA', 'alanine'),
('R', 'ARG', 'arginine'),
('N', 'ASN', 'asparagine'),
('D', 'ASP', 'aspartic acid'),
('C', 'CYS', 'cysteine'),
('E', 'GLU', 'glutamic acid'),
('Q', 'GLN', 'glutamine'),
('G', 'GLY', 'glycine'),
('H', 'HIS', 'histidine'),
('I', 'ILE', 'isoleucine'),
('L', 'LEU', 'leucine'),
('K', 'LYS', 'lysine'),
('M', 'MET', 'methionine'),
('F', 'PHE', 'phenylalanine'),
('P', 'PRO', 'proline'),
('S', 'SER', 'serine'),
('T', 'THR', 'threonine'),
('W', 'TRP', 'tryptophan'),
('Y', 'TYR', 'tyrosine'),
('V', 'VAL', 'valine'),
# Extended AAs
('B', 'ASX', 'asparagine or aspartic acid'),
('Z', 'GLX', 'glutamine or glutamic acid'),
('X', 'XAA', 'Any'),
('J', 'XLE', 'Leucine or isoleucine'),
)
# Indices of standard amino acids in `aa_names`.
standard_indices = tuple(range(20))
# Single letter codes of standard amino acids.
standard_aas = tuple(aa_names[i][0] for i in standard_indices)
AAs = tuple(sorted(standard_aas))
# aa_to_idx and idx_to_aa
aa2idx = dict(zip(AAs, standard_indices))
idx2aa = {v: k for k, v in aa2idx.items()}
# dictionaries for aas names conversion
one2three = dict(aa_names[i][:2] for i in standard_indices)
three2one = {v: k for k, v in one2three.items()}
# Models names and paths
_BASE_DIR = os.getcwd() # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_DEFAULT_MODELS_DIR = os.path.join(_BASE_DIR, "models/pretrained")
MODELS_DIR = os.environ.get("MULAN_MODELS_DIR", _DEFAULT_MODELS_DIR)
MODELS = {
"mulan-esm": f"{MODELS_DIR}/mulan_esm.ckpt",
"mulan-esm-multiple": f"{MODELS_DIR}/mulan_esm_multiple.ckpt",
"imulan-esm": f"{MODELS_DIR}/imulan_esm.ckpt",
"mulan-ankh": f"{MODELS_DIR}/mulan_ankh.ckpt",
"imulan-ankh": f"{MODELS_DIR}/imulan_ankh.ckpt",
"mulan-ankh-multiple": f"{MODELS_DIR}/mulan_ankh_multiple.ckpt",
}
# PLMs encoders and HuggingFace Hub ids
PLM_ENCODERS = {
"esm": "facebook/esm2_t36_3B_UR50D",
"ankh": "ElnaggarLab/ankh-large",
"esm_35M": "facebook/esm2_t12_35M_UR50D",
"esm_650M": "facebook/esm2_t33_650M_UR50D",
"ankh_base": "ElnaggarLab/ankh-base",
"protbert": "Rostlab/prot_bert",
"prott5_xl_half": "Rostlab/prot_t5_xl_half_uniref50-enc",
}
\ No newline at end of file
"""Implementation of Light attention model"""
from typing import Union, Sequence, Optional, List
from dataclasses import dataclass
import torch
import torch.nn as nn
from mulan.config import MulanConfig
@dataclass
class OutputWithAttention:
output: torch.Tensor = None
attention: Union[torch.Tensor, Sequence[torch.Tensor]] = None
class AttentionMeanK(nn.Module):
def __init__(self, config):
super().__init__()
self.padding_value = config.padding_value
kernel_sizes = list(config.kernel_sizes)
self.n_attn = len(kernel_sizes)
self.hidden_size = config.hidden_size
self.last_hidden_size = config.last_hidden_size
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(config.conv_dropout)
self.conv_heads = nn.ModuleList(
[
nn.LazyConv1d(self.hidden_size, kernel_size, stride=1, padding=kernel_size // 2)
for kernel_size in kernel_sizes
]
)
self.attn_heads = nn.ModuleList(
[
nn.LazyConv1d(self.hidden_size, kernel_size, stride=1, padding=kernel_size // 2)
for kernel_size in kernel_sizes
]
)
self.fc = nn.Sequential(
nn.Linear(
self.hidden_size * self.n_attn * 2, self.last_hidden_size * self.n_attn
), # n_attn + concatenated maxpool
nn.Dropout(config.hidden_dropout_prob),
nn.LeakyReLU(0.5),
nn.Linear(self.last_hidden_size * self.n_attn, self.last_hidden_size // 2),
nn.LeakyReLU(0.5),
)
def forward(self, x):
# feature convolution
batch_size, length, hidden = x.shape
mask = x[..., 0] != self.padding_value
x = x.transpose(-1, -2)
o = torch.stack([conv(x) for conv in self.conv_heads], dim=1)
o = self.dropout(o) # [batch_size, embeddings_dim, sequence_length]
# attention weights
attn_weights = torch.stack(
[
self.softmax(head(x).masked_fill(mask[:, None, :] == False, -1e9))
for head in self.attn_heads
],
dim=1,
)
# print(attn_weights.shape, o.shape)
o1 = torch.sum(o * attn_weights, dim=-1).view(batch_size, -1)
# print(o1.shape)
# max pooling
o2, _ = torch.max(o, dim=-1)
o2 = o2.view(batch_size, -1)
# mlp
o = torch.cat([o1, o2], dim=-1)
o = self.fc(o)
# attn_mean = torch.softmax(attn_weigths.mean(dim=-1)
output = OutputWithAttention(o, attn_weights)
return output
class LightAttModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = AttentionMeanK(config)
last_hidden_size = config.last_hidden_size
if config.add_scores:
last_hidden_size = last_hidden_size + 1
self.linear = nn.Linear(last_hidden_size, 1)
def forward(
self,
inputs_embeds: List[torch.FloatTensor],
zs_scores: Optional[torch.FloatTensor] = None,
output_attentions=False,
):
batch_size = inputs_embeds[0].shape[0]
# Siamese encoder
encodings = [
self.encoder(emb) for emb in inputs_embeds
] # The order is [wt1, wt2, mut1, mut2]
# features combination
x_wt = torch.cat(
[
encodings[0].output * encodings[1].output,
torch.abs(encodings[0].output - encodings[1].output),
],
dim=1,
)
x_mut = torch.cat(
[
encodings[2].output * encodings[3].output,
torch.abs(encodings[2].output - encodings[3].output),
],
dim=1,
)
output = x_mut - x_wt
if zs_scores is not None and (self.config.add_scores or self.config.add_columns_scores):
output = torch.cat((output, zs_scores.view(batch_size, -1)), dim=-1)
output = self.linear(output).squeeze(-1)
if not output_attentions:
return output
else:
return OutputWithAttention(output, tuple([enc.attention for enc in encodings[:2]]))
@classmethod
def from_pretrained(cls, pretrained_model_path, device=None, **kwargs):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(pretrained_model_path, map_location=device, weights_only=False)
config = ckpt["config"]
config.update(kwargs)
config = MulanConfig(**config)
state_dict = ckpt["state_dict"]
model = cls(config).to(device)
model.load_state_dict(state_dict, strict=False)
return model
"""Util functions to process data and models."""
from typing import List, Tuple
import os
import re
import torch
import numpy as np
from scipy.stats import rankdata
import mulan.constants as C
def mutation_generator(sequence):
"""Generate all possible single-point mutations for a given sequence."""
for i, aa in enumerate(sequence):
for new_aa in C.AAs:
if new_aa != aa:
yield (f"{aa}{i + 1}{new_aa}", sequence[:i] + new_aa + sequence[i + 1 :])
def listed_mutation_generator(sequence1, sequence2, mutations):
"""Generate mutated sequences from a list of mutations."""
for mutation in mutations:
seq1, seq2 = list(sequence1), list(sequence2)
for single_mut in mutation:
chain = single_mut[1]
if chain == "A":
seq1[int(single_mut[2:-1]) - 1] = single_mut[-1]
else:
seq2[int(single_mut[2:-1]) - 1] = single_mut[-1]
yield "".join(seq1), "".join(seq2)
def parse_mutations(mutations: Tuple[str], seq1: str, seq2: str) -> List[Tuple[str, str]]:
seq1, seq2 = list(seq1), list(seq2)
for single_mut in mutations:
chain = single_mut[1]
if chain == "A":
seq1[int(single_mut[2:-1]) - 1] = single_mut[-1]
else:
seq2[int(single_mut[2:-1]) - 1] = single_mut[-1]
return "".join(seq1), "".join(seq2)
def alphabetic_tokens_permutation(tokenizer):
"""Permute the tokenizer vocabulary."""
vocab = tokenizer.get_vocab()
aas_idx = [vocab[tok] for tok in C.AAs]
return aas_idx
def parse_fasta(fasta_file):
"""Parse a fasta file and return a dictionary."""
with open(fasta_file) as f:
lines = f.readlines()
fasta_dict = {}
for line in lines:
if line.startswith(">"):
key = line.strip().split()[0][1:]
fasta_dict[key] = ""
else:
fasta_dict[key] += line.strip().upper()
return fasta_dict
def dict_to_fasta(fasta_dict, fasta_file):
"""Write a dictionary to a fasta file."""
with open(fasta_file, "w") as f:
for key, value in fasta_dict.items():
f.write(f">{key}\n")
f.write(f"{value}\n")
def get_plm_names():
"""Return the names of the available pretrained language models."""
return list(C.PLM_ENCODERS.keys())
def load_pretrained_plm(model_name, device="cpu"):
model_id = C.PLM_ENCODERS.get(model_name)
if model_id is None:
raise ValueError(f"Invalid model_name: {model_name}. Must be one of {get_plm_names()}")
if "t5" in model_id or "ankh" in model_id:
from transformers import T5EncoderModel
model = T5EncoderModel.from_pretrained(model_id)
if "ankh" in model_id:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
else:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained(model_id)
else:
try:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
except Exception as e:
raise e
model = model.to(device)
model = model.eval()
return model, tokenizer
@torch.inference_mode()
def embed_sequence(plm_model, plm_tokenizer, sequence):
"""Embed a sequence using a pretrained model."""
sequence = sequence.upper()
sequence = re.sub(r"[UZOB]", "X", sequence) # always replace non-canonical AAs with X
# Pre-process sequence for ProtTrans models
if "Rostlab/prot" in plm_tokenizer.name_or_path:
sequence = " ".join(sequence)
inputs = plm_tokenizer(
sequence,
return_tensors="pt",
add_special_tokens=True,
return_special_tokens_mask=True,
).to(plm_model.device)
embedding = (
plm_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
.last_hidden_state[~inputs["special_tokens_mask"].bool()]
.unsqueeze(0)
)
return embedding
def save_embedding(embedding, output_dir, name):
"""Save an embedding to disk."""
embedding = embedding.squeeze(0).cpu()
torch.save(embedding, os.path.join(output_dir, name + ".pt"))
def ranksort(array: np.ndarray) -> np.ndarray:
"""Ranksort an array."""
return (rankdata(array) / array.size).reshape(array.shape)
def minmax_scale(array: np.ndarray) -> np.ndarray:
"""Min-max scale an array."""
return (array - array.min()) / (array.max() - array.min())
numpy
scipy
pandas
tqdm
h5py
torch>=2.0
transformers<4.45,>=4.27
sentencepiece
\ No newline at end of file
"""Compute full mutational landscape of a protein in a given complex."""
import os
import re
from argparse import ArgumentParser
from tqdm import tqdm
import pandas as pd
import torch
import numpy as np
from mulan import utils, constants as C
from mulan.modules import LightAttModel
def get_args():
parser = ArgumentParser(
prog="mulan-landscape",
description=__doc__,
)
parser.add_argument(
"sequences",
type=str,
help="""Sequences strings, separated by column character.
The first is the sequence to be scored, the other is the partner.""",
)
parser.add_argument(
"-m", "--model-name",
type=str,
default="mulan-ankh",
help=f"""Name of the pre-trained model. Must be one of: {C.MODELS.keys()}.
If the model is a version of imulan, a TXT file containing zero-shot scores
must be provided, where lines correspond to mutated positions and columns
to amino acids, in alphabetic order for the single letter name, separated
by spaces.""",
)
parser.add_argument(
"-s",
"--scores-file",
type=str,
default=None,
help="TXT File containing zero-shot scores for imulan model.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="output",
help="Output directory. Set to 'output' by default.",
)
parser.add_argument(
"--no-ranksort", action="store_true", help="Do not ranksort computed scores."
)
parser.add_argument(
"-e",
"--embeddings-dir",
type=str,
default=None,
help="If not None, directory to store embeddings in PT format.",
)
args = parser.parse_args()
if args.model_name not in C.MODELS:
raise ValueError(f"Model name must be one of: {C.MODELS.keys()}")
args.sequences = args.sequences.upper()
if "imulan" in args.model_name and args.scores_file is None:
raise ValueError("Zero-shot scores file must be provided for imulan models.")
return args
@torch.inference_mode()
def score_mutation(
model,
plm_model,
plm_tokenizer,
wildtype_embedding,
partner_embedding,
mut_name,
mutation,
zs_score,
embeddings_dir,
):
"""Score a mutation in a protein sequence."""
mutation_embedding = utils.embed_sequence(plm_model, plm_tokenizer, mutation)
score = (
model(
inputs_embeds=[
wildtype_embedding,
partner_embedding,
mutation_embedding,
partner_embedding
],
zs_scores=zs_score,
)
.squeeze()
.item()
)
if embeddings_dir is not None:
utils.save_embedding(mutation_embedding, embeddings_dir, mut_name)
return score
def run(
model_name, sequences, output_dir, scores_file=None, ranksort_output=True, embeddings_dir=None
):
"""Compute full mutational landscape of a protein in a given complex."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq1, seq2 = re.sub(r"[UZOB]", "X", sequences).split(":")
output = torch.zeros(len(seq1), len(C.AAs))
os.makedirs(output_dir, exist_ok=True)
if embeddings_dir is not None:
os.makedirs(embeddings_dir, exist_ok=True)
# load models
plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name])
model.eval()
if scores_file is not None:
zs_scores = np.loadtxt(scores_file)
zs_scores = torch.tensor(zs_scores, dtype=torch.float32)
# embed wildtype sequences
wildtype_embedding = utils.embed_sequence(plm_model, plm_tokenizer, seq1)
partner_embedding = utils.embed_sequence(plm_model, plm_tokenizer, seq2)
if embeddings_dir is not None:
utils.save_embedding(wildtype_embedding, embeddings_dir, "WT")
utils.save_embedding(partner_embedding, embeddings_dir, "PARTNER")
# iterate over single-point mutations
num_iterations = len(seq1) * (len(C.AAs) - 1)
pbar = tqdm(initial=0, total=num_iterations, colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Scoring mutations")
for mut_name, mutation in utils.mutation_generator(seq1):
i, aa = int(mut_name[1:-1]) - 1, mut_name[-1]
zs_score = None if scores_file is None else zs_scores[i, C.aa2idx[aa]]
score = score_mutation(
model,
plm_model,
plm_tokenizer,
wildtype_embedding,
partner_embedding,
mut_name,
mutation,
zs_score,
embeddings_dir,
)
output[i, C.aa2idx[aa]] = score
pbar.update(1)
# save output
output = output.cpu().numpy()
if ranksort_output:
wt_index = (range(len(seq1)), tuple([C.aa2idx[aa] for aa in seq1]))
output[wt_index] = -10000 # set to very low value to give lowest scores to wt residues
output = utils.ranksort(output)
output = pd.DataFrame(output, columns=C.AAs, index=[f"{i+1}{aa}" for i, aa in enumerate(seq1)])
output.to_csv(os.path.join(output_dir, "landscape.csv"), float_format="%.3f")
def main():
args = get_args()
run(
args.model_name,
args.sequences,
args.output_dir,
args.scores_file,
not args.no_ranksort,
args.embeddings_dir,
)
if __name__ == "__main__":
main()
\ No newline at end of file
"""Extract normalized attention weights from pre-trained MuLAN model for sequences in a file."""
import os
from argparse import ArgumentParser
from tqdm import tqdm
import torch
import h5py # type: ignore
from mulan import utils, constants as C
from mulan.modules import LightAttModel
def get_args():
parser = ArgumentParser(
prog="extract_attentions",
description=__doc__,
)
parser.add_argument(
"fasta_file",
type=str,
help="Fasta file containing sequences to extract attention weights."
)
parser.add_argument(
"--model-name",
type=str,
default="mulan-ankh",
help=f"""
Name of the pre-trained model. Must be one of:
{list(C.MODELS.keys())}.""",
)
parser.add_argument(
"-o", "--output-file", type=str, default="attentions.h5", help="Output H5 file"
)
parser.add_argument(
"-e",
"--embeddings-dir",
type=str,
default=None,
help="If not None, directory to store embeddings in PT format.",
)
args = parser.parse_args()
if args.model_name not in C.MODELS:
raise ValueError(f"Invalid model name: {args.model_name}")
return args
@torch.inference_mode
def run(model_name, fasta_file, output_file, embeddings_dir=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = utils.parse_fasta(fasta_file)
plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name])
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if embeddings_dir is not None:
os.makedirs(embeddings_dir, exist_ok=True)
pbar = tqdm(initial=0, total=len(dataset), colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Extracting attention weights")
with h5py.File(output_file, "w") as f:
for name, sequence in dataset.items():
# embed sequence
embedding = utils.embed_sequence(plm_model, plm_tokenizer, sequence)
# compute attention weights
attention = model([embedding] * 4, output_attentions=True).attention[0]
attention = attention.squeeze(0).cpu().numpy().mean(axis=(-2, -3))
attention = utils.minmax_scale(attention)
# save attention weights
f.create_dataset(name, data=attention)
if embeddings_dir is not None:
utils.save_embedding(embedding, embeddings_dir, name)
pbar.update(1)
def main():
args = get_args()
run(args.model_name, args.fasta_file, args.output_file, args.embeddings_dir)
if __name__ == "__main__":
main()
\ No newline at end of file
"""Generate proteins embeddings with transformers pretrained models.
Embeddings are stored in PT format."""
import os
from argparse import ArgumentParser
from tqdm import tqdm
import torch
from mulan import utils, constants as C
def get_args():
parser = ArgumentParser(
prog="generate_embeddings",
description=__doc__,
)
parser.add_argument(
"fasta_file",
type=str,
help="Path to FASTA file containing sequences to be encoded",
)
parser.add_argument(
"model_name",
type=str,
default="ankh_large",
help=f"PLM name to be loaded. Must be one of {C.PLM_ENCODERS.keys()}",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="./embeddings",
help="Output directory. Defaults to './embeddings'",
)
args = parser.parse_args()
return args
@torch.inference_mode()
def run(model_name, fasta_file, output_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tokenizer = utils.load_pretrained_plm(model_name, device=device)
dataset = utils.parse_fasta(fasta_file)
os.makedirs(output_dir, exist_ok=True)
pbar = tqdm(initial=0, total=len(dataset), colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Embedding sequences")
for name, sequence in dataset.items():
# embed sequence
embedding = utils.embed_sequence(model, tokenizer, sequence)
utils.save_embedding(embedding, output_dir, name)
pbar.update(1)
def main():
args = get_args()
run(args.model_name, args.fasta_file, args.output_dir)
if __name__ == "__main__":
main()
"""Predict scores for given mutations in a protein complex."""
import os
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import torch
import pandas as pd
import mulan.constants as C
from mulan import utils
from mulan.modules import LightAttModel
def get_args():
parser = ArgumentParser(
prog="predict",
description=__doc__,
)
parser.add_argument(
"model_name",
type=str,
default="mulan-ankh",
help=f"Name of the pre-trained model. Must be one of: {C.MODELS.keys()}.",
)
parser.add_argument(
"input_file",
type=str,
help="""
Input file containing wild type sequences and mutations to be scored.
Each line must contain columns with the name of the complex, the interacting
sequences and the mutations to be scored, separated by a comma, in the format
<wt_aa><chain:A,B><position><mut_aa>.
Multiple-point mutations can be provided, separated by a column.
Example: 'C1 SEQ1 SEQ2 AA1G:AA2T,CA3C:QB4A' """,
)
parser.add_argument(
"-s",
"--scores-file",
type=str,
default=None,
help="""
TXT File containing zero-shot scores for imulan model. Each line must contain
the name of the complex, the correesponding mutation with the same format as in
'input_file' and the score, separated by a white spaces.""",
)
parser.add_argument(
"-o",
"--output-file",
type=str,
default="output.txt",
help="Output file. Set to 'output.txt' by default.",
)
parser.add_argument(
"--store-embeddings",
action="store_true",
help="Store embeddings in PT format. Only wild type embeddings are stored.",
)
args = parser.parse_args()
return args
def parse_input(input_file):
"""Parse input file and return a list of records."""
data = []
with open(input_file) as f:
for line in f:
complex_name, seq1, seq2, mutations = line.strip().split()
mutations = [tuple(m.split(":")) for m in mutations.split(",")]
data.append(
{
"complex": complex_name,
"sequences": (seq1, seq2),
"mutations": mutations,
}
)
return data
def parse_zs_scores(scores_file):
"""Parse zero-shot scores file and return a dictionary."""
zs_scores = defaultdict(dict)
with open(scores_file) as f:
for line in f:
complex_name, mutations, score = line.strip().split()
mutations = [tuple(m.split(":")) for m in mutations.split(",")]
zs_scores[complex_name].update(
{mutations: torch.tensor(score, dtype=torch.float32).unsqueeze(0)}
)
return zs_scores
@torch.inference_mode()
def run(model_name, input_file, scores_file, output_file, store_embeddings):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = parse_input(input_file)
num_iterations = sum(len(d["mutations"]) for d in data)
scores = []
plm_name = model_name.split("-")[1]
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_name, device=device)
model = LightAttModel.from_pretrained(C.MODELS[model_name])
if "imulan" in model_name and scores_file is None:
raise ValueError("Zero-shot scores file must be provided for imulan models.")
zs_scores = parse_zs_scores(scores_file) if scores_file is not None else {}
if store_embeddings:
embeddings_dir = os.path.dirname(output_file)
os.makedirs(embeddings_dir, exist_ok=True)
pbar = tqdm(initial=0, total=num_iterations, colour="red", dynamic_ncols=True, ascii="-#")
pbar.set_description("Running prediction")
for complex in data:
# compute embeddings for wt sequences
complex_name = complex["complex"]
seq1, seq2 = complex["sequences"]
seq1_embedding = utils.embed_sequence(plm_model, plm_tokenizer, seq1)
seq2_embedding = utils.embed_sequence(plm_model, plm_tokenizer, seq2)
for mutations in complex["mutations"]:
mut_seq1, mut_seq2 = utils.parse_mutations(mutations, seq1, seq2)
if mut_seq1 != seq1:
mut_seq1_embedding = utils.embed_sequence(plm_model, plm_tokenizer, mut_seq1)
else:
mut_seq1_embedding = seq1_embedding
if mut_seq2 != seq2:
mut_seq2_embedding = utils.embed_sequence(plm_model, plm_tokenizer, mut_seq2)
else:
mut_seq2_embedding = seq2_embedding
inputs = [seq1_embedding, seq2_embedding, mut_seq1_embedding, mut_seq2_embedding]
score = (
model(
inputs_embeds=inputs,
zs_scores=zs_scores.get(complex_name, {}).get(mutations, None),
)
.squeeze()
.item()
)
scores.append(
{"complex": complex_name, "mutations": ":".join(mutations), "score": score}
)
if store_embeddings:
utils.save_embedding(seq1_embedding, embeddings_dir, f"{complex_name}_A")
utils.save_embedding(seq2_embedding, embeddings_dir, f"{complex_name}_B")
utils.save_embedding(
mut_seq1_embedding,
embeddings_dir,
f"{complex_name}_{'-'.join([mut for mut in mutations if mut[1] == 'A'])}"
)
utils.save_embedding(
mut_seq2_embedding,
embeddings_dir,
f"{complex_name}_{'-'.join([mut for mut in mutations if mut[1] == 'B'])}"
)
pbar.update(1)
df = pd.DataFrame.from_records(scores)
df.to_csv(output_file, index=False, sep="\t")
def main():
args = get_args()
run(
args.model_name, args.input_file, args.scores_file, args.output_file, args.store_embeddings
)
if __name__ == "__main__":
main()
"""Train Mulan model on custom data. Not implemented yet!"""
from argparse import ArgumentParser
def get_args():
parser = ArgumentParser(
prog="train",
description=__doc__,
)
args = parser.parse_args()
return args
def main():
get_args()
raise NotImplementedError("This script has not been implemented yet.")
if __name__ == "__main__":
main()
from setuptools import setup
def parse_requirements(filename):
with open(filename, 'r') as file:
return [line.strip() for line in file if line] # and not line.startswith('#')]
with open("README.md", "r") as f:
readme = f.read()
sources = {
"mulan": "mulan",
"mulan.scripts": "scripts",
}
setup(
name='mulan',
version='0.1.0',
description="MuLAN: MUtational effects with Light Attention Networks",
long_description=readme,
long_description_content_type='text/markdown',
author="Gianluca Lombardi",
author_email="gianluca.lombardi@sorbonne-universite.fr",
url="https://github.com/GianLMB/mulan",
license="CC BY-NC-SA 4.0",
packages=sources.keys(),
package_dir=sources,
python_requires='>=3.9',
install_requires= parse_requirements('requirements.txt'),
entry_points={
'console_scripts': [
'plm-embed=mulan.scripts.generate_embeddings:main',
'mulan-predict=mulan.scripts.predict:main',
'mulan-att=mulan.scripts.extract_attentions:main',
'mulan-landscape=mulan.scripts.compute_landscape:main',
'mulan-train=mulan.scripts.train:main',
],
},
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment