Commit 516e97e8 by Gianluca Lombardi

Add README, remove torch models from utils

parent 86348a1f
# LoRA-DR-suite
![LoRA-DR-Suite](./assets/graphical_abstract.png)
## Model details
LoRA-DR-suite is a family of models for the identification of disordered regions (DR) in proteins, built upon state-of-the-art Protein Language Models (PLMs) trained on protein sequences only. They leverage Low-Rank Adaptation (LoRA) fine-tuning for binary classification of intrinsic and soft disorder.
Intrinsically-disordered residues are experimentally detected through circular dichroism and X-ray cristallography, while soft disorder is characterized by high B-factor, or intermittently
missing residues across different X-ray crystal structures of the same sequence.
Models for intrinsic disorder are trained on DisProt 7.0 data only (DisProt7 suffix) or on additional data from the first and second edition of the Critical Assesment of Intrinsic Disorder (CAID), indicated with the ID suffix.
Models for soft disorder classification are trained instead on the SoftDis dataset, derived from an extensive analysis of clusters of alternative structures for the same protein
sequence in the Protein Data Bank (PDB). For each position in the represantitive sequence of each cluster, it provides the frequency of closely-related homologs for which the corresponding residue is higly flexible or missing.
## Repository content
This repository provides code for hyperparameters optimization (HPO) and training data for the LoRA-DR-suite models. In addition, jupyter notebooks show how to reproduce performances for intrinsic and soft disorder prediction tests, and how to load and process data from SoftDis dataset.
In order to run the code, clone the repository and install the required packages in a dedicated `conda` environment with `pip install -r requirements.txt`. Referenced PyTorch versions requires CUDA 12.1.
HPO and training main script is contained in `hpo_optuna.py`. Required parameters can be inspected with the `--help` option. Example HPO configurations for the different pre-trained models and tasks are saved in `config` folder.
## Model checkpoints
We provide different model checkpoints for LoRA-DR-suite models,, based on training data and pre-trained PLM, that can be downloaded and accessed on HuggingFace Hub.
| Checkpoint name | Training dataset | Pre-trained checkpoint |
|-----------------|------------------|------------------------|
| [esm2_650M-LoRA-DisProt7](https://huggingface.co/CQSB/esm2_650M-LoRA-DisProt7) | DisProt 7.0 | [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D) |
| [esm2_35M-LoRA-DisProt7](https://huggingface.co/CQSB/esm2_35M-LoRA-DisProt7) | DisProt 7.0 | [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) |
| [Ankh-LoRA-DisProt7](https://huggingface.co/CQSB/Ankh-LoRA-DisProt7) | DisProt 7.0 | [ankh-large](https://huggingface.co/ElnaggarLab/ankh-large) |
| [PortT5-LoRA-DisProt7](https://huggingface.co/CQSB/ProtT5-LoRA-DisProt7) | DisProt 7.0 | [prot_t5_xl_uniref5](Rostlab/prot_t5_xl_uniref50) |
| [esm2_650M-LoRA-ID](https://huggingface.co/CQSB/esm2_650M-LoRA-ID) | Intrinsic dis.* | [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D) |
| [esm2_35M-LoRA-ID](https://huggingface.co/CQSB/esm2_35M-LoRA-ID) | Intrinsic dis.* | [esm2_t12_35M_UR50D](https://huggingface.co/facebook//esm2_t12_35M_UR50D) |
| [Ankh-LoRA-ID](https://huggingface.co/CQSB/Ankh-LoRA-ID) | Intrinsic dis.* | [ankh-large](https://huggingface.co/ElnaggarLab/ankh-large) |
| [PortT5-LoRA-ID](https://huggingface.co/CQSB/ProtT5-LoRA-ID) | Intrinsic dis.* | [prot_t5_xl_uniref5](Rostlab/prot_t5_xl_uniref50) |
| [esm2_650M-LoRA-SD](https://huggingface.co/CQSB/esm2_650M-LoRA-SD) | SoftDis | [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D) |
| [esm2_35M-LoRA-SD](https://huggingface.co/CQSB/esm2_35M-LoRA-SD) | SoftDis | [esm2_t12_35M_UR50D](https://huggingface.co/facebook//esm2_t12_35M_UR50D) |
| [Ankh-LoRA-SD](https://huggingface.co/CQSB/Ankh-LoRA-SD) | SoftDis | [ankh-large](https://huggingface.co/ElnaggarLab/ankh-large) |
| [PortT5-LoRA-SD](https://huggingface.co/CQSB/ProtT5-LoRA-SD) | SoftDis | [prot_t5_xl_uniref5](Rostlab/prot_t5_xl_uniref50) |
\* DisProt7, CAID1 and CAID2 data
## Intended uses & limitations
The models are intended to be used for classification of different disorder types.
Models for intrinsic disorder trained on DisProt 7.0 were evaluated on CAID1 and CAID2 challenge, but we suggest to use "ID" models for classification of new sequences, as they show better generalization.
In addition to its relation to flexibility and assembly pathways, soft disorder can be used to infer confidence score for structure prediciton tools, as we found high negative Spearman correlation between soft disorder probabilities and pLDDT from AlphaFold2 predicitons.
### Model usage
All models can be loaded as PyTorch Modules, together with their associated tokenizer, with the following code:
```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
model_id = "CQSB/ProtT5-LoRA-SD" # model_id for selected model
model = AutoModelForTokenClassification.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
```
Once the model is loadded, disorder profile for all residues in a sequence can be obtained as follow:
```python
import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# example sequence
sequence = "TAIWEQHTVTLHRAPGFGFGIAISGGRDNPHFQSGETSIVISDVLKG"
# each pre-trained model adds its own special tokens to the tokenized sequence,
# special_tokens_mask allows to deal with them (padding included, for batched
# inputs) without changing the code
inputs = tokenizer(
[sequence], return_tensors="pt", return_special_tokens_mask=True
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
special_tokens_mask = inputs['special_tokens_mask'].bool()
# extract predicted disorder probability
with torch.inference_mode():
output = model(input_ids, attention_mask).logits.cpu()
output = output[~special_tokens_mask, :]
disorder_proba = F.softmax(output, dim=-1)[:, 1]
```
## How to cite
Coming soon...
\ No newline at end of file
...@@ -26,7 +26,7 @@ from utils import ( ...@@ -26,7 +26,7 @@ from utils import (
process_config_file, process_config_file,
HF_PRETRAINED_MODELS, HF_PRETRAINED_MODELS,
) )
from torch_models import TorchModelHelper # from torch_models import TorchModelHelper
import data_utils import data_utils
...@@ -256,17 +256,17 @@ def get_model_helper(model_name, config): ...@@ -256,17 +256,17 @@ def get_model_helper(model_name, config):
model_helper = HFModelHelper(model_name, device=device) model_helper = HFModelHelper(model_name, device=device)
if config["model"]["is_peft"]: if config["model"]["is_peft"]:
model_helper.convert_to_peft() model_helper.convert_to_peft()
else: # else:
logging.info(f"Loading custom torch model: {model_name}") # logging.info(f"Loading custom torch model: {model_name}")
model_config = config["model"] # model_config = config["model"]
model_config.pop("model_name", None) # model_config.pop("model_name", None)
try: # try:
pretrained_model = model_config.pop("pretrained_model") # pretrained_model = model_config.pop("pretrained_model")
except KeyError: # except KeyError:
raise ValueError( # raise ValueError(
"Pretrained model for downstrem torch model not found in configuration file" # "Pretrained model for downstrem torch model not found in configuration file"
) # )
model_helper = TorchModelHelper(model_name, pretrained_model, device, **model_config) # model_helper = TorchModelHelper(model_name, pretrained_model, device, **model_config)
logging.debug(f"Model: {model_helper.model}") logging.debug(f"Model: {model_helper.model}")
return model_helper return model_helper
......
...@@ -19,5 +19,4 @@ peft==0.13.1 ...@@ -19,5 +19,4 @@ peft==0.13.1
datasets==3.0.1 datasets==3.0.1
optuna==4.0.0 optuna==4.0.0
wandb==0.18.5 wandb==0.18.5
tqdm==4.67.1 tqdm==4.67.1
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment