0.1.6

bugfix in params.device
parent 15fc56cf
__version__ = "0.1.5" __version__ = "0.1.6"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
import argparse import argparse
import logging import logging
import torch
from .commands import * from .commands import *
from senseppi import __version__ from senseppi import __version__
...@@ -32,11 +32,15 @@ def main(): ...@@ -32,11 +32,15 @@ def main():
params = parser.parse_args() params = parser.parse_args()
#WARNING: due to some internal issues of torch, the mps backend is temporarily disabled if hasattr(params, 'device'):
if hasattr(params, 'device') and params.device == 'mps': # WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.' if params.device == 'mps':
'The cpu backend will be used instead.') logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
params.device = 'cpu' 'The cpu backend will be used instead.')
params.device = 'cpu'
if params.device == 'gpu':
torch.set_float32_matmul_precision('high')
params.func(params) params.func(params)
......
...@@ -21,7 +21,7 @@ def add_general_args(parser): ...@@ -21,7 +21,7 @@ def add_general_args(parser):
def determine_device(): def determine_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
device = 'cuda' device = 'gpu'
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = 'mps' device = 'mps'
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment