0.6.4 edits to esm2 params (max_len for truncation) + edits for predict (exceptions + minor things)

parent 0c517ea9
...@@ -43,7 +43,10 @@ The original SENSE-PPI repository also contains two human-based models pretraine ...@@ -43,7 +43,10 @@ The original SENSE-PPI repository also contains two human-based models pretraine
For information about the other models that can be found in the pretrained_models folder, please refer to the original article. For information about the other models that can be found in the pretrained_models folder, please refer to the original article.
**N.B.**: All pretrained models were made to work with proteins in range 50-800 amino acids. **N.B.: All pretrained models were made to work with proteins in range 50-800 amino acids.**
1. By running the 'predict' command the model will automatically take 1 as the minimum length and the maximum length will be the length of the longest protein in the dataset. However, it is **strongly recommended** to use the proteins in range 50-800 amino acids for the best performance.
2. if you use --min_len and --max_len arguments your fasta file will be filtered automatically, so make sure you have a backup.
In order to cite the original SENSE-PPI paper, please use the following link: https://doi.org/10.1101/2023.09.19.558413 In order to cite the original SENSE-PPI paper, please use the following link: https://doi.org/10.1101/2023.09.19.558413
......
__version__ = "0.6.3" __version__ = "0.6.4"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -34,7 +34,11 @@ def predict(params): ...@@ -34,7 +34,11 @@ def predict(params):
num_workers=4) num_workers=4)
preds = trainer.predict(pretrained_model, test_loader) preds = trainer.predict(pretrained_model, test_loader)
preds = [batch.squeeze().tolist() for batch in preds] try:
preds = [batch.squeeze().tolist() for batch in preds]
except TypeError:
raise Exception("It looks like the dataset is empty. Check the sequence length restriction, "
"it might be that due to the values of min_len and max_len, no pairs were left in the dataset.")
if any(isinstance(i, list) for i in preds): if any(isinstance(i, list) for i in preds):
preds = [item for batch in preds for item in batch] preds = [item for batch in preds for item in batch]
preds = np.asarray(preds) preds = np.asarray(preds)
...@@ -87,15 +91,15 @@ def add_args(parser): ...@@ -87,15 +91,15 @@ def add_args(parser):
help="Prediction threshold to determine interacting pairs that " help="Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1).") "will be written to a separate file. Range: (0, 1).")
predict_args.add_argument("--num_nodes", type=int, default=1, predict_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for launching on a cluster.") help="Number of nodes to use for launching on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser) parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr") remove_argument(parser, "--lr")
add_esm_args(parser) add_esm_args(parser)
parser.set_defaults(max_len=None) parser.set_defaults(max_len=None) # later will be set to the max length of the sequences in the fasta file
parser.set_defaults(min_len=0) parser.set_defaults(min_len=1)
return parser return parser
...@@ -116,51 +120,53 @@ def get_protein_names(fasta_file): ...@@ -116,51 +120,53 @@ def get_protein_names(fasta_file):
def main(params): def main(params):
tmp_pairs = 'protein.pairs.tsv' tmp_pairs = 'senseppi_pairs_for_prediction.tmp'
try:
fasta_max_len = get_max_len(params.fasta_file) fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None: if params.max_len is None:
params.max_len = fasta_max_len params.max_len = fasta_max_len
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len) process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs_file is None: if params.pairs_file is None:
generate_pairs(params.fasta_file, tmp_pairs, with_self=params.with_self) generate_pairs(params.fasta_file, tmp_pairs, with_self=params.with_self)
params.pairs_file = tmp_pairs
else:
if params.max_len < fasta_max_len:
proteins_in_fasta = get_protein_names(params.fasta_file)
data_tmp = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
data_tmp = data_tmp[data_tmp.iloc[:, 0].isin(proteins_in_fasta) &
data_tmp.iloc[:, 1].isin(proteins_in_fasta)]
data_tmp.to_csv(tmp_pairs, sep='\t', index=False, header=False)
params.pairs_file = tmp_pairs params.pairs_file = tmp_pairs
else:
compute_embeddings(params) if params.max_len < fasta_max_len:
proteins_in_fasta = get_protein_names(params.fasta_file)
logging.info('Predicting...')
preds = predict(params) data_tmp = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
data_tmp = data_tmp[data_tmp.iloc[:, 0].isin(proteins_in_fasta) &
data = pd.read_csv(params.pairs_file, delimiter='\t', header=None) data_tmp.iloc[:, 1].isin(proteins_in_fasta)]
data_tmp.to_csv(tmp_pairs, sep='\t', index=False, header=False)
if len(data.columns) == 3: params.pairs_file = tmp_pairs
data.columns = ['seq1', 'seq2', 'label']
elif len(data.columns) == 2: compute_embeddings(params)
data.columns = ['seq1', 'seq2']
else: logging.info('Predicting...')
raise ValueError('The tab-separated pairs file must have 2 or 3 columns (without header): ' preds = predict(params)
'protein name 1, protein name 2 and label(optional)')
data['preds'] = preds data = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
data = data.sort_values(by=['preds'], ascending=False) if len(data.columns) == 3:
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True) data.columns = ['seq1', 'seq2', 'label']
elif len(data.columns) == 2:
data_positive = data[data['preds'] >= params.pred_threshold] data.columns = ['seq1', 'seq2']
data_positive = data_positive.sort_values(by=['preds'], ascending=False) else:
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True) raise ValueError('The tab-separated pairs file must have 2 or 3 columns (without header): '
'protein name 1, protein name 2 and label(optional)')
if os.path.isfile(tmp_pairs): data['preds'] = preds
os.remove(tmp_pairs)
data = data.sort_values(by=['preds'], ascending=False)
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True)
data_positive = data[data['preds'] >= params.pred_threshold]
data_positive = data_positive.sort_values(by=['preds'], ascending=False)
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
except Exception as e:
raise e
finally:
if os.path.isfile(tmp_pairs):
os.remove(tmp_pairs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -42,13 +42,6 @@ def add_esm_args(parent_parser): ...@@ -42,13 +42,6 @@ def add_esm_args(parent_parser):
# help="layers indices from which to extract representations (0 to num_layers, inclusive)", # help="layers indices from which to extract representations (0 to num_layers, inclusive)",
help=argparse.SUPPRESS help=argparse.SUPPRESS
) )
parser.add_argument(
"--truncation_seq_length_esm",
type=int,
default=1022,
# help="truncate sequences longer than the given value",
help=argparse.SUPPRESS
)
def run(params): def run(params):
...@@ -65,7 +58,7 @@ def run(params): ...@@ -65,7 +58,7 @@ def run(params):
dataset = FastaBatchedDataset.from_file(params.fasta_file) dataset = FastaBatchedDataset.from_file(params.fasta_file)
batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1) batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(params.truncation_seq_length_esm), batch_sampler=batches dataset, collate_fn=alphabet.get_batch_converter(params.max_len), batch_sampler=batches
) )
print(f"Read {params.fasta_file} with {len(dataset)} sequences") print(f"Read {params.fasta_file} with {len(dataset)} sequences")
...@@ -94,7 +87,7 @@ def run(params): ...@@ -94,7 +87,7 @@ def run(params):
params.output_file_esm = params.output_dir_esm / f"{label}.pt" params.output_file_esm = params.output_dir_esm / f"{label}.pt"
params.output_file_esm.parent.mkdir(parents=True, exist_ok=True) params.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label} result = {"label": label}
truncate_len = min(params.truncation_seq_length_esm, len(strs[i])) truncate_len = min(params.max_len, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation # Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995 # See https://github.com/pytorch/pytorch/issues/1995
result["representations"] = { result["representations"] = {
...@@ -138,7 +131,13 @@ def compute_embeddings(params): ...@@ -138,7 +131,13 @@ def compute_embeddings(params):
if __name__ == "__main__": if __name__ == "__main__":
from utils import add_general_args
esm_parser = argparse.ArgumentParser() esm_parser = argparse.ArgumentParser()
add_esm_args(esm_parser) add_esm_args(esm_parser)
esm_parser = add_general_args(esm_parser)
esm_parser.add_argument("fasta_file", type=Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
args = esm_parser.parse_args() args = esm_parser.parse_args()
run(args) run(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment