0.3.0 predict updated. Things to be changed later: fasta file still is edited inplace

parent 78f70acf
__version__ = "0.2.2" __version__ = "0.3.0"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -32,7 +32,10 @@ def predict(params): ...@@ -32,7 +32,10 @@ def predict(params):
batch_size=params.batch_size, batch_size=params.batch_size,
num_workers=4) num_workers=4)
preds = [pred for batch in trainer.predict(pretrained_model, test_loader) for pred in batch.squeeze().tolist()] preds = trainer.predict(pretrained_model, test_loader)
preds = [batch.squeeze().tolist() for batch in preds]
if any(isinstance(i, list) for i in preds):
preds = [item for batch in preds for item in batch]
preds = np.asarray(preds) preds = np.asarray(preds)
return preds return preds
...@@ -85,28 +88,63 @@ def add_args(parser): ...@@ -85,28 +88,63 @@ def add_args(parser):
remove_argument(parser, "--lr") remove_argument(parser, "--lr")
add_esm_args(parser) add_esm_args(parser)
parser.set_defaults(max_len=None)
parser.set_defaults(min_len=0)
return parser return parser
def get_max_len(fasta_file):
max_len = 0
for record in SeqIO.parse(fasta_file, "fasta"):
if len(record.seq) > max_len:
max_len = len(record.seq)
return max_len
def get_protein_names(fasta_file):
names = []
for record in SeqIO.parse(fasta_file, "fasta"):
names.append(record.id)
return set(names)
def main(params): def main(params):
tmp_pairs = 'protein.pairs.tsv'
fasta_max_len = get_max_len(params.fasta_file)
if params.max_len is None:
params.max_len = fasta_max_len
process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len) process_string_fasta(params.fasta_file, min_len=params.min_len, max_len=params.max_len)
if params.pairs_file is None: if params.pairs_file is None:
generate_pairs(params.fasta_file, 'protein.pairs.tsv', with_self=params.with_self) generate_pairs(params.fasta_file, tmp_pairs, with_self=params.with_self)
params.pairs_file = 'protein.pairs.tsv' params.pairs_file = tmp_pairs
else:
if params.max_len < fasta_max_len:
proteins_in_fasta = get_protein_names(params.fasta_file)
data_tmp = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
data_tmp = data_tmp[data_tmp.iloc[:, 0].isin(proteins_in_fasta) &
data_tmp.iloc[:, 1].isin(proteins_in_fasta)]
data_tmp.to_csv(tmp_pairs, sep='\t', index=False, header=False)
params.pairs_file = tmp_pairs
compute_embeddings(params) compute_embeddings(params)
logging.info('Predicting...') logging.info('Predicting...')
preds = predict(params) preds = predict(params)
data = pd.read_csv(params.pairs_file, delimiter='\t') data = pd.read_csv(params.pairs_file, delimiter='\t', header=None)
#if 3 columns, then assign names ['seq1', 'seq2', 'label'] if 2 columns, then names ['seq1', 'seq2']
if len(data.columns) == 3: if len(data.columns) == 3:
data.columns = ['seq1', 'seq2', 'label'] data.columns = ['seq1', 'seq2', 'label']
elif len(data.columns) == 2: elif len(data.columns) == 2:
data.columns = ['seq1', 'seq2'] data.columns = ['seq1', 'seq2']
else: else:
raise ValueError('The pairs file must have 2 or 3 columns: seq1, seq2 and label(optional)') raise ValueError('The tab-separated pairs file must have 2 or 3 columns (without header): '
'protein name 1, protein name 2 and label(optional)')
data['preds'] = preds data['preds'] = preds
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True) data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True)
...@@ -114,6 +152,9 @@ def main(params): ...@@ -114,6 +152,9 @@ def main(params):
data_positive = data[data['preds'] >= params.pred_threshold] data_positive = data[data['preds'] >= params.pred_threshold]
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True) data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
if os.path.isfile(tmp_pairs):
os.remove(tmp_pairs)
if __name__ == '__main__': if __name__ == '__main__':
pred_parser = argparse.ArgumentParser() pred_parser = argparse.ArgumentParser()
......
...@@ -8,11 +8,11 @@ import logging ...@@ -8,11 +8,11 @@ import logging
def add_general_args(parser): def add_general_args(parser):
parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__)) parser.add_argument("-v", "--version", action="version", version="SENSE_PPI v{}".format(__version__))
parser.add_argument("--min_len", type=int, default=50, parser.add_argument("--min_len", type=int, default=50,
help="Minimum length of the protein sequence. " help="Minimum length of the protein sequence. The sequences with smaller length will not be "
"The sequences with smaller length will not be considered.") "considered and will be deleted from the fasta file.")
parser.add_argument("--max_len", type=int, default=800, parser.add_argument("--max_len", type=int, default=800,
help="Maximum length of the protein sequence. " help="Maximum length of the protein sequence. The sequences with larger length will not be "
"The sequences with larger length will not be considered.") "considered and will be deleted from the fasta file.")
parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'], parser.add_argument("--device", type=str, default=determine_device(), choices=['cpu', 'gpu', 'mps'],
help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)." help="Device to used for computations. Options include: cpu, gpu, mps (for MacOS)."
"If not selected the device is set by torch automatically.") "If not selected the device is set by torch automatically.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment