0.6.4 edits to esm2 params (max_len for truncation) + edits for predict (exceptions + minor things)

parent 0c517ea9
......@@ -43,7 +43,10 @@ The original SENSE-PPI repository also contains two human-based models pretraine
For information about the other models that can be found in the pretrained_models folder, please refer to the original article.
**N.B.**: All pretrained models were made to work with proteins in range 50-800 amino acids.
**N.B.: All pretrained models were made to work with proteins in range 50-800 amino acids.**
1. By running the 'predict' command the model will automatically take 1 as the minimum length and the maximum length will be the length of the longest protein in the dataset. However, it is **strongly recommended** to use the proteins in range 50-800 amino acids for the best performance.
2. if you use --min_len and --max_len arguments your fasta file will be filtered automatically, so make sure you have a backup.
In order to cite the original SENSE-PPI paper, please use the following link: https://doi.org/10.1101/2023.09.19.558413
......
__version__ = "0.6.3"
__version__ = "0.6.4"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -34,7 +34,11 @@ def predict(params):
num_workers=4)
preds = trainer.predict(pretrained_model, test_loader)
try:
preds = [batch.squeeze().tolist() for batch in preds]
except TypeError:
raise Exception("It looks like the dataset is empty. Check the sequence length restriction, "
"it might be that due to the values of min_len and max_len, no pairs were left in the dataset.")
if any(isinstance(i, list) for i in preds):
preds = [item for batch in preds for item in batch]
preds = np.asarray(preds)
......@@ -94,8 +98,8 @@ def add_args(parser):
add_esm_args(parser)
parser.set_defaults(max_len=None)
parser.set_defaults(min_len=0)
parser.set_defaults(max_len=None) # later will be set to the max length of the sequences in the fasta file
parser.set_defaults(min_len=1)
return parser
......@@ -116,8 +120,8 @@ def get_protein_names(fasta_file):
def main(params):
tmp_pairs = 'protein.pairs.tsv'
tmp_pairs = 'senseppi_pairs_for_prediction.tmp'
try:
fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None:
params.max_len = fasta_max_len
......@@ -158,7 +162,9 @@ def main(params):
data_positive = data[data['preds'] >= params.pred_threshold]
data_positive = data_positive.sort_values(by=['preds'], ascending=False)
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
except Exception as e:
raise e
finally:
if os.path.isfile(tmp_pairs):
os.remove(tmp_pairs)
......
......@@ -42,13 +42,6 @@ def add_esm_args(parent_parser):
# help="layers indices from which to extract representations (0 to num_layers, inclusive)",
help=argparse.SUPPRESS
)
parser.add_argument(
"--truncation_seq_length_esm",
type=int,
default=1022,
# help="truncate sequences longer than the given value",
help=argparse.SUPPRESS
)
def run(params):
......@@ -65,7 +58,7 @@ def run(params):
dataset = FastaBatchedDataset.from_file(params.fasta_file)
batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(params.truncation_seq_length_esm), batch_sampler=batches
dataset, collate_fn=alphabet.get_batch_converter(params.max_len), batch_sampler=batches
)
print(f"Read {params.fasta_file} with {len(dataset)} sequences")
......@@ -94,7 +87,7 @@ def run(params):
params.output_file_esm = params.output_dir_esm / f"{label}.pt"
params.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label}
truncate_len = min(params.truncation_seq_length_esm, len(strs[i]))
truncate_len = min(params.max_len, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
result["representations"] = {
......@@ -138,7 +131,13 @@ def compute_embeddings(params):
if __name__ == "__main__":
from utils import add_general_args
esm_parser = argparse.ArgumentParser()
add_esm_args(esm_parser)
esm_parser = add_general_args(esm_parser)
esm_parser.add_argument("fasta_file", type=Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
args = esm_parser.parse_args()
run(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment