0.1.9 minor bugfix

parent a65a0e1a
......@@ -3,7 +3,6 @@ import pytorch_lightning as pl
from itertools import permutations, product
import numpy as np
import pandas as pd
import logging
import pathlib
import argparse
from ..dataset import PairSequenceData
......@@ -97,11 +96,7 @@ def main(params):
compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
block_mps(params)
logging.info('Predicting...')
preds = predict(params)
......@@ -116,8 +111,8 @@ def main(params):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args()
pred_parser = argparse.ArgumentParser()
pred_parser = add_args(pred_parser)
pred_params = pred_parser.parse_args()
main(params)
main(pred_params)
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import pandas as pd
import logging
import pathlib
import argparse
from ..dataset import PairSequenceData
......@@ -72,11 +71,7 @@ def main(params):
compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
block_mps(params)
logging.info('Evaluating...')
test_metrics = test(params)[0]
......@@ -87,7 +82,7 @@ def main(params):
if __name__ == '__main__':
test_parser = argparse.ArgumentParser()
parser = add_args(test_parser)
params = test_parser.parse_args()
test_parser = add_args(test_parser)
test_params = test_parser.parse_args()
main(params)
main(test_params)
......@@ -2,8 +2,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib
import argparse
import logging
from ..utils import add_general_args
from ..utils import *
from ..model import SensePPIModel
from ..dataset import PairSequenceData
from ..esm2_model import add_esm_args, compute_embeddings
......@@ -15,11 +14,7 @@ def main(params):
compute_embeddings(params)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
params.device = 'cpu'
block_mps(params)
dataset = PairSequenceData(emb_dir=params.output_dir_esm, actions_file=params.pairs_file,
max_len=params.max_len, labels=True)
......@@ -80,8 +75,8 @@ def add_args(parser):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = add_args(parser)
params = parser.parse_args()
train_parser = argparse.ArgumentParser()
train_parser = add_args(train_parser)
train_params = train_parser.parse_args()
main(params)
\ No newline at end of file
main(train_params)
\ No newline at end of file
......@@ -15,18 +15,13 @@ import shutil
DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/"
def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None):
def generate_pairs_string(fasta_file, output_file, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id)
if with_self:
all_pairs = [p for p in product(ids, repeat=2)]
else:
all_pairs = [p for p in permutations(ids, 2)]
pairs = []
for p in all_pairs:
for p in [p for p in permutations(ids, 2)]:
if (p[1], p[0]) not in pairs and (p[0], p[1]) not in pairs:
pairs.append(p)
......
......@@ -2,6 +2,7 @@ from Bio import SeqIO
import os
from senseppi import __version__
import torch
import logging
def add_general_args(parser):
......@@ -29,6 +30,18 @@ def determine_device():
return device
def block_mps(params):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if hasattr(params, 'device'):
if params.device == 'mps':
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.')
if torch.cuda.is_available():
params.device = 'gpu'
else:
params.device = 'cpu'
def process_string_fasta(fasta_file, min_len, max_len):
with open('file.tmp', 'w') as f:
for record in SeqIO.parse(fasta_file, "fasta"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment