0.2.1 changed name of the command "create_dataset". Mps block is put globally in __main__

parent b73cbabe
......@@ -3,6 +3,7 @@ import logging
import torch
from .commands import *
from senseppi import __version__
from senseppi.utils import block_mps
def main():
......@@ -20,7 +21,7 @@ def main():
modules = {'train': train,
'predict': predict,
'string_dataset_create': string_dataset_create,
'create_dataset': create_dataset,
'test': test,
'predict_string': predict_string
}
......@@ -36,6 +37,8 @@ def main():
if params.device == 'gpu':
torch.set_float32_matmul_precision('high')
block_mps(params)
logging.info('Device used: {}'.format(params.device))
params.func(params)
......
__all__ = ['predict', 'train', 'string_dataset_create', 'test', 'predict_string']
\ No newline at end of file
__all__ = ['predict', 'train', 'create_dataset', 'test', 'predict_string']
\ No newline at end of file
......@@ -96,8 +96,6 @@ def main(params):
compute_embeddings(params)
block_mps(params)
logging.info('Predicting...')
preds = predict(params)
......
......@@ -31,8 +31,6 @@ def main(params):
params.fasta_file = fasta_file
compute_embeddings(params)
block_mps(params)
preds = predict(params)
# open the actions tsv file as dataframe and add the last column with the predictions
......
......@@ -71,8 +71,6 @@ def main(params):
compute_embeddings(params)
block_mps(params)
logging.info('Evaluating...')
test_metrics = test(params)[0]
......
......@@ -14,8 +14,6 @@ def main(params):
compute_embeddings(params)
block_mps(params)
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment