0.1.9 minor bugfix

parent a65a0e1a
...@@ -3,7 +3,6 @@ import pytorch_lightning as pl ...@@ -3,7 +3,6 @@ import pytorch_lightning as pl
from itertools import permutations, product from itertools import permutations, product
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import logging
import pathlib import pathlib
import argparse import argparse
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
...@@ -97,11 +96,7 @@ def main(params): ...@@ -97,11 +96,7 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Predicting...') logging.info('Predicting...')
preds = predict(params) preds = predict(params)
...@@ -116,8 +111,8 @@ def main(params): ...@@ -116,8 +111,8 @@ def main(params):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() pred_parser = argparse.ArgumentParser()
parser = add_args(parser) pred_parser = add_args(pred_parser)
params = parser.parse_args() pred_params = pred_parser.parse_args()
main(params) main(pred_params)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
import pandas as pd import pandas as pd
import logging
import pathlib import pathlib
import argparse import argparse
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
...@@ -72,11 +71,7 @@ def main(params): ...@@ -72,11 +71,7 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
logging.info('Evaluating...') logging.info('Evaluating...')
test_metrics = test(params)[0] test_metrics = test(params)[0]
...@@ -87,7 +82,7 @@ def main(params): ...@@ -87,7 +82,7 @@ def main(params):
if __name__ == '__main__': if __name__ == '__main__':
test_parser = argparse.ArgumentParser() test_parser = argparse.ArgumentParser()
parser = add_args(test_parser) test_parser = add_args(test_parser)
params = test_parser.parse_args() test_params = test_parser.parse_args()
main(params) main(test_params)
...@@ -2,8 +2,7 @@ import pytorch_lightning as pl ...@@ -2,8 +2,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib import pathlib
import argparse import argparse
import logging from ..utils import *
from ..utils import add_general_args
from ..model import SensePPIModel from ..model import SensePPIModel
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
from ..esm2_model import add_esm_args, compute_embeddings from ..esm2_model import add_esm_args, compute_embeddings
...@@ -15,11 +14,7 @@ def main(params): ...@@ -15,11 +14,7 @@ def main(params):
compute_embeddings(params) compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled block_mps(params)
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file, dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True) max_len=params.max_len, labels=True)
...@@ -80,8 +75,8 @@ def add_args(parser): ...@@ -80,8 +75,8 @@ def add_args(parser):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() train_parser = argparse.ArgumentParser()
parser = add_args(parser) train_parser = add_args(train_parser)
params = parser.parse_args() train_params = train_parser.parse_args()
main(params) main(train_params)
\ No newline at end of file \ No newline at end of file
...@@ -15,18 +15,13 @@ import shutil ...@@ -15,18 +15,13 @@ import shutil
DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/" DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/"
def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None): def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
ids = [] ids = []
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id) ids.append(record.id)
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
pairs = [] pairs = []
for p in all_pairs: for p in [p for p in permutations(ids, 2)]:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs: if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p) pairs.append(p)
......
...@@ -2,6 +2,7 @@ from Bio import SeqIO ...@@ -2,6 +2,7 @@ from Bio import SeqIO
import os import os
from senseppi import __version__ from senseppi import __version__
import torch import torch
import logging
def add_general_args(parser): def add_general_args(parser):
...@@ -29,6 +30,18 @@ def determine_device(): ...@@ -29,6 +30,18 @@ def determine_device():
return device return device
def block_mps(params):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if hasattr(params, 'device'):
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
if torch.cuda.is_available():
params.device = 'gpu'
else:
params.device = 'cpu'
def process_string_fasta(fasta_file, min_len, max_len): def process_string_fasta(fasta_file, min_len, max_len):
with open('file.tmp', 'w') as f: with open('file.tmp', 'w') as f:
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment