0.1.6

bugfix in params.device
parent 15fc56cf
__version__ = "0.1.5"
__version__ = "0.1.6"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
import argparse
import logging
import torch
from .commands import *
from senseppi import __version__
......@@ -32,11 +32,15 @@ def main():
params = parser.parse_args()
#WARNING: due to some internal issues of torch, the mps backend is temporarily disabled
if hasattr(params, 'device') and params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
if hasattr(params, 'device'):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
if params.device == 'gpu':
torch.set_float32_matmul_precision('high')
params.func(params)
......
......@@ -21,7 +21,7 @@ def add_general_args(parser):
def determine_device():
if torch.cuda.is_available():
device = 'cuda'
device = 'gpu'
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = 'mps'
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment