Args changed (num_nodes) + esm2 path bugfix

parent a6c3a6aa
......@@ -26,7 +26,8 @@ def predict(params):
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False)
trainer = pl.Trainer(accelerator=params.device, logger=False,
num_nodes=params.num_nodes if hasattr(params, 'num_nodes') else 1)
test_loader = DataLoader(dataset=test_data,
batch_size=params.batch_size,
......@@ -85,6 +86,8 @@ def add_args(parser):
predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5,
help="Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1).")
predict_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for launching on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
......
......@@ -24,7 +24,8 @@ def test(params):
pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False)
trainer = pl.Trainer(accelerator=params.device, logger=False,
num_nodes=params.num_nodes)
eval_loader = DataLoader(dataset=eval_data,
batch_size=params.batch_size,
......@@ -55,6 +56,8 @@ def add_args(parser):
help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa. WARNING: "
"this will modify the original input files.")
test_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for launching on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
......
......@@ -3,7 +3,7 @@
# Modified by Konstantin Volzhenin, Sorbonne University, 2023
import argparse
import pathlib
from pathlib import Path
import torch
import os
import logging
......@@ -29,7 +29,7 @@ def add_esm_args(parent_parser):
)
parser.add_argument(
"--output_dir_esm",
type=pathlib.Path, default=pathlib.Path('esm2_embs_3B'),
type=Path, default=Path('esm2_embs_3B'),
help="output directory for extracted representations",
)
......@@ -126,7 +126,7 @@ def compute_embeddings(params):
seq_dict.pop(seq_id)
if len(seq_dict) > 0:
params_esm = copy(params)
params_esm.fasta_file = 'tmp_for_esm.fasta'
params_esm.fasta_file = Path(str(params.fasta_file).replace('fasta', 'tmp.fasta'))
with open(params_esm.fasta_file, 'w') as f:
for seq_id in seq_dict.keys():
f.write('>' + seq_id + '\n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment