Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
SENSE-PPI
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Konstantin Volzhenin
SENSE-PPI
Commits
2fe8d626
Commit
2fe8d626
authored
Jul 28, 2023
by
Konstantin Volzhenin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
0.1.9 minor bugfix
parent
a65a0e1a
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
140 deletions
+81
-140
predict.py
senseppi/commands/predict.py
+5
-10
predict_string.py
senseppi/commands/predict_string.py
+50
-102
test.py
senseppi/commands/test.py
+4
-9
train.py
senseppi/commands/train.py
+7
-12
network_utils.py
senseppi/network_utils.py
+2
-7
utils.py
senseppi/utils.py
+13
-0
No files found.
senseppi/commands/predict.py
View file @
2fe8d626
...
...
@@ -3,7 +3,6 @@ import pytorch_lightning as pl
from
itertools
import
permutations
,
product
import
numpy
as
np
import
pandas
as
pd
import
logging
import
pathlib
import
argparse
from
..dataset
import
PairSequenceData
...
...
@@ -97,11 +96,7 @@ def main(params):
compute_embeddings
(
params
)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if
params
.
device
==
'mps'
:
logging
.
warning
(
'WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.'
)
params
.
device
=
'cpu'
block_mps
(
params
)
logging
.
info
(
'Predicting...'
)
preds
=
predict
(
params
)
...
...
@@ -116,8 +111,8 @@ def main(params):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
p
arser
=
add_args
(
parser
)
p
arams
=
parser
.
parse_args
()
p
red_p
arser
=
argparse
.
ArgumentParser
()
p
red_parser
=
add_args
(
pred_
parser
)
p
red_params
=
pred_
parser
.
parse_args
()
main
(
params
)
main
(
p
red_p
arams
)
senseppi/commands/predict_string.py
View file @
2fe8d626
from
torch.utils.data
import
DataLoader
import
pytorch_lightning
as
pl
from
torchmetrics
import
AUROC
,
Accuracy
,
Precision
,
Recall
,
F1Score
,
MatthewsCorrCoef
import
networkx
as
nx
import
seaborn
as
sns
...
...
@@ -10,19 +8,17 @@ from matplotlib.patches import Rectangle
import
argparse
import
matplotlib.pyplot
as
plt
import
glob
import
logging
from
..model
import
SensePPIModel
from
..utils
import
*
from
..network_utils
import
*
from
..esm2_model
import
add_esm_args
,
compute_embeddings
from
..dataset
import
PairSequenceData
from
.predict
import
predict
def
main
(
params
):
LABEL_THRESHOLD
=
params
.
score
/
1000.
PRED_THRESHOLD
=
params
.
pred_threshold
/
1000.
label_threshold
=
params
.
score
/
1000.
pred_threshold
=
params
.
pred_threshold
/
1000.
pairs_file
=
'protein.pairs_string.tsv'
fasta_file
=
'sequences.fasta'
...
...
@@ -30,22 +26,18 @@ def main(params):
get_interactions_from_string
(
params
.
genes
,
species
=
params
.
species
,
add_nodes
=
params
.
nodes
,
required_score
=
params
.
score
,
network_type
=
params
.
network_type
)
process_string_fasta
(
fasta_file
,
min_len
=
params
.
min_len
,
max_len
=
params
.
max_len
)
generate_pairs_string
(
fasta_file
,
output_file
=
pairs_file
,
with_self
=
False
,
delete_proteins
=
params
.
delete_proteins
)
generate_pairs_string
(
fasta_file
,
output_file
=
pairs_file
,
delete_proteins
=
params
.
delete_proteins
)
params
.
fasta_file
=
fasta_file
compute_embeddings
(
params
)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if
params
.
device
==
'mps'
:
logging
.
warning
(
'WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.'
)
params
.
device
=
'cpu'
block_mps
(
params
)
preds
=
predict
(
params
)
# open the actions tsv file as dataframe and add the last column with the predictions
data
=
pd
.
read_csv
(
'protein.pairs_string.tsv'
,
delimiter
=
'
\t
'
,
names
=
[
"seq1"
,
"seq2"
,
"string_label"
])
data
[
'binary_label'
]
=
data
[
'string_label'
]
.
apply
(
lambda
x
:
1
if
x
>
LABEL_THRESHOLD
else
0
)
data
[
'binary_label'
]
=
data
[
'string_label'
]
.
apply
(
lambda
x
:
1
if
x
>
label_threshold
else
0
)
data
[
'preds'
]
=
preds
print
(
data
.
sort_values
(
by
=
[
'preds'
],
ascending
=
False
)
.
to_string
())
...
...
@@ -53,13 +45,18 @@ def main(params):
# Calculate torch metrics based on data['binary_label'] and data['preds']
torch_labels
=
torch
.
tensor
(
data
[
'binary_label'
])
torch_preds
=
torch
.
tensor
(
data
[
'preds'
])
print
(
'Accuracy: '
,
Accuracy
(
threshold
=
PRED_THRESHOLD
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'Precision: '
,
Precision
(
threshold
=
PRED_THRESHOLD
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'Recall: '
,
Recall
(
threshold
=
PRED_THRESHOLD
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'F1Score: '
,
F1Score
(
threshold
=
PRED_THRESHOLD
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'Accuracy: '
,
Accuracy
(
threshold
=
pred_threshold
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'Precision: '
,
Precision
(
threshold
=
pred_threshold
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'Recall: '
,
Recall
(
threshold
=
pred_threshold
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'F1Score: '
,
F1Score
(
threshold
=
pred_threshold
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'MatthewsCorrCoef: '
,
MatthewsCorrCoef
(
num_classes
=
2
,
threshold
=
PRED_THRESHOLD
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'ROCAUC: '
,
AUROC
(
task
=
'binary'
)(
torch_preds
,
torch_labels
))
MatthewsCorrCoef
(
num_classes
=
2
,
threshold
=
pred_threshold
,
task
=
'binary'
)(
torch_preds
,
torch_labels
))
print
(
'ROCAUC: '
,
AUROC
(
task
=
'binary'
)(
torch_preds
,
torch_labels
))
string_ids
=
{}
string_tsv
=
pd
.
read_csv
(
'string_interactions.tsv'
,
delimiter
=
'
\t
'
)[
...
...
@@ -74,37 +71,6 @@ def main(params):
data_to_save
=
data_to_save
.
sort_values
(
by
=
[
'preds'
],
ascending
=
False
)
data_to_save
.
to_csv
(
params
.
output
+
'.tsv'
,
sep
=
'
\t
'
,
index
=
False
)
# This part was needed to color the pairs belonging to the train data, temporarily removed
# print('Fetching gene names for training set from STRING...')
#
# if not os.path.exists('all_genes_train.tsv'):
# all_genes = generate_dscript_gene_names(
# file_path=actions_path,
# only_positives=True,
# species=str(hparams.species))
# all_genes.to_csv('all_genes_train.tsv', sep='\t', index=False)
# else:
# all_genes = pd.read_csv('all_genes_train.tsv', sep='\t')
# full_train_data = pd.read_csv(actions_path,
# delimiter='\t', names=['seq1', 'seq2', 'label'])
#
# if all_genes is not None:
# full_train_data = full_train_data.merge(all_genes, left_on='seq1', right_on='QueryString', how='left').merge(
# all_genes, left_on='seq2', right_on='QueryString', how='left')
#
# full_train_data = full_train_data[['preferredName_x', 'preferredName_y', 'label']]
#
# positive_train_data = full_train_data[full_train_data['label'] == 1][['preferredName_x', 'preferredName_y']]
# full_train_data = full_train_data[['preferredName_x', 'preferredName_y']]
#
# full_train_data = [tuple(x) for x in full_train_data.values]
# positive_train_data = [tuple(x) for x in positive_train_data.values]
# else:
# full_train_data = None
# positive_train_data = None
if
params
.
graphs
:
# Create two subpolots but make a short gap between them
fig
,
(
ax1
,
ax2
)
=
plt
.
subplots
(
1
,
2
,
figsize
=
(
18
,
5
),
gridspec_kw
=
{
'width_ratios'
:
[
1
,
1
],
'wspace'
:
0.2
})
...
...
@@ -154,14 +120,6 @@ def main(params):
np
.
triu_indices_from
(
data_heatmap
.
values
)]
labels_heatmap
.
fillna
(
value
=-
1
,
inplace
=
True
)
# This part was needed to color the pairs belonging to the train data, temporarily removed
# if full_train_data is not None:
# for i, row in labels_heatmap.iterrows():
# for j, _ in row.items():
# if (i, j) in full_train_data or (j, i) in full_train_data:
# labels_heatmap.loc[i, j] = -1
cmap
=
matplotlib
.
cm
.
get_cmap
(
'coolwarm'
)
.
copy
()
cmap
.
set_bad
(
"black"
)
...
...
@@ -179,53 +137,40 @@ def main(params):
for
i
in
range
(
len
(
labels_heatmap
)):
ax1
.
add_patch
(
Rectangle
((
i
,
i
),
1
,
1
,
fill
=
True
,
color
=
'white'
,
alpha
=
1
,
zorder
=
100
))
G
=
nx
.
Graph
()
pred_graph
=
nx
.
Graph
()
for
i
,
row
in
data
.
iterrows
():
if
row
[
'string_label'
]
>
LABEL_THRESHOLD
:
G
.
add_edge
(
row
[
'seq1'
],
row
[
'seq2'
],
color
=
'black'
,
weight
=
row
[
'string_label'
],
style
=
'dotted'
)
if
row
[
'preds'
]
>
PRED_THRESHOLD
and
G
.
has_edge
(
row
[
'seq1'
],
row
[
'seq2'
]):
G
[
row
[
'seq1'
]][
row
[
'seq2'
]][
'style'
]
=
'solid'
G
[
row
[
'seq1'
]][
row
[
'seq2'
]][
'color'
]
=
'limegreen'
if
row
[
'preds'
]
>
PRED_THRESHOLD
and
row
[
'string_label'
]
<=
LABEL_THRESHOLD
:
G
.
add_edge
(
row
[
'seq1'
],
row
[
'seq2'
],
color
=
'red'
,
weight
=
row
[
'preds'
],
style
=
'solid'
)
if
row
[
'string_label'
]
>
label_threshold
:
pred_graph
.
add_edge
(
row
[
'seq1'
],
row
[
'seq2'
],
color
=
'black'
,
weight
=
row
[
'string_label'
],
style
=
'dotted'
)
if
row
[
'preds'
]
>
pred_threshold
and
pred_graph
.
has_edge
(
row
[
'seq1'
],
row
[
'seq2'
]):
pred_graph
[
row
[
'seq1'
]][
row
[
'seq2'
]][
'style'
]
=
'solid'
pred_graph
[
row
[
'seq1'
]][
row
[
'seq2'
]][
'color'
]
=
'limegreen'
if
row
[
'preds'
]
>
pred_threshold
and
row
[
'string_label'
]
<=
label_threshold
:
pred_graph
.
add_edge
(
row
[
'seq1'
],
row
[
'seq2'
],
color
=
'red'
,
weight
=
row
[
'preds'
],
style
=
'solid'
)
for
node
in
pred_graph
.
nodes
():
pred_graph
.
nodes
[
node
][
'color'
]
=
'lightgrey'
# Replace the string ids with gene names
G
=
nx
.
relabel_nodes
(
G
,
string_ids
)
# This part was needed to color the pairs belonging to the train data, temporarily removed
# if positive_train_data is not None:
# for edge in G.edges():
# if (edge[0], edge[1]) in positive_train_data or (edge[1], edge[0]) in positive_train_data:
# print('TRAINING EDGE: ', edge)
# G[edge[0]][edge[1]]['color'] = 'darkblue'
# # G[edge[0]][edge[1]]['weight'] = 1
# Make nodes red if they are present in training data
for
node
in
G
.
nodes
():
# if all_genes is not None and node in all_genes['preferredName'].values:
# G.nodes[node]['color'] = 'orange'
# else:
G
.
nodes
[
node
][
'color'
]
=
'lightgrey'
pos
=
nx
.
spring_layout
(
G
,
k
=
2.
,
iterations
=
100
)
nx
.
draw
(
G
,
pos
=
pos
,
with_labels
=
True
,
ax
=
ax2
,
edge_color
=
[
G
[
u
][
v
][
'color'
]
for
u
,
v
in
G
.
edges
()],
width
=
[
G
[
u
][
v
][
'weight'
]
for
u
,
v
in
G
.
edges
()],
style
=
[
G
[
u
][
v
][
'style'
]
for
u
,
v
in
G
.
edges
()],
node_color
=
[
G
.
nodes
[
node
][
'color'
]
for
node
in
G
.
nodes
()])
pred_graph
=
nx
.
relabel_nodes
(
pred_graph
,
string_ids
)
pos
=
nx
.
spring_layout
(
pred_graph
,
k
=
2.
,
iterations
=
100
)
nx
.
draw
(
pred_graph
,
pos
=
pos
,
with_labels
=
True
,
ax
=
ax2
,
edge_color
=
[
pred_graph
[
u
][
v
][
'color'
]
for
u
,
v
in
pred_graph
.
edges
()],
width
=
[
pred_graph
[
u
][
v
][
'weight'
]
for
u
,
v
in
pred_graph
.
edges
()],
style
=
[
pred_graph
[
u
][
v
][
'style'
]
for
u
,
v
in
pred_graph
.
edges
()],
node_color
=
[
pred_graph
.
nodes
[
node
][
'color'
]
for
node
in
pred_graph
.
nodes
()])
legend_elements
=
[
# Line2D([0], [0], marker='_', color='darkblue', label='PP from training data', markerfacecolor='darkblue',
# markersize=10),
Line2D
([
0
],
[
0
],
marker
=
'_'
,
color
=
'limegreen'
,
label
=
'PP'
,
markerfacecolor
=
'limegreen'
,
markersize
=
10
),
Line2D
([
0
],
[
0
],
marker
=
'_'
,
color
=
'red'
,
label
=
'FP'
,
markerfacecolor
=
'red'
,
markersize
=
10
),
Line2D
([
0
],
[
0
],
marker
=
'_'
,
color
=
'black'
,
label
=
'FN - based on STRING'
,
markerfacecolor
=
'black'
,
markersize
=
10
,
linestyle
=
'dotted'
)]
plt
.
legend
(
handles
=
legend_elements
,
loc
=
'upper right'
,
bbox_to_anchor
=
(
1.2
,
0.0
),
ncol
=
1
,
fontsize
=
8
)
savepath
=
'{}_graph_{}_{}.pdf'
.
format
(
params
.
output
,
'_'
.
join
(
params
.
genes
),
params
.
species
)
plt
.
savefig
(
savepath
,
bbox_inches
=
'tight'
,
dpi
=
600
)
print
(
"The graphs were saved to: "
,
savepath
)
save
_
path
=
'{}_graph_{}_{}.pdf'
.
format
(
params
.
output
,
'_'
.
join
(
params
.
genes
),
params
.
species
)
plt
.
savefig
(
save
_
path
,
bbox_inches
=
'tight'
,
dpi
=
600
)
print
(
"The graphs were saved to: "
,
save
_
path
)
plt
.
show
()
plt
.
close
()
...
...
@@ -235,6 +180,7 @@ def main(params):
os
.
remove
(
f
)
os
.
remove
(
'string_interactions.tsv'
)
def
add_args
(
parser
):
parser
=
add_general_args
(
parser
)
...
...
@@ -242,8 +188,8 @@ def add_args(parser):
parser
.
_action_groups
[
0
]
.
add_argument
(
"model_path"
,
type
=
str
,
help
=
"A path to .ckpt file that contains weights to a pretrained model."
)
parser
.
_action_groups
[
0
]
.
add_argument
(
"genes"
,
type
=
str
,
nargs
=
"+"
,
help
=
"Name of gene to fetch from STRING database. Several names can be typed (separated by
"
"whitespaces)
"
)
help
=
"Name of gene to fetch from STRING database. Several names can be
"
"typed (separated by whitespaces).
"
)
string_pred_args
.
add_argument
(
"-s"
,
"--species"
,
type
=
int
,
default
=
9606
,
help
=
"Species from STRING database. Default: 9606 (H. Sapiens)"
)
string_pred_args
.
add_argument
(
"-n"
,
"--nodes"
,
type
=
int
,
default
=
10
,
...
...
@@ -252,7 +198,8 @@ def add_args(parser):
help
=
"Score threshold for STRING connections. Range: (0, 1000). Default: 500"
)
string_pred_args
.
add_argument
(
"-p"
,
"--pred_threshold"
,
type
=
int
,
default
=
500
,
help
=
"Prediction threshold. Range: (0, 1000). Default: 500"
)
string_pred_args
.
add_argument
(
"--graphs"
,
action
=
'store_true'
,
help
=
"Enables plotting the heatmap and a network graph."
)
string_pred_args
.
add_argument
(
"--graphs"
,
action
=
'store_true'
,
help
=
"Enables plotting the heatmap and a network graph."
)
string_pred_args
.
add_argument
(
"-o"
,
"--output"
,
type
=
str
,
default
=
"preds_from_string"
,
help
=
"A path to a file where the predictions will be saved. "
"(.tsv format will be added automatically)"
)
...
...
@@ -269,6 +216,8 @@ def add_args(parser):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
add_args
(
parser
)
params
=
parser
.
parse_args
()
\ No newline at end of file
pred_parser
=
argparse
.
ArgumentParser
()
pred_parser
=
add_args
(
pred_parser
)
pred_params
=
pred_parser
.
parse_args
()
main
(
pred_params
)
senseppi/commands/test.py
View file @
2fe8d626
from
torch.utils.data
import
DataLoader
import
pytorch_lightning
as
pl
import
pandas
as
pd
import
logging
import
pathlib
import
argparse
from
..dataset
import
PairSequenceData
...
...
@@ -72,11 +71,7 @@ def main(params):
compute_embeddings
(
params
)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if
params
.
device
==
'mps'
:
logging
.
warning
(
'WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.'
)
params
.
device
=
'cpu'
block_mps
(
params
)
logging
.
info
(
'Evaluating...'
)
test_metrics
=
test
(
params
)[
0
]
...
...
@@ -87,7 +82,7 @@ def main(params):
if
__name__
==
'__main__'
:
test_parser
=
argparse
.
ArgumentParser
()
parser
=
add_args
(
test_parser
)
params
=
test_parser
.
parse_args
()
test_
parser
=
add_args
(
test_parser
)
test_
params
=
test_parser
.
parse_args
()
main
(
params
)
main
(
test_
params
)
senseppi/commands/train.py
View file @
2fe8d626
...
...
@@ -2,8 +2,7 @@ import pytorch_lightning as pl
from
pytorch_lightning.callbacks
import
ModelCheckpoint
import
pathlib
import
argparse
import
logging
from
..utils
import
add_general_args
from
..utils
import
*
from
..model
import
SensePPIModel
from
..dataset
import
PairSequenceData
from
..esm2_model
import
add_esm_args
,
compute_embeddings
...
...
@@ -15,11 +14,7 @@ def main(params):
compute_embeddings
(
params
)
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if
params
.
device
==
'mps'
:
logging
.
warning
(
'WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.'
)
params
.
device
=
'cpu'
block_mps
(
params
)
dataset
=
PairSequenceData
(
emb_dir
=
params
.
output_dir_esm
,
actions_file
=
params
.
pairs_file
,
max_len
=
params
.
max_len
,
labels
=
True
)
...
...
@@ -80,8 +75,8 @@ def add_args(parser):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
add_args
(
parser
)
params
=
parser
.
parse_args
()
train_
parser
=
argparse
.
ArgumentParser
()
train_parser
=
add_args
(
train_
parser
)
train_params
=
train_
parser
.
parse_args
()
main
(
params
)
\ No newline at end of file
main
(
train_params
)
\ No newline at end of file
senseppi/network_utils.py
View file @
2fe8d626
...
...
@@ -15,18 +15,13 @@ import shutil
DOWNLOAD_LINK_STRING
=
"https://stringdb-downloads.org/download/"
def
generate_pairs_string
(
fasta_file
,
output_file
,
with_self
=
False
,
delete_proteins
=
None
):
def
generate_pairs_string
(
fasta_file
,
output_file
,
delete_proteins
=
None
):
ids
=
[]
for
record
in
SeqIO
.
parse
(
fasta_file
,
"fasta"
):
ids
.
append
(
record
.
id
)
if
with_self
:
all_pairs
=
[
p
for
p
in
product
(
ids
,
repeat
=
2
)]
else
:
all_pairs
=
[
p
for
p
in
permutations
(
ids
,
2
)]
pairs
=
[]
for
p
in
all_pairs
:
for
p
in
[
p
for
p
in
permutations
(
ids
,
2
)]
:
if
(
p
[
1
],
p
[
0
])
not
in
pairs
and
(
p
[
0
],
p
[
1
])
not
in
pairs
:
pairs
.
append
(
p
)
...
...
senseppi/utils.py
View file @
2fe8d626
...
...
@@ -2,6 +2,7 @@ from Bio import SeqIO
import
os
from
senseppi
import
__version__
import
torch
import
logging
def
add_general_args
(
parser
):
...
...
@@ -29,6 +30,18 @@ def determine_device():
return
device
def
block_mps
(
params
):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if
hasattr
(
params
,
'device'
):
if
params
.
device
==
'mps'
:
logging
.
warning
(
'WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.'
)
if
torch
.
cuda
.
is_available
():
params
.
device
=
'gpu'
else
:
params
.
device
=
'cpu'
def
process_string_fasta
(
fasta_file
,
min_len
,
max_len
):
with
open
(
'file.tmp'
,
'w'
)
as
f
:
for
record
in
SeqIO
.
parse
(
fasta_file
,
"fasta"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment