0.6.4 edits to esm2 params (max_len for truncation) + edits for predict (exceptions + minor things)

parent 0c517ea9
......@@ -43,7 +43,10 @@ The original SENSE-PPI repository also contains two human-based models pretraine
For information about the other models that can be found in the pretrained_models folder, please refer to the original article.
**N.B.**: All pretrained models were made to work with proteins in range 50-800 amino acids.
**N.B.: All pretrained models were made to work with proteins in range 50-800 amino acids.**
1. By running the 'predict' command the model will automatically take 1 as the minimum length and the maximum length will be the length of the longest protein in the dataset. However, it is **strongly recommended** to use the proteins in range 50-800 amino acids for the best performance.
2. if you use --min_len and --max_len arguments your fasta file will be filtered automatically, so make sure you have a backup.
In order to cite the original SENSE-PPI paper, please use the following link: https://doi.org/10.1101/2023.09.19.558413
......
__version__ = "0.6.3"
__version__ = "0.6.4"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -34,7 +34,11 @@ def predict(params):
num_workers=4)
preds = trainer.predict(pretrained_model, test_loader)
preds = [batch.squeeze().tolist() for batch in preds]
try:
preds = [batch.squeeze().tolist() for batch in preds]
except TypeError:
raise Exception("It looks like the dataset is empty. Check the sequence length restriction, "
"it might be that due to the values of min_len and max_len, no pairs were left in the dataset.")
if any(isinstance(i, list) for i in preds):
preds = [item for batch in preds for item in batch]
preds = np.asarray(preds)
......@@ -87,15 +91,15 @@ def add_args(parser):
help="Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1).")
predict_args.add_argument("--num_nodes", type=int, default=1,
help="Number of nodes to use for launching on a cluster.")
help="Number of nodes to use for launching on a cluster.")
parser = SensePPIModel.add_model_specific_args(parser)
remove_argument(parser, "--lr")
add_esm_args(parser)
parser.set_defaults(max_len=None)
parser.set_defaults(min_len=0)
parser.set_defaults(max_len=None) # later will be set to the max length of the sequences in the fasta file
parser.set_defaults(min_len=1)
return parser
......@@ -116,51 +120,53 @@ def get_protein_names(fasta_file):
def main(params):
tmp_pairs = 'protein.pairs.tsv'
fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None:
params.max_len = fasta_max_len
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs_file is None:
generate_pairs(params.fasta_file, tmp_pairs, with_self=params.with_self)
params.pairs_file = tmp_pairs
else:
if params.max_len < fasta_max_len:
proteins_in_fasta = get_protein_names(params.fasta_file)
data_tmp = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
data_tmp = data_tmp[data_tmp.iloc[:, 0].isin(proteins_in_fasta) &
data_tmp.iloc[:, 1].isin(proteins_in_fasta)]
data_tmp.to_csv(tmp_pairs, sep='\t', index=False, header=False)
tmp_pairs = 'senseppi_pairs_for_prediction.tmp'
try:
fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None:
params.max_len = fasta_max_len
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs_file is None:
generate_pairs(params.fasta_file, tmp_pairs, with_self=params.with_self)
params.pairs_file = tmp_pairs
compute_embeddings(params)
logging.info('Predicting...')
preds = predict(params)
data = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
if len(data.columns) == 3:
data.columns = ['seq1', 'seq2', 'label']
elif len(data.columns) == 2:
data.columns = ['seq1', 'seq2']
else:
raise ValueError('The tab-separated pairs file must have 2 or 3 columns (without header): '
'protein name 1, protein name 2 and label(optional)')
data['preds'] = preds
data = data.sort_values(by=['preds'], ascending=False)
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True)
data_positive = data[data['preds'] >= params.pred_threshold]
data_positive = data_positive.sort_values(by=['preds'], ascending=False)
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
if os.path.isfile(tmp_pairs):
os.remove(tmp_pairs)
else:
if params.max_len < fasta_max_len:
proteins_in_fasta = get_protein_names(params.fasta_file)
data_tmp = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
data_tmp = data_tmp[data_tmp.iloc[:, 0].isin(proteins_in_fasta) &
data_tmp.iloc[:, 1].isin(proteins_in_fasta)]
data_tmp.to_csv(tmp_pairs, sep='\t', index=False, header=False)
params.pairs_file = tmp_pairs
compute_embeddings(params)
logging.info('Predicting...')
preds = predict(params)
data = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
if len(data.columns) == 3:
data.columns = ['seq1', 'seq2', 'label']
elif len(data.columns) == 2:
data.columns = ['seq1', 'seq2']
else:
raise ValueError('The tab-separated pairs file must have 2 or 3 columns (without header): '
'protein name 1, protein name 2 and label(optional)')
data['preds'] = preds
data = data.sort_values(by=['preds'], ascending=False)
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True)
data_positive = data[data['preds'] >= params.pred_threshold]
data_positive = data_positive.sort_values(by=['preds'], ascending=False)
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
except Exception as e:
raise e
finally:
if os.path.isfile(tmp_pairs):
os.remove(tmp_pairs)
if __name__ == '__main__':
......
......@@ -42,13 +42,6 @@ def add_esm_args(parent_parser):
# help="layers indices from which to extract representations (0 to num_layers, inclusive)",
help=argparse.SUPPRESS
)
parser.add_argument(
"--truncation_seq_length_esm",
type=int,
default=1022,
# help="truncate sequences longer than the given value",
help=argparse.SUPPRESS
)
def run(params):
......@@ -65,7 +58,7 @@ def run(params):
dataset = FastaBatchedDataset.from_file(params.fasta_file)
batches = dataset.get_batch_indices(params.toks_per_batch_esm, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(params.truncation_seq_length_esm), batch_sampler=batches
dataset, collate_fn=alphabet.get_batch_converter(params.max_len), batch_sampler=batches
)
print(f"Read {params.fasta_file} with {len(dataset)} sequences")
......@@ -94,7 +87,7 @@ def run(params):
params.output_file_esm = params.output_dir_esm / f"{label}.pt"
params.output_file_esm.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label}
truncate_len = min(params.truncation_seq_length_esm, len(strs[i]))
truncate_len = min(params.max_len, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
result["representations"] = {
......@@ -138,7 +131,13 @@ def compute_embeddings(params):
if __name__ == "__main__":
from utils import add_general_args
esm_parser = argparse.ArgumentParser()
add_esm_args(esm_parser)
esm_parser = add_general_args(esm_parser)
esm_parser.add_argument("fasta_file", type=Path,
help="FASTA file on which to extract the ESM2 representations and then test.",
)
args = esm_parser.parse_args()
run(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment