1. MPS still does not work (torch incompatibility)

2. Train function transferred (needs to be checked)
3. Minor changes
parent 4db6ea76
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
/pretrained_models /pretrained_models
*.tsv *.tsv
*.fasta *.fasta
*.sh
\ No newline at end of file
import argparse import argparse
import logging
from .commands import * from .commands import *
from senseppi import __version__ from senseppi import __version__
...@@ -22,6 +23,13 @@ def main(): ...@@ -22,6 +23,13 @@ def main():
sp.set_defaults(func=module.main) sp.set_defaults(func=module.main)
params = parser.parse_args() params = parser.parse_args()
#WARNING: due to some internal issues of torch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
params.func(params) params.func(params)
......
...@@ -11,7 +11,7 @@ from ..esm2_model import add_esm_args, compute_embeddings ...@@ -11,7 +11,7 @@ from ..esm2_model import add_esm_args, compute_embeddings
def predict(params): def predict(params):
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs, test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=False) max_len=params.max_len, labels=False)
pretrained_model = SensePPIModel(params) pretrained_model = SensePPIModel(params)
...@@ -34,7 +34,7 @@ def predict(params): ...@@ -34,7 +34,7 @@ def predict(params):
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()] preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds) preds = np.asarray(preds)
data = pd.read_csv(params.pairs, delimiter='\t', names=["seq1", "seq2"]) data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"])
data['preds'] = preds data['preds'] = preds
return data return data
...@@ -66,16 +66,19 @@ def add_args(parser): ...@@ -66,16 +66,19 @@ def add_args(parser):
predict_args = parser.add_argument_group(title="Predict args") predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str, parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.") help="A path to .ckpt file that contains weights to a pretrained model.")
predict_args.add_argument("--pairs", type=str, default=None, predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, all-to-all pairs will be generated.") help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.")
predict_args.add_argument("-o", "--output", type=str, default="predictions", predict_args.add_argument("-o", "--output", type=str, default="predictions",
help="A path to a file where the predictions will be saved. (.tsv format will be added automatically)") help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)")
predict_args.add_argument("--with_self", action='store_true', predict_args.add_argument("--with_self", action='store_true',
help="Include self-interactions in the predictions." help="Include self-interactions in the predictions."
"By default they are not included since they were not part of training but" "By default they are not included since they were not part of training but"
"they can be included by setting this flag to True.") "they can be included by setting this flag to True.")
predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5, predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5,
help="Prediction threshold to determine interacting pairs that will be written to a separate file. Range: (0, 1).") help="Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1).")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr") remove_argument(parser, "--lr")
...@@ -88,9 +91,9 @@ def main(params): ...@@ -88,9 +91,9 @@ def main(params):
logging.info("Device used: ", params.device) logging.info("Device used: ", params.device)
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len) process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs is None: if params.pairs_file is None:
generate_pairs(params.fasta_file, 'protein.pairs.tsv', with_self=params.with_self) generate_pairs(params.fasta_file, 'protein.pairs.tsv', with_self=params.with_self)
params.pairs = 'protein.pairs.tsv' params.pairs_file = 'protein.pairs.tsv'
compute_embeddings(params) compute_embeddings(params)
...@@ -104,7 +107,8 @@ def main(params): ...@@ -104,7 +107,8 @@ def main(params):
if __name__ == '__main__': if __name__ == '__main__':
parser = add_args() parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args() params = parser.parse_args()
main(params) main(params)
import os
from pathlib import Path
import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import sys
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from ..model import SensePPIModel from ..model import SensePPIModel
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
...@@ -10,72 +6,63 @@ from ..utils import * ...@@ -10,72 +6,63 @@ from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings from ..esm2_model import add_esm_args, compute_embeddings
# training_th.py
def main(params): def main(params):
if params.seed is not None: if params.seed is not None:
pl.seed_everything(params.seed, workers=True) pl.seed_everything(params.seed, workers=True)
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs, compute_embeddings(params)
max_len=params.max_len, labels=False)
model = SensePPIModel(params) dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
model.load_data(dataset=dataset, valid_size=0.1) model = SensePPIModel(params)
model.load_data(dataset=dataset, valid_size=params.valid_size)
train_set = model.train_dataloader() train_set = model.train_dataloader()
val_set = model.val_dataloader() val_set = model.val_dataloader()
logger = pl.loggers.TensorBoardLogger("logs", name='SENSE-PPI') # logger = pl.loggers.TensorBoardLogger("logs", name='SENSE-PPI')
logger = None
callbacks = [ callbacks = [
TQDMProgressBar(refresh_rate=250), # TQDMProgressBar(refresh_rate=250),
ModelCheckpoint(filename='chkpt_loss_based_{epoch}-{val_loss:.3f}-{val_BinaryF1Score:.3f}', verbose=True, ModelCheckpoint(filename='chkpt_loss_based_{epoch}-{val_loss:.3f}-{val_BinaryF1Score:.3f}', verbose=True,
monitor='val_loss', mode='min', save_top_k=1) monitor='val_loss', mode='min', save_top_k=1)
] ]
trainer = pl.Trainer(accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=params.devices, num_nodes=params.num_nodes, max_epochs=100, trainer = pl.Trainer(accelerator=params.device, devices=params.num_devices, num_nodes=params.num_nodes,
logger=logger, callbacks=callbacks, strategy=params.strategy) max_epochs=params.num_epochs, logger=logger, callbacks=callbacks)
trainer.fit(model, train_set, val_set) trainer.fit(model, train_set, val_set)
def esm_check(fasta_file, output_dir, params):
params.model_location = 'esm2_t36_3B_UR50D'
params.fasta_file = fasta_file
params.output_dir = output_dir
with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')]
if not os.path.exists(params.output_dir):
print('Computing ESM embeddings...')
esm2_model.run(params)
else:
for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir, seq_id + '.pt')):
print('Computing ESM embeddings...')
esm2_model.run(params)
break
def add_args(parser): def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Training args") train_args = parser.add_argument_group(title="Training args")
parser._action_groups[0].add_argument("pairs_file", type=str,
help="A path to a .tsv file containing training pairs. "
"Required format: 3 tab separated columns: first protein, "
"second protein (protein names have to be present in fasta_file), "
"label (0 or 1).")
train_args.add_argument("--valid_size", type=float, default=0.1,
help="Fraction of the training data to use for validation.")
train_args.add_argument("--seed", type=int, default=None, help="Global training seed.")
train_args.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.")
train_args.add_argument("--num_devices", type=int, default=1,
help="Number of devices to use for multi GPU training.")
train_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for training on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
add_esm_args(parser) add_esm_args(parser)
return parser return parser
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = add_args(parser) parser = add_args(parser)
params = parser.parse_args() params = parser.parse_args()
esm_check(Path(os.path.join('Data', 'Dscript', 'human.fasta')),
Path(os.path.join('Data', 'Dscript', 'esm_emb_3B_human')),
params)
main(params) main(params)
\ No newline at end of file
...@@ -10,7 +10,7 @@ class PairSequenceData(Dataset): ...@@ -10,7 +10,7 @@ class PairSequenceData(Dataset):
def __init__(self, def __init__(self,
actions_file, actions_file,
emb_dir, emb_dir,
max_len=800, max_len,
pad_inputs=True, pad_inputs=True,
labels=True): labels=True):
...@@ -34,9 +34,8 @@ class PairSequenceData(Dataset): ...@@ -34,9 +34,8 @@ class PairSequenceData(Dataset):
try: try:
emb = torch.load(f) emb = torch.load(f)
except FileNotFoundError as _: except FileNotFoundError as _:
raise Exception( raise Exception('Embedding file {} not found. Check your fasta file and make sure it contains '
'Embedding file {} not found. Check your fasta file and make sure it contains all the sequences used in training/testing.'.format( 'all the sequences used in training/testing.'.format(f))
f))
tensor_emb = emb['representations'][36] # [33] tensor_emb = emb['representations'][36] # [33]
tensor_len = tensor_emb.size(0) tensor_len = tensor_emb.size(0)
......
...@@ -47,7 +47,6 @@ class DynamicLSTM(pl.LightningModule): ...@@ -47,7 +47,6 @@ class DynamicLSTM(pl.LightningModule):
def forward(self, x, seq_lens): def forward(self, x, seq_lens):
# sort input by descending length # sort input by descending length
_, idx_sort = torch.sort(seq_lens, dim=0, descending=True) _, idx_sort = torch.sort(seq_lens, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0) _, idx_unsort = torch.sort(idx_sort, dim=0)
x_sort = torch.index_select(x, dim=0, index=idx_sort) x_sort = torch.index_select(x, dim=0, index=idx_sort)
...@@ -163,12 +162,12 @@ class BaselineModel(pl.LightningModule): ...@@ -163,12 +162,12 @@ class BaselineModel(pl.LightningModule):
self.valid_metrics.reset() self.valid_metrics.reset()
self.log_dict(result, on_epoch=True, sync_dist=self.hparams.sync_dist) self.log_dict(result, on_epoch=True, sync_dist=self.hparams.sync_dist)
def load_data(self, dataset, valid_size=0.2, indices=None): def load_data(self, dataset, valid_size=0.1, indices=None):
if indices is None: if indices is None:
dataset_length = len(dataset) dataset_length = len(dataset)
valid_length = int(valid_size * dataset_length) valid_length = int(valid_size * dataset_length)
train_length = dataset_length - valid_length train_length = dataset_length - valid_length
self.train_set, self.val_set = data.random_split(dataset, [train_length, valid_length]) # , test_size]) self.train_set, self.val_set = data.random_split(dataset, [train_length, valid_length])
print('Data has been randomly divided into train/val sets with sizes {} and {}'.format(len(self.train_set), print('Data has been randomly divided into train/val sets with sizes {} and {}'.format(len(self.train_set),
len(self.val_set))) len(self.val_set)))
else: else:
...@@ -216,7 +215,7 @@ class SensePPIModel(BaselineModel): ...@@ -216,7 +215,7 @@ class SensePPIModel(BaselineModel):
def __init__(self, params): def __init__(self, params):
super(SensePPIModel, self).__init__(params) super(SensePPIModel, self).__init__(params)
self.encoder_features = self.hparams.encoder_features # 2560 self.encoder_features = self.hparams.encoder_features
self.hidden_dim = 256 self.hidden_dim = 256
self.lstm = DynamicLSTM(self.encoder_features, hidden_size=128, num_layers=3, dropout=0.5, bidirectional=True) self.lstm = DynamicLSTM(self.encoder_features, hidden_size=128, num_layers=3, dropout=0.5, bidirectional=True)
......
...@@ -15,10 +15,10 @@ def add_general_args(parser): ...@@ -15,10 +15,10 @@ def add_general_args(parser):
) )
parser.add_argument("--min_len", type=int, default=50, parser.add_argument("--min_len", type=int, default=50,
help="Minimum length of the protein sequence. " help="Minimum length of the protein sequence. "
"The sequences with smaller length will not be considered. Default: 50") "The sequences with smaller length will not be considered.")
parser.add_argument("--max_len", type=int, default=800, parser.add_argument("--max_len", type=int, default=800,
help="Maximum length of the protein sequence. " help="Maximum length of the protein sequence. "
"The sequences with larger length will not be considered. Default: 800") "The sequences with larger length will not be considered.")
parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'], parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'],
help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)." help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)."
"If not selected the device is set by torch automatically.") "If not selected the device is set by torch automatically.")
......
...@@ -23,7 +23,7 @@ setup( ...@@ -23,7 +23,7 @@ setup(
"matplotlib", "matplotlib",
"tqdm", "tqdm",
"scikit-learn", "scikit-learn",
"pytorch-lightning", "pytorch-lightning==1.9.0",
"torchmetrics", "torchmetrics",
"biopython", "biopython",
"fair-esm" "fair-esm"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment