0.2.2 fixed tsv bug in predict

parent 4b1847e2
......@@ -127,3 +127,5 @@ dmypy.json
/esm2_embs_3B
*.sh
draft.py
/data/string_species/mmseqs_dbs/
/data/human_virus/all_test_viruses.csv
__version__ = "0.2.1"
__version__ = "0.2.2"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -99,13 +99,20 @@ def main(params):
logging.info('Predicting...')
preds = predict(params)
data = pd.read_csv(params.pairs_file, delimiter='\t', names=["seq1", "seq2"])
data = pd.read_csv(params.pairs_file, delimiter='\t')
#if 3 columns, then assign names ['seq1', 'seq2', 'label'] if 2 columns, then names ['seq1', 'seq2']
if len(data.columns) == 3:
data.columns = ['seq1', 'seq2', 'label']
elif len(data.columns) == 2:
data.columns = ['seq1', 'seq2']
else:
raise ValueError('The pairs file must have 2 or 3 columns: seq1, seq2 and label(optional)')
data['preds'] = preds
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=False)
data.to_csv(params.output + '.tsv', sep='\t', index=False, header=True)
data_positive = data[data['preds'] >= params.pred_threshold]
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=False)
data_positive.to_csv(params.output + '_positive_interactions.tsv', sep='\t', index=False, header=True)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment