1. MPS still does not work (torch incompatibility)

2. Train function transferred (needs to be checked)
3. Minor changes
parent 4db6ea76
/esm2_embs_3B
/pretrained_models
*.tsv
*.fasta
\ No newline at end of file
*.fasta
*.sh
\ No newline at end of file
import argparse
import logging
from .commands import *
from senseppi import __version__
......@@ -22,6 +23,13 @@ def main():
sp.set_defaults(func=module.main)
params = parser.parse_args()
#WARNING: due to some internal issues of torch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
params.func(params)
......
......@@ -11,7 +11,7 @@ from ..esm2_model import add_esm_args, compute_embeddings
def predict(params):
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs,
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=False)
pretrained_model = SensePPIModel(params)
......@@ -34,7 +34,7 @@ def predict(params):
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds)
data = pd.read_csv(params.pairs, delimiter='\t', names=["seq1", "seq2"])
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"])
data['preds'] = preds
return data
......@@ -66,16 +66,19 @@ def add_args(parser):
predict_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.")
predict_args.add_argument("--pairs", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, all-to-all pairs will be generated.")
predict_args.add_argument("--pairs_file", type=str, default=None,
help="A path to a .tsv file with pairs of proteins to test (Optional). If not provided, "
"all-to-all pairs will be generated.")
predict_args.add_argument("-o", "--output", type=str, default="predictions",
help="A path to a file where the predictions will be saved. (.tsv format will be added automatically)")
help="A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)")
predict_args.add_argument("--with_self", action='store_true',
help="Include self-interactions in the predictions."
"By default they are not included since they were not part of training but"
"they can be included by setting this flag to True.")
predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5,
help="Prediction threshold to determine interacting pairs that will be written to a separate file. Range: (0, 1).")
help="Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1).")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
......@@ -88,9 +91,9 @@ def main(params):
logging.info("Device used: ", params.device)
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs is None:
if params.pairs_file is None:
generate_pairs(params.fasta_file, 'protein.pairs.tsv', with_self=params.with_self)
params.pairs = 'protein.pairs.tsv'
params.pairs_file = 'protein.pairs.tsv'
compute_embeddings(params)
......@@ -104,7 +107,8 @@ def main(params):
if __name__ == '__main__':
parser = add_args()
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args()
main(params)
import os
from pathlib import Path
import torch
import pytorch_lightning as pl
import sys
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from ..model import SensePPIModel
from ..dataset import PairSequenceData
......@@ -10,72 +6,63 @@ from ..utils import *
from ..esm2_model import add_esm_args, compute_embeddings
# training_th.py
def main(params):
if params.seed is not None:
pl.seed_everything(params.seed, workers=True)
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs,
max_len=params.max_len, labels=False)
compute_embeddings(params)
model = SensePPIModel(params)
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
model.load_data(dataset=dataset, valid_size=0.1)
model = SensePPIModel(params)
model.load_data(dataset=dataset, valid_size=params.valid_size)
train_set = model.train_dataloader()
val_set = model.val_dataloader()
logger = pl.loggers.TensorBoardLogger("logs", name='SENSE-PPI')
# logger = pl.loggers.TensorBoardLogger("logs", name='SENSE-PPI')
logger = None
callbacks = [
TQDMProgressBar(refresh_rate=250),
# TQDMProgressBar(refresh_rate=250),
ModelCheckpoint(filename='chkpt_loss_based_{epoch}-{val_loss:.3f}-{val_BinaryF1Score:.3f}', verbose=True,
monitor='val_loss', mode='min', save_top_k=1)
]
trainer = pl.Trainer(accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=params.devices, num_nodes=params.num_nodes, max_epochs=100,
logger=logger, callbacks=callbacks, strategy=params.strategy)
trainer = pl.Trainer(accelerator=params.device, devices=params.num_devices, num_nodes=params.num_nodes,
max_epochs=params.num_epochs, logger=logger, callbacks=callbacks)
trainer.fit(model, train_set, val_set)
def esm_check(fasta_file, output_dir, params):
params.model_location = 'esm2_t36_3B_UR50D'
params.fasta_file = fasta_file
params.output_dir = output_dir
with open(params.fasta_file, 'r') as f:
seq_ids = [line.strip().split(' ')[0].replace('>', '') for line in f.readlines() if line.startswith('>')]
if not os.path.exists(params.output_dir):
print('Computing ESM embeddings...')
esm2_model.run(params)
else:
for seq_id in seq_ids:
if not os.path.exists(os.path.join(params.output_dir, seq_id + '.pt')):
print('Computing ESM embeddings...')
esm2_model.run(params)
break
def add_args(parser):
parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Training args")
train_args = parser.add_argument_group(title="Training args")
parser._action_groups[0].add_argument("pairs_file", type=str,
help="A path to a .tsv file containing training pairs. "
"Required format: 3 tab separated columns: first protein, "
"second protein (protein names have to be present in fasta_file), "
"label (0 or 1).")
train_args.add_argument("--valid_size", type=float, default=0.1,
help="Fraction of the training data to use for validation.")
train_args.add_argument("--seed", type=int, default=None, help="Global training seed.")
train_args.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.")
train_args.add_argument("--num_devices", type=int, default=1,
help="Number of devices to use for multi GPU training.")
train_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for training on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
add_esm_args(parser)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args()
esm_check(Path(os.path.join('Data', 'Dscript', 'human.fasta')),
Path(os.path.join('Data', 'Dscript', 'esm_emb_3B_human')),
params)
main(params)
\ No newline at end of file
......@@ -10,7 +10,7 @@ class PairSequenceData(Dataset):
def __init__(self,
actions_file,
emb_dir,
max_len=800,
max_len,
pad_inputs=True,
labels=True):
......@@ -34,9 +34,8 @@ class PairSequenceData(Dataset):
try:
emb = torch.load(f)
except FileNotFoundError as _:
raise Exception(
'Embedding file {} not found. Check your fasta file and make sure it contains all the sequences used in training/testing.'.format(
f))
raise Exception('Embedding file {} not found. Check your fasta file and make sure it contains '
'all the sequences used in training/testing.'.format(f))
tensor_emb = emb['representations'][36] # [33]
tensor_len = tensor_emb.size(0)
......
......@@ -47,7 +47,6 @@ class DynamicLSTM(pl.LightningModule):
def forward(self, x, seq_lens):
# sort input by descending length
_, idx_sort = torch.sort(seq_lens, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
x_sort = torch.index_select(x, dim=0, index=idx_sort)
......@@ -163,12 +162,12 @@ class BaselineModel(pl.LightningModule):
self.valid_metrics.reset()
self.log_dict(result, on_epoch=True, sync_dist=self.hparams.sync_dist)
def load_data(self, dataset, valid_size=0.2, indices=None):
def load_data(self, dataset, valid_size=0.1, indices=None):
if indices is None:
dataset_length = len(dataset)
valid_length = int(valid_size * dataset_length)
train_length = dataset_length - valid_length
self.train_set, self.val_set = data.random_split(dataset, [train_length, valid_length]) # , test_size])
self.train_set, self.val_set = data.random_split(dataset, [train_length, valid_length])
print('Data has been randomly divided into train/val sets with sizes {} and {}'.format(len(self.train_set),
len(self.val_set)))
else:
......@@ -216,7 +215,7 @@ class SensePPIModel(BaselineModel):
def __init__(self, params):
super(SensePPIModel, self).__init__(params)
self.encoder_features = self.hparams.encoder_features # 2560
self.encoder_features = self.hparams.encoder_features
self.hidden_dim = 256
self.lstm = DynamicLSTM(self.encoder_features, hidden_size=128, num_layers=3, dropout=0.5, bidirectional=True)
......
......@@ -15,10 +15,10 @@ def add_general_args(parser):
)
parser.add_argument("--min_len", type=int, default=50,
help="Minimum length of the protein sequence. "
"The sequences with smaller length will not be considered. Default: 50")
"The sequences with smaller length will not be considered.")
parser.add_argument("--max_len", type=int, default=800,
help="Maximum length of the protein sequence. "
"The sequences with larger length will not be considered. Default: 800")
"The sequences with larger length will not be considered.")
parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'],
help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)."
"If not selected the device is set by torch automatically.")
......
......@@ -23,7 +23,7 @@ setup(
"matplotlib",
"tqdm",
"scikit-learn",
"pytorch-lightning",
"pytorch-lightning==1.9.0",
"torchmetrics",
"biopython",
"fair-esm"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment