0.1.9 mps now works but only to compute the ESM2 embeddings, removed redundancy…

0.1.9 mps now works but only to compute the ESM2 embeddings, removed redundancy in predict_string.py
parent f7ddda97
__version__ = "0.1.8" __version__ = "0.1.9"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -33,12 +33,6 @@ def main(): ...@@ -33,12 +33,6 @@ def main():
params = parser.parse_args() params = parser.parse_args()
if hasattr(params, 'device'): if hasattr(params, 'device'):
# # WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
# if params.device == 'mps':
# logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
# 'The cpu backend will be used instead.')
# params.device = 'cpu'
if params.device == 'gpu': if params.device == 'gpu':
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
......
...@@ -36,10 +36,7 @@ def predict(params): ...@@ -36,10 +36,7 @@ def predict(params):
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()] preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds) preds = np.asarray(preds)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"]) return preds
data['preds'] = preds
return data
def generate_pairs(fasta_file, output_path, with_self=False): def generate_pairs(fasta_file, output_path, with_self=False):
...@@ -100,8 +97,17 @@ def main(params): ...@@ -100,8 +97,17 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Predicting...') logging.info('Predicting...')
data = predict(params) preds = predict(params)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"])
data['preds'] = preds
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=False) data.to_csv(params.output + '.tsv', sep='\t', index=False, header=False)
......
...@@ -10,12 +10,14 @@ from matplotlib.patches import Rectangle ...@@ -10,12 +10,14 @@ from matplotlib.patches import Rectangle
import argparse import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import glob import glob
import logging
from ..model import SensePPIModel from ..model import SensePPIModel
from ..utils import * from ..utils import *
from ..network_utils import * from ..network_utils import *
from ..esm2_model import add_esm_args, compute_embeddings from ..esm2_model import add_esm_args, compute_embeddings
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
from predict import predict
def main(params): def main(params):
...@@ -33,28 +35,13 @@ def main(params): ...@@ -33,28 +35,13 @@ def main(params):
params.fasta_file = fasta_file params.fasta_file = fasta_file
compute_embeddings(params) compute_embeddings(params)
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=pairs_file, # WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
max_len=params.max_len, labels=False) if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
pretrained_model = SensePPIModel(params) preds = predict(params)
if params.device == 'gpu':
checkpoint = torch.load(params.model_path)
elif params.device == 'mps':
checkpoint = torch.load(params.model_path, map_location=torch.device('mps'))
else:
checkpoint = torch.load(params.model_path, map_location=torch.device('cpu'))
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False)
test_loader = DataLoader(dataset=test_data,
batch_size=params.batch_size,
num_workers=4)
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds)
# open the actions tsv file as dataframe and add the last column with the predictions # open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"]) data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
......
...@@ -72,6 +72,12 @@ def main(params): ...@@ -72,6 +72,12 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Evaluating...') logging.info('Evaluating...')
test_metrics = test(params)[0] test_metrics = test(params)[0]
......
...@@ -2,6 +2,7 @@ import pytorch_lightning as pl ...@@ -2,6 +2,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib import pathlib
import argparse import argparse
import logging
from ..utils import add_general_args from ..utils import add_general_args
from ..model import SensePPIModel from ..model import SensePPIModel
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
...@@ -14,6 +15,12 @@ def main(params): ...@@ -14,6 +15,12 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file, dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True) max_len=params.max_len, labels=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment