Commit 102d58ee by Gianluca LOMBARDI

Update types and imports, add train script and utils and example input files

parent f3163134
...@@ -52,7 +52,7 @@ We provide several command line interfaces for quick usage of MuLAN different ap ...@@ -52,7 +52,7 @@ We provide several command line interfaces for quick usage of MuLAN different ap
- `mulan-predict` for $\Delta \Delta G$ prediction of single and multiple-point mutations; - `mulan-predict` for $\Delta \Delta G$ prediction of single and multiple-point mutations;
- `mulan-att` to extract residues weights, related to interface regions; - `mulan-att` to extract residues weights, related to interface regions;
- `mulan-landscape` to produce a full mutational landscape for a given complex; - `mulan-landscape` to produce a full mutational landscape for a given complex;
- `mulan-train` to re-train the model on a custom dataset or to run cross validation (not supported yet); - `mulan-train` to re-train the model on a custom dataset or to run cross validation;
- `plm-embed` to extract embeddings from protein language models. - `plm-embed` to extract embeddings from protein language models.
Since the script uses the `transformers` interface, only models that are saved on HuggingFace 🤗 Hub can be loaded. Since the script uses the `transformers` interface, only models that are saved on HuggingFace 🤗 Hub can be loaded.
......
1A22 FPTIPLSRLFDNAMLRAHRLHQLAFDTYQEFEEAYIPKEQKYSFLQNPQTSLCFSESIPTPSNREETQQKSNLELLRISLLLIQSWLEPVQFLRSVFANSLVYGASDSNVYDLLKDLEERIQTLMGRLEGQIFKQTYSKFDTDALLKNYGLLYCFRKDMDKVETFLRIVQCRSVEGSCGF PKFTKCRSPERETFSCHWTLGPIQLFYTRRNTQEWTQEWKECPDYVSAGENSCYFNSSFTSIWIPYCIKLTSNGGTVDEKCFSVDEIVQPDPPIALNWTLLNGIHADIQVRWEAPRNADIQKGWMVLEYELQYKEVNETKWKMMDPILTTSVPVYSLKVDKEYEVRVRSKQRNSGNYGEFSEVLYVTLPQMS CA171A,CB67A,EB12A,FA25A,PA61A,RB11M
\ No newline at end of file
from .config import MulanConfig # noqa from .config import MulanConfig
from .modules import LightAttModel # noqa from .modules import LightAttModel
from .utils import ( from .utils import (
load_pretrained, load_pretrained,
load_pretrained_plm, load_pretrained_plm,
get_available_models, get_available_models,
get_available_plms, get_available_plms,
) # noqa )
\ No newline at end of file \ No newline at end of file
from typing import Dict, List, NamedTuple, Tuple, Optional
import os
from tqdm import tqdm
import pandas as pd
import torch
from torch.utils.data import Dataset, default_collate
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from mulan import utils
class MutatedComplex(NamedTuple):
sequence_A: str
sequence_B: str
mutations: Tuple[str]
class MutatedComplexEmbeds(NamedTuple):
seq1: torch.Tensor
seq2: torch.Tensor
mut_seq1: torch.Tensor
mut_seq2: torch.Tensor
class MulanDataset(Dataset):
def __init__(
self,
mutated_complexes: List[MutatedComplex],
wt_sequences: Dict[str, str],
embeddings_dir: str,
plm_model_name: str = None,
scores: List[float] = None,
zs_scores: List[float] = None,
):
self.sequences = wt_sequences
self.embeddings_dir = embeddings_dir
self.mutated_complexes = mutated_complexes
self.zs_scores = zs_scores
self.scores = scores
self._sequences_ids = []
self._fill_metadata(mutated_complexes)
# generate embeddings if not provided
all_ids = set([id_ for ids in self._sequences_ids for id_ in ids])
provided_embeddings_ids = (
[os.path.splitext(file)[0] for file in os.listdir(self.embeddings_dir)]
if os.path.exists(self.embeddings_dir)
else []
)
missing_ids = all_ids - set(provided_embeddings_ids)
if missing_ids:
if not plm_model_name:
raise ValueError(
"`plm_model_name` must be provided if embeddings were not pre-computed."
)
self._generate_missing_embeddings(plm_model_name, missing_ids)
def __len__(self):
return len(self.mutated_complexes)
def __getitem__(self, index):
return {
"data": self.mutated_complexes[index],
"inputs_embeds": self._load_embeddings(index),
"zs_scores": (
torch.tensor(self.zs_scores[index], dtype=torch.float32)
if self.zs_scores
else None
),
"labels": (
torch.tensor(self.scores[index], dtype=torch.float32) if self.scores else None
),
}
@classmethod
def from_table(
cls,
mutated_complexes_file: str,
wt_sequences_file: str,
embeddings_dir: str,
plm_model_name: str = None,
):
wt_sequences = utils.parse_fasta(wt_sequences_file)
# parse table file
data = pd.read_table(mutated_complexes_file, sep=r"\s+", header=None)
mutated_complexes = [
MutatedComplex(row[0], row[1], tuple(row[2].split(",")))
for row in data.itertuples(index=False)
]
scores, zs_scores = None, None
if len(data.columns) > 3:
scores = data[3].astype(float).tolist()
if len(data.columns) > 4:
zs_scores = data[4].astype(float).tolist()
return cls(
mutated_complexes, wt_sequences, embeddings_dir, plm_model_name, scores, zs_scores
)
def _fill_metadata(self, mutated_complexes):
for seq1_label, seq2_label, mutations in mutated_complexes:
seq1 = self.sequences[seq1_label]
seq2 = self.sequences[seq2_label]
mut_seq1, mut_seq2 = utils.parse_mutations(mutations, seq1, seq2)
mut_seq1_label = (
f"{seq1_label}_{'-'.join([mut for mut in mutations if mut[1] == 'A'])}"
)
mut_seq2_label = (
f"{seq2_label}_{'-'.join([mut for mut in mutations if mut[1] == 'B'])}"
)
self.sequences.update({mut_seq1_label: mut_seq1, mut_seq2_label: mut_seq2})
self._sequences_ids.append((seq1_label, seq2_label, mut_seq1_label, mut_seq2_label))
return
def _generate_missing_embeddings(self, plm_model_name, missing_ids):
plm_model, plm_tokenizer = utils.load_pretrained_plm(plm_model_name)
os.makedirs(self.embeddings_dir, exist_ok=True)
for id_ in tqdm(missing_ids, desc="Generating embeddings"):
seq = self.sequences[id_]
embedding = utils.embed_sequence(plm_model, plm_tokenizer, seq)
utils.save_embedding(embedding, self.embeddings_dir, id_)
# del plm_model, plm_tokenizer
return
def _load_embeddings(self, index):
return MutatedComplexEmbeds(
*[
torch.load(os.path.join(self.embeddings_dir, f"{id_}.pt"), weights_only=True)
for id_ in self._sequences_ids[index]
]
)
class MulanDataCollator(object):
def __init__(self, padding_value: float = 0.0):
self.padding_value = padding_value
def __call__(self, batch):
return self._collate_fn(batch)
def _collate_fn(self, batch):
elem = batch[0]
if isinstance(elem, dict):
return {key: self._collate_fn([d[key] for d in batch]) for key in elem}
if isinstance(elem, MutatedComplexEmbeds):
return MutatedComplexEmbeds(
*[
pad_sequence(embeds, batch_first=True, padding_value=self.padding_value)
for embeds in (zip(*batch))
]
)
elif elem is None:
return None
else:
return default_collate(batch)
def split_data(
mutated_complexes_file: str,
output_dir: Optional[str] = None,
add_validation_set: bool = True,
validation_size: float = 0.15,
test_size: float = 0.15,
num_folds: int = 1,
random_state: int = 42,
):
"""Split data into train, validation and test sets for training or cross-validation."""
def _save_data(data, output_file):
data.to_csv(output_file, sep="\t", index=False, header=False)
train_data_all, test_data_all = [], []
val_data_all = [] if add_validation_set else None
files_basename = os.path.splitext(os.path.basename(mutated_complexes_file))[0]
data = pd.read_table(mutated_complexes_file, sep=r"\s+", header=None)
rng = np.random.default_rng(random_state)
if num_folds <= 0:
raise ValueError("`num_folds` must be greater than 0.")
elif num_folds == 2 and add_validation_set:
raise ValueError("`num_folds` must be greater than 2 to add a validation set.")
elif num_folds == 1:
split_index = rng.choice(
[0, 1, 2],
size=len(data),
p=[test_size, validation_size, 1 - test_size - validation_size],
)
test_data = data[split_index == 0]
if add_validation_set:
val_data = data[split_index == 1]
train_data = data[split_index == 2]
val_data_all.append(val_data)
else:
train_data = data[(split_index == 1) & (data[split_index == 2])]
train_data_all.append(train_data)
test_data_all.append(test_data)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
_save_data(train_data, os.path.join(output_dir, f"{files_basename}_train.tsv"))
_save_data(test_data, os.path.join(output_dir, f"{files_basename}_test.tsv"))
if add_validation_set:
_save_data(val_data, os.path.join(output_dir, f"{files_basename}_val.tsv"))
else:
fold_index = rng.integers(low=0, high=num_folds, size=len(data))
for test_fold_index in range(num_folds):
test_data = data[fold_index == test_fold_index]
if add_validation_set:
val_fold_index = (test_fold_index - 1) % num_folds
val_data = data[fold_index == val_fold_index]
val_data_all.append(val_data)
train_data = data[(fold_index != test_fold_index) & (fold_index != val_fold_index)]
else:
train_data = data[fold_index != test_fold_index]
train_data_all.append(train_data)
test_data_all.append(test_data)
if output_dir:
os.makedirs(os.path.join(output_dir, f"fold_{test_fold_index}"), exist_ok=True)
_save_data(
train_data,
os.path.join(
output_dir, f"fold_{test_fold_index}", f"{files_basename}_train.tsv"
),
)
_save_data(
test_data,
os.path.join(
output_dir, f"fold_{test_fold_index}", f"{files_basename}_test.tsv"
),
)
if add_validation_set:
_save_data(
val_data,
os.path.join(
output_dir, f"fold_{test_fold_index}", f"{files_basename}_val.tsv"
),
)
return train_data_all, test_data_all, val_data_all
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from transformers import Trainer
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
from scipy.stats import pearsonr, spearmanr
from mulan.data import MutatedComplexEmbeds, MutatedComplex
def _metric_spearmanr(y_true, y_pred):
return spearmanr(y_true, y_pred, nan_policy="omit")[0]
def _metric_pearsonr(y_true, y_pred):
return pearsonr(y_true, y_pred)[0]
_DEFAULT_METRICS = {
"mae": mean_absolute_error,
"rmse": root_mean_squared_error,
"pcc": _metric_pearsonr,
"scc": _metric_spearmanr,
}
def default_compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.flatten()
labels = labels.flatten()
res = {}
for name, metric in _DEFAULT_METRICS.items():
res[name] = metric(labels, predictions)
return res
@dataclass
class DatasetArguments:
train_data: str = field(
metadata={
"help": (
"Training data in TSV format. Must contain columns for sequence_A, sequence_B,"
" mutations (separated with comma if multiple), score and (optionally) zero-shot"
" score."
)
}
)
train_fasta_file: str = field(
metadata={
"help": (
"Fasta file containing wild-type sequences for training data. Identifiers must"
" match the training data and, if present, the evaluation data."
)
}
)
embeddings_dir: str = field(
metadata={
"help": (
"Directory containing pre-computed embeddings in PT format, or where new"
" embeddings will be stored. In the latter case, `plm_model_name` must be"
" provided."
)
}
)
eval_data: Optional[str] = field(
default=None,
metadata={"help": "Evaluation data file, with the same format of training data."},
)
test_data: Optional[str] = field(
default=None,
metadata={"help": "Test data file, with the same format of training data."},
)
test_fasta_file: Optional[str] = field(
default=None,
metadata={
"help": (
"Fasta file containing wild-type sequences. Identifiers must match the test data."
)
},
)
plm_model_name: Optional[str] = field(
default=None,
metadata={
"help": (
"Name of the pre-trained protein language model to use for embedding generation."
)
},
)
@dataclass
class ModelArguments:
model_name_or_config_path: str = field(
metadata={
"help": (
"Name of the pre-trained model to fine-tune, or path to config file in JSON"
" format."
)
}
)
save_model: bool = field(
default=False,
metadata={"help": "Whether to save the model after training."},
)
@dataclass
class CustomisableTrainingArguments:
output_dir: str = field(metadata={"help": "Directory where the trained model will be saved."})
num_epochs: int = field(default=30, metadata={"help": "Number of training epochs."})
batch_size: int = field(default=8, metadata={"help": "Batch size."})
learning_rate: float = field(default=5e-4, metadata={"help": "Learning rate."})
disable_tqdm: bool = field(
default=False, metadata={"help": "Whether to disable tqdm progress bars."}
)
report_to: Union[None, str, List[str]] = field(
default="none",
metadata={"help": "The list of integrations to report the results and logs to."},
)
early_stopping_patience: Optional[int] = field(
default=None,
metadata={
"help": (
"Number of epochs without improvement before early stopping. If not set, early"
" stopping is disabled."
)
},
)
class MulanTrainer(Trainer):
"""Custom Trainer class adapted for Mulan model training"""
def compute_loss(self, model, inputs, return_outputs=False):
"""
Computes the loss for MulanDataset inputs for a model that do not return loss values.
"""
inputs.pop("data")
outputs = model(inputs["inputs_embeds"], inputs.get("zs_scores"))
labels = inputs.get("labels")
loss = torch.nn.functional.mse_loss(outputs.view(-1), labels.view(-1))
return (loss, outputs) if return_outputs else loss
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
Adapted from the parent class to handle the case where the input is a custom type.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (MutatedComplexEmbeds, MutatedComplex)):
return type(data)(*[self._prepare_input(v) for v in data])
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.args.device}
if self.is_deepspeed_enabled and (
torch.is_floating_point(data) or torch.is_complex(data)
):
# NLP models inputs are int/uint and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(
{"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}
)
return data.to(**kwargs)
return data
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Overridden from the parent class to handle the case where the model does not return loss values.
Support for sagemaker was removed!
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = (
False
if len(self.label_names) == 0
else all(inputs.get(k) is not None for k in self.label_names)
)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
# print("return_loss", return_loss, "has_labels", has_labels)
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = inputs.get("labels")
else:
labels = None
with torch.no_grad():
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
logits = outputs
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(inputs["inputs_embeds"], inputs.get("zs_scores"))
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from scipy.stats import rankdata from scipy.stats import rankdata
import mulan.constants as C import mulan.constants as C
from mulan.constants import AAs, aa2idx, idx2aa, one2three, three2one # noqa from mulan.constants import AAs, aa2idx, idx2aa, one2three, three2one
def mutation_generator(sequence): def mutation_generator(sequence):
...@@ -83,7 +83,9 @@ def get_available_models(): ...@@ -83,7 +83,9 @@ def get_available_models():
return list(C.MODELS.keys()) return list(C.MODELS.keys())
def load_pretrained_plm(model_name, device="cpu"): def load_pretrained_plm(model_name, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = C.PLM_ENCODERS.get(model_name) model_id = C.PLM_ENCODERS.get(model_name)
if model_id is None: if model_id is None:
raise ValueError( raise ValueError(
......
...@@ -3,6 +3,8 @@ scipy ...@@ -3,6 +3,8 @@ scipy
pandas pandas
tqdm tqdm
h5py h5py
scikit-learn
torch>=2.0 torch>=2.0
transformers<4.45,>=4.27 transformers<4.45,>=4.27
accelerate
sentencepiece sentencepiece
\ No newline at end of file
"""Train Mulan model on custom data. Not implemented yet!""" """Train Mulan model on custom data using HuggingFace Trainer API"""
import os
import pandas as pd
import torch
from transformers import (
HfArgumentParser,
TrainingArguments,
logging,
EarlyStoppingCallback,
set_seed,
)
import mulan
from mulan.data import MulanDataset, MulanDataCollator
from mulan.train_utils import (
DatasetArguments,
ModelArguments,
CustomisableTrainingArguments,
MulanTrainer,
default_compute_metrics,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
from argparse import ArgumentParser
def get_args(): def get_args():
parser = ArgumentParser( parser = HfArgumentParser(
dataclass_types=[DatasetArguments, ModelArguments, CustomisableTrainingArguments],
prog="mulan-train", prog="mulan-train",
description=__doc__, description=__doc__,
) )
args = parser.parse_args() data_args, model_args, custom_training_args = parser.parse_args_into_dataclasses()
return args return (data_args, model_args, custom_training_args)
def load_data(data_args):
logger.info("Loading training data")
train_dataset = MulanDataset.from_table(
data_args.train_data,
data_args.train_fasta_file,
data_args.embeddings_dir,
data_args.plm_model_name,
)
eval_dataset, test_dataset = None, None
if data_args.eval_data:
eval_dataset = MulanDataset.from_table(
data_args.eval_data,
data_args.train_fasta_file,
data_args.embeddings_dir,
data_args.plm_model_name,
)
if data_args.test_data:
logger.info("Loading test data...")
test_dataset = MulanDataset.from_table(
data_args.test_data,
data_args.test_fasta_file,
data_args.embeddings_dir,
data_args.plm_model_name,
)
return train_dataset, eval_dataset, test_dataset
def dummy_forward_call(model, dataset, data_collator):
inputs = data_collator([dataset[0]])
return model(inputs["inputs_embeds"], inputs.get("zs_scores"))
def save_predictions(output_dir, dataset, predictions):
df = pd.DataFrame(dataset.mutated_complexes)
df["mutations"] = df["mutations"].apply(lambda x: ",".join(x))
df["score"] = predictions
df.to_csv(
os.path.join(output_dir, "test_predictions.tsv"), sep="\t", index=False, header=False
)
return
def save_model_ckpt(model, output_dir):
torch.save(
{"state_dict": model.state_dict(), "config": model.config.__dict__},
os.path.join(output_dir, "model.ckpt"),
)
def train(data_args, model_args, custom_training_args):
# set global seed
set_seed(42)
# load data
train_dataset, eval_dataset, test_dataset = load_data(data_args)
# load model
if model_args.model_name_or_config_path in mulan.get_available_models():
model = mulan.load_pretrained(model_args.model_name_or_config_path, device="cpu")
else:
config = mulan.MulanConfig.from_json(model_args.model_name_or_config_path)
model = mulan.LightAttModel(config)
# training arguments
training_args = TrainingArguments(
output_dir=custom_training_args.output_dir,
num_train_epochs=custom_training_args.num_epochs,
per_device_train_batch_size=custom_training_args.batch_size,
per_device_eval_batch_size=custom_training_args.batch_size,
logging_dir=custom_training_args.output_dir,
report_to=custom_training_args.report_to,
remove_unused_columns=False,
label_names=["labels"],
logging_strategy="epoch",
eval_strategy="epoch" if eval_dataset else "no",
save_strategy="epoch" if model_args.save_model else "no",
load_best_model_at_end=(eval_dataset and model_args.save_model), # to be tested
metric_for_best_model="loss",
save_total_limit=2,
)
data_collator = MulanDataCollator(padding_value=model.config.padding_value)
optimizer = torch.optim.AdamW(model.parameters(), lr=custom_training_args.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = (
EarlyStoppingCallback(custom_training_args.early_stopping_patience)
if custom_training_args.early_stopping_patience
else None
)
dummy_forward_call(model, train_dataset, data_collator) # to initialize lazy modules
# instantiate Trainer
trainer = MulanTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=default_compute_metrics,
optimizers=(optimizer, scheduler),
callbacks=[early_stopping] if early_stopping else None,
)
# train model
train_results = trainer.train()
metrics = train_results.metrics
if test_dataset:
prediction_results = trainer.predict(test_dataset)
save_predictions(
custom_training_args.output_dir, test_dataset, prediction_results.predictions
)
metrics.update(prediction_results.metrics)
trainer.save_metrics("all", metrics)
# TODO
# remove logging message `Trainer.model is not a `PreTrainedModel`, only saving its state dict.``
if model_args.save_model:
save_model_ckpt(model, custom_training_args.output_dir)
return
def main(): def main():
get_args() data_args, model_args, custom_training_args = get_args()
raise NotImplementedError("This script has not been implemented yet.") train(data_args, model_args, custom_training_args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment