0.2.1 changed name of the command "create_dataset". Mps block is put globally in __main__

parent b73cbabe
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import torch import torch
from .commands import * from .commands import *
from senseppi import __version__ from senseppi import __version__
from senseppi.utils import block_mps
def main(): def main():
...@@ -20,7 +21,7 @@ def main(): ...@@ -20,7 +21,7 @@ def main():
modules = {'train': train, modules = {'train': train,
'predict': predict, 'predict': predict,
'string_dataset_create': string_dataset_create, 'create_dataset': create_dataset,
'test': test, 'test': test,
'predict_string': predict_string 'predict_string': predict_string
} }
...@@ -36,6 +37,8 @@ def main(): ...@@ -36,6 +37,8 @@ def main():
if params.device == 'gpu': if params.device == 'gpu':
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
block_mps(params)
logging.info('Device used: {}'.format(params.device)) logging.info('Device used: {}'.format(params.device))
params.func(params) params.func(params)
......
__all__ = ['predict', 'train', 'string_dataset_create', 'test', 'predict_string'] __all__ = ['predict', 'train', 'create_dataset', 'test', 'predict_string']
\ No newline at end of file \ No newline at end of file
...@@ -96,8 +96,6 @@ def main(params): ...@@ -96,8 +96,6 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
block_mps(params)
logging.info('Predicting...') logging.info('Predicting...')
preds = predict(params) preds = predict(params)
......
...@@ -31,8 +31,6 @@ def main(params): ...@@ -31,8 +31,6 @@ def main(params):
params.fasta_file = fasta_file params.fasta_file = fasta_file
compute_embeddings(params) compute_embeddings(params)
block_mps(params)
preds = predict(params) preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions # open the actions tsv file as dataframe and add the last column with the predictions
......
...@@ -71,8 +71,6 @@ def main(params): ...@@ -71,8 +71,6 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
block_mps(params)
logging.info('Evaluating...') logging.info('Evaluating...')
test_metrics = test(params)[0] test_metrics = test(params)[0]
......
...@@ -14,8 +14,6 @@ def main(params): ...@@ -14,8 +14,6 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
block_mps(params)
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file, dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True) max_len=params.max_len, labels=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment