Args changed (num_nodes) + esm2 path bugfix

parent a6c3a6aa
...@@ -26,7 +26,8 @@ def predict(params): ...@@ -26,7 +26,8 @@ def predict(params):
pretrained_model.load_state_dict(checkpoint['state_dict']) pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False) trainer = pl.Trainer(accelerator=params.device, logger=False,
num_nodes=params.num_nodes if hasattr(params, 'num_nodes') else 1)
test_loader = DataLoader(dataset=test_data, test_loader = DataLoader(dataset=test_data,
batch_size=params.batch_size, batch_size=params.batch_size,
...@@ -85,6 +86,8 @@ def add_args(parser): ...@@ -85,6 +86,8 @@ def add_args(parser):
predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5, predict_args.add_argument("-p", "--pred_threshold", type=float, default=0.5,
help="Prediction threshold to determine interacting pairs that " help="Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1).") "will be written to a separate file. Range: (0, 1).")
predict_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for launching on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr") remove_argument(parser, "--lr")
......
...@@ -24,7 +24,8 @@ def test(params): ...@@ -24,7 +24,8 @@ def test(params):
pretrained_model.load_state_dict(checkpoint['state_dict']) pretrained_model.load_state_dict(checkpoint['state_dict'])
trainer = pl.Trainer(accelerator=params.device, logger=False) trainer = pl.Trainer(accelerator=params.device, logger=False,
num_nodes=params.num_nodes)
eval_loader = DataLoader(dataset=eval_data, eval_loader = DataLoader(dataset=eval_data,
batch_size=params.batch_size, batch_size=params.batch_size,
...@@ -55,6 +56,8 @@ def add_args(parser): ...@@ -55,6 +56,8 @@ def add_args(parser):
help="If set, the data will be cropped to the limits of the model: " help="If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa. WARNING: " "evaluations will be done only for proteins >50aa and <800aa. WARNING: "
"this will modify the original input files.") "this will modify the original input files.")
test_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for launching on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr") remove_argument(parser, "--lr")
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Modified by Konstantin Volzhenin, Sorbonne University, 2023 # Modified by Konstantin Volzhenin, Sorbonne University, 2023
import argparse import argparse
import pathlib from pathlib import Path
import torch import torch
import os import os
import logging import logging
...@@ -29,7 +29,7 @@ def add_esm_args(parent_parser): ...@@ -29,7 +29,7 @@ def add_esm_args(parent_parser):
) )
parser.add_argument( parser.add_argument(
"--output_dir_esm", "--output_dir_esm",
type=pathlib.Path, default=pathlib.Path('esm2_embs_3B'), type=Path, default=Path('esm2_embs_3B'),
help="output directory for extracted representations", help="output directory for extracted representations",
) )
...@@ -126,7 +126,7 @@ def compute_embeddings(params): ...@@ -126,7 +126,7 @@ def compute_embeddings(params):
seq_dict.pop(seq_id) seq_dict.pop(seq_id)
if len(seq_dict) > 0: if len(seq_dict) > 0:
params_esm = copy(params) params_esm = copy(params)
params_esm.fasta_file = 'tmp_for_esm.fasta' params_esm.fasta_file = Path(str(params.fasta_file).replace('fasta', 'tmp.fasta'))
with open(params_esm.fasta_file, 'w') as f: with open(params_esm.fasta_file, 'w') as f:
for seq_id in seq_dict.keys(): for seq_id in seq_dict.keys():
f.write('>' + seq_id + '\n') f.write('>' + seq_id + '\n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment