0.1.9 mps now works but only to compute the ESM2 embeddings, removed redundancy…

0.1.9 mps now works but only to compute the ESM2 embeddings, removed redundancy in predict_string.py
parent f7ddda97
__version__ = "0.1.8"
__version__ = "0.1.9"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -33,12 +33,6 @@ def main():
params = parser.parse_args()
if hasattr(params, 'device'):
# # WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
# if params.device == 'mps':
# logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
# 'The cpu backend will be used instead.')
# params.device = 'cpu'
if params.device == 'gpu':
torch.set_float32_matmul_precision('high')
......
......@@ -36,10 +36,7 @@ def predict(params):
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"])
data['preds'] = preds
return data
return preds
def generate_pairs(fasta_file, output_path, with_self=False):
......@@ -100,8 +97,17 @@ def main(params):
compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Predicting...')
data = predict(params)
preds = predict(params)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"])
data['preds'] = preds
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=False)
......
......@@ -10,12 +10,14 @@ from matplotlib.patches import Rectangle
import argparse
import matplotlib.pyplot as plt
import glob
import logging
from ..model import SensePPIModel
from ..utils import *
from ..network_utils import *
from ..esm2_model import add_esm_args, compute_embeddings
from ..dataset import PairSequenceData
from predict import predict
def main(params):
......@@ -33,28 +35,13 @@ def main(params):
params.fasta_file = fasta_file
compute_embeddings(params)
test_data = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=pairs_file,
max_len=params.max_len, labels=False)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
pretrained_model = SensePPIModel(params)
if params.device == 'gpu':
checkpoint = torch.load(params.model_path)
elif params.device == 'mps':
checkpoint = torch.load(params.model_path, map_location=torch.device('mps'))
else:
checkpoint = torch.load(params.model_path, map_location=torch.device('cpu'))
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False)
test_loader = DataLoader(dataset=test_data,
batch_size=params.batch_size,
num_workers=4)
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()]
preds = np.asarray(preds)
preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions
data = pd.read_csv('protein.pairs_string.tsv', delimiter='\t', names=["seq1", "seq2", "string_label"])
......
......@@ -72,6 +72,12 @@ def main(params):
compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Evaluating...')
test_metrics = test(params)[0]
......
......@@ -2,6 +2,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib
import argparse
import logging
from ..utils import add_general_args
from ..model import SensePPIModel
from ..dataset import PairSequenceData
......@@ -14,6 +15,12 @@ def main(params):
compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment