First fully assembled version 0.1.0

Comments, arguments and outdated code have to be cleaned, the version still needs some testing
parent 59a6e327
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -37,7 +37,7 @@ def test(params): ...@@ -37,7 +37,7 @@ def test(params):
def add_args(parser): def add_args(parser):
parser = add_general_args(parser) parser = add_general_args(parser)
predict_args = parser.add_argument_group(title="Predict args") test_args = parser.add_argument_group(title="Predict args")
parser._action_groups[0].add_argument("model_path", type=str, parser._action_groups[0].add_argument("model_path", type=str,
help="A path to .ckpt file that contains weights to a pretrained model.") help="A path to .ckpt file that contains weights to a pretrained model.")
parser._action_groups[0].add_argument("pairs_file", type=str, default=None, parser._action_groups[0].add_argument("pairs_file", type=str, default=None,
...@@ -47,10 +47,10 @@ def add_args(parser): ...@@ -47,10 +47,10 @@ def add_args(parser):
help="FASTA file on which to extract the ESM2 " help="FASTA file on which to extract the ESM2 "
"representations and then evaluate.", "representations and then evaluate.",
) )
predict_args.add_argument("-o", "--output", type=str, default="test_metrics", test_args.add_argument("-o", "--output", type=str, default="test_metrics",
help="A path to a file where the test metrics will be saved. " help="A path to a file where the test metrics will be saved. "
"(.tsv format will be added automatically)") "(.tsv format will be added automatically)")
predict_args.add_argument("--crop_data_to_model_lims", action="store_true", test_args.add_argument("--crop_data_to_model_lims", action="store_true",
help="If set, the data will be cropped to the limits of the model: " help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa.") "evaluations will be done only for proteins >50aa and <800aa.")
......
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
import pathlib import pathlib
import argparse
from ..utils import add_general_args from ..utils import add_general_args
from ..model import SensePPIModel from ..model import SensePPIModel
from ..dataset import PairSequenceData from ..dataset import PairSequenceData
......
...@@ -13,12 +13,11 @@ import gzip ...@@ -13,12 +13,11 @@ import gzip
import shutil import shutil
def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_proteins=None): def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None):
ids = [] ids = []
for record in SeqIO.parse(fasta_file, "fasta"): for record in SeqIO.parse(fasta_file, "fasta"):
ids.append(record.id) ids.append(record.id)
if mode == 'all_to_all':
if with_self: if with_self:
all_pairs = [p for p in product(ids, repeat=2)] all_pairs = [p for p in product(ids, repeat=2)]
else: else:
...@@ -51,7 +50,7 @@ def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_protei ...@@ -51,7 +50,7 @@ def generate_pairs(fasta_file, mode='all_to_all', with_self=False, delete_protei
pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)] pairs = pairs[~pairs['seq1'].isin(string_ids_to_delete)]
pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)] pairs = pairs[~pairs['seq2'].isin(string_ids_to_delete)]
pairs.to_csv('protein.actions.tsv', sep='\t', index=False, header=False) pairs.to_csv(output_file, sep='\t', index=False, header=False)
def generate_dscript_gene_names(file_path, def generate_dscript_gene_names(file_path,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment