0.5.4 some model args and ESM2 args are suppressed

parent 86e49290
__version__ = "0.5.3" __version__ = "0.5.4"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
#!/usr/bin/env python3 -u #!/usr/bin/env python3 -u
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified by Konstantin Volzhenin, Sorbonne University, 2023
import argparse import argparse
import pathlib import pathlib
...@@ -20,8 +21,9 @@ def add_esm_args(parent_parser): ...@@ -20,8 +21,9 @@ def add_esm_args(parent_parser):
parser.add_argument( parser.add_argument(
"--model_location_esm", "--model_location_esm",
type=str, default="esm2_t36_3B_UR50D", type=str, default="esm2_t36_3B_UR50D",
help="PyTorch model file OR name of pretrained model to download. If not default, " # help="PyTorch model file OR name of pretrained model to download. If not default, "
"the number of encoder_features has to be modified according to the embedding dimensionality. " # "the number of encoder_features has to be modified according to the embedding dimensionality. "
help=argparse.SUPPRESS
) )
parser.add_argument( parser.add_argument(
"--output_dir_esm", "--output_dir_esm",
...@@ -35,13 +37,15 @@ def add_esm_args(parent_parser): ...@@ -35,13 +37,15 @@ def add_esm_args(parent_parser):
type=int, type=int,
default=[-1], default=[-1],
nargs="+", nargs="+",
help="layers indices from which to extract representations (0 to num_layers, inclusive)", # help="layers indices from which to extract representations (0 to num_layers, inclusive)",
help=argparse.SUPPRESS
) )
parser.add_argument( parser.add_argument(
"--truncation_seq_length_esm", "--truncation_seq_length_esm",
type=int, type=int,
default=1022, default=1022,
help="truncate sequences longer than the given value", # help="truncate sequences longer than the given value",
help=argparse.SUPPRESS
) )
......
import argparse
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.utils.data as data import torch.utils.data as data
from torch.utils.data import Subset from torch.utils.data import Subset
from torchmetrics import AUROC, ROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef, AveragePrecision from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score, MatthewsCorrCoef, AveragePrecision
from torchmetrics.collections import MetricCollection from torchmetrics.collections import MetricCollection
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim import torch.optim as optim
...@@ -206,9 +207,10 @@ class BaselineModel(pl.LightningModule): ...@@ -206,9 +207,10 @@ class BaselineModel(pl.LightningModule):
"Cosine warmup will be applied.") "Cosine warmup will be applied.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training/testing.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training/testing.")
parser.add_argument("--encoder_features", type=int, default=2560, parser.add_argument("--encoder_features", type=int, default=2560,
help="Number of features in the encoder " # help="Number of features in the encoder "
"(Corresponds to the dimentionality of per-token embedding of ESM2 model.) " # "(Corresponds to the dimentionality of per-token embedding of ESM2 model.) "
"If not a 3B version of ESM2 is chosen, this parameter needs to be set accordingly.") # "If not a 3B version of ESM2 is chosen, this parameter needs to be set accordingly."
help=argparse.SUPPRESS)
return parent_parser return parent_parser
......
import json import json
from Bio import SeqIO from Bio import SeqIO
from itertools import permutations, product from itertools import permutations
import pandas as pd import pandas as pd
import numpy as np
import os import os
import urllib.request import urllib.request
import time
from tqdm import tqdm
from copy import deepcopy
import requests import requests
import gzip import gzip
import shutil import shutil
......
...@@ -15,7 +15,8 @@ def add_general_args(parser): ...@@ -15,7 +15,8 @@ def add_general_args(parser):
"considered and will be deleted from the fasta file.") "considered and will be deleted from the fasta file.")
parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'], parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'],
help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)." help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)."
"If not selected the device is set by torch automatically.") "If not selected the device is set by torch automatically. WARNING: mps is temporarily "
"disabled, if it is chosen, cpu will be used instead.")
return parser return parser
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment