0.6.4 edits to esm2 params (max_len for truncation) + edits for predict (exceptions + minor things)

parent 0c517ea9
...@@ -43,7 +43,10 @@ The original SENSE-PPI repository also contains two human-based models pretraine ...@@ -43,7 +43,10 @@ The original SENSE-PPI repository also contains two human-based models pretraine
For information about the other models that can be found in the pretrained_models folder, please refer to the original article. For information about the other models that can be found in the pretrained_models folder, please refer to the original article.
**N.B.**: All pretrained models were made to work with proteins in range 50-800 amino acids. **N.B.: All pretrained models were made to work with proteins in range 50-800 amino acids.**
1. By running the 'predict' command the model will automatically take 1 as the minimum length and the maximum length will be the length of the longest protein in the dataset. However, it is **strongly recommended** to use the proteins in range 50-800 amino acids for the best performance.
2. if you use --min_len and --max_len arguments your fasta file will be filtered automatically, so make sure you have a backup.
In order to cite the original SENSE-PPI paper, please use the following link: https://doi.org/10.1101/2023.09.19.558413 In order to cite the original SENSE-PPI paper, please use the following link: https://doi.org/10.1101/2023.09.19.558413
......
__version__ = "0.6.3" __version__ = "0.6.4"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -34,7 +34,11 @@ def predict(params): ...@@ -34,7 +34,11 @@ def predict(params):
num_workers=4) num_workers=4)
preds = trainer.predict(pretrained_model, test_loader) preds = trainer.predict(pretrained_model, test_loader)
try:
preds = [batch.squeeze().tolist() for batch in preds] preds = [batch.squeeze().tolist() for batch in preds]
except TypeError:
raise Exception("It looks like the dataset is empty. Check the sequence length restriction, "
"it might be that due to the values of min_len and max_len, no pairs were left in the dataset.")
if any(isinstance(i, list) for i in preds): if any(isinstance(i, list) for i in preds):
preds = [item for batch in preds for item in batch] preds = [item for batch in preds for item in batch]
preds = np.asarray(preds) preds = np.asarray(preds)
...@@ -94,8 +98,8 @@ def add_args(parser): ...@@ -94,8 +98,8 @@ def add_args(parser):
add_esm_args(parser) add_esm_args(parser)
parser.set_defaults(max_len=None) parser.set_defaults(max_len=None) # later will be set to the max length of the sequences in the fasta file
parser.set_defaults(min_len=0) parser.set_defaults(min_len=1)
return parser return parser
...@@ -116,8 +120,8 @@ def get_protein_names(fasta_file): ...@@ -116,8 +120,8 @@ def get_protein_names(fasta_file):
def main(params): def main(params):
tmp_pairs = 'protein.pairs.tsv' tmp_pairs = 'senseppi_pairs_for_prediction.tmp'
try:
fasta_max_len = get_max_len(params.fasta_file) fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None: if params.max_len is None:
params.max_len = fasta_max_len params.max_len = fasta_max_len
...@@ -158,7 +162,9 @@ def main(params): ...@@ -158,7 +162,9 @@ def main(params):
data_positive = data[data['preds'] >= params.pred_threshold] data_positive = data[data['preds'] >= params.pred_threshold]
data_positive = data_positive.sort_values(by=['preds'], ascending=False) data_positive = data_positive.sort_values(by=['preds'], ascending=False)
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True) data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
except Exception as e:
raise e
finally:
if os.path.isfile(tmp_pairs): if os.path.isfile(tmp_pairs):
os.remove(tmp_pairs) os.remove(tmp_pairs)
......
...@@ -42,13 +42,6 @@ def add_esm_args(parent_parser): ...@@ -42,13 +42,6 @@ def add_esm_args(parent_parser):
# help="layers indices from which to extract representations (0 to num_layers, inclusive)", # help="layers indices from which to extract representations (0 to num_layers, inclusive)",
help=argparse.SUPPRESS help=argparse.SUPPRESS
) )
parser.add_argument(
"--truncation_seq_length_esm",
type=int,
default=1022,
# help="truncate sequences longer than the given value",
help=argparse.SUPPRESS
)
def run(params): def run(params):
...@@ -65,7 +58,7 @@ def run(params): ...@@ -65,7 +58,7 @@ def run(params):
dataset = FastaBatchedDataset.from_file(params.fasta_file) dataset = FastaBatchedDataset.from_file(params.fasta_file)
batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1) batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(params.truncation_seq_length_esm), batch_sampler=batches dataset, collate_fn=alphabet.get_batch_converter(params.max_len), batch_sampler=batches
) )
print(f"Read {params.fasta_file} with {len(dataset)} sequences") print(f"Read {params.fasta_file} with {len(dataset)} sequences")
...@@ -94,7 +87,7 @@ def run(params): ...@@ -94,7 +87,7 @@ def run(params):
params.output_file_esm = params.output_dir_esm / f"{label}.pt" params.output_file_esm = params.output_dir_esm / f"{label}.pt"
params.output_file_esm.parent.mkdir(parents=True, exist_ok=True) params.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label} result = {"label": label}
truncate_len = min(params.truncation_seq_length_esm, len(strs[i])) truncate_len = min(params.max_len, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation # Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995 # See https://github.com/pytorch/pytorch/issues/1995
result["representations"] = { result["representations"] = {
...@@ -138,7 +131,13 @@ def compute_embeddings(params): ...@@ -138,7 +131,13 @@ def compute_embeddings(params):
if __name__ == "__main__": if __name__ == "__main__":
from utils import add_general_args
esm_parser = argparse.ArgumentParser() esm_parser = argparse.ArgumentParser()
add_esm_args(esm_parser) add_esm_args(esm_parser)
esm_parser = add_general_args(esm_parser)
esm_parser.add_argument("fasta_file", type=Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
args = esm_parser.parse_args() args = esm_parser.parse_args()
run(args) run(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment