Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
SENSE-PPI
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Konstantin Volzhenin
SENSE-PPI
Commits
66c1ef29
Commit
66c1ef29
authored
Nov 23, 2023
by
Konstantin Volzhenin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Args changed (num_nodes) + esm2 path bugfix
parent
a6c3a6aa
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
5 deletions
+11
-5
predict.py
senseppi/commands/predict.py
+4
-1
test.py
senseppi/commands/test.py
+4
-1
esm2_model.py
senseppi/esm2_model.py
+3
-3
No files found.
senseppi/commands/predict.py
View file @
66c1ef29
...
...
@@ -26,7 +26,8 @@ def predict(params):
pretrained_model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
trainer
=
pl
.
Trainer
(
accelerator
=
params
.
device
,
logger
=
False
)
trainer
=
pl
.
Trainer
(
accelerator
=
params
.
device
,
logger
=
False
,
num_nodes
=
params
.
num_nodes
if
hasattr
(
params
,
'num_nodes'
)
else
1
)
test_loader
=
DataLoader
(
dataset
=
test_data
,
batch_size
=
params
.
batch_size
,
...
...
@@ -85,6 +86,8 @@ def add_args(parser):
predict_args
.
add_argument
(
"-p"
,
"--pred_threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Prediction threshold to determine interacting pairs that "
"will be written to a separate file. Range: (0, 1)."
)
predict_args
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
help
=
"Number of nodes to use for launching on a cluster."
)
parser
=
SensePPIModel
.
add_model_specific_args
(
parser
)
remove_argument
(
parser
,
"--lr"
)
...
...
senseppi/commands/test.py
View file @
66c1ef29
...
...
@@ -24,7 +24,8 @@ def test(params):
pretrained_model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
trainer
=
pl
.
Trainer
(
accelerator
=
params
.
device
,
logger
=
False
)
trainer
=
pl
.
Trainer
(
accelerator
=
params
.
device
,
logger
=
False
,
num_nodes
=
params
.
num_nodes
)
eval_loader
=
DataLoader
(
dataset
=
eval_data
,
batch_size
=
params
.
batch_size
,
...
...
@@ -55,6 +56,8 @@ def add_args(parser):
help
=
"If set, the data will be cropped to the limits of the model: "
"evaluations will be done only for proteins >50aa and <800aa. WARNING: "
"this will modify the original input files."
)
test_args
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
help
=
"Number of nodes to use for launching on a cluster."
)
parser
=
SensePPIModel
.
add_model_specific_args
(
parser
)
remove_argument
(
parser
,
"--lr"
)
...
...
senseppi/esm2_model.py
View file @
66c1ef29
...
...
@@ -3,7 +3,7 @@
# Modified by Konstantin Volzhenin, Sorbonne University, 2023
import
argparse
import
pathlib
from
pathlib
import
Path
import
torch
import
os
import
logging
...
...
@@ -29,7 +29,7 @@ def add_esm_args(parent_parser):
)
parser
.
add_argument
(
"--output_dir_esm"
,
type
=
pathlib
.
Path
,
default
=
pathlib
.
Path
(
'esm2_embs_3B'
),
type
=
Path
,
default
=
Path
(
'esm2_embs_3B'
),
help
=
"output directory for extracted representations"
,
)
...
...
@@ -126,7 +126,7 @@ def compute_embeddings(params):
seq_dict
.
pop
(
seq_id
)
if
len
(
seq_dict
)
>
0
:
params_esm
=
copy
(
params
)
params_esm
.
fasta_file
=
'tmp_for_esm.fasta'
params_esm
.
fasta_file
=
Path
(
str
(
params
.
fasta_file
)
.
replace
(
'fasta'
,
'tmp.fasta'
))
with
open
(
params_esm
.
fasta_file
,
'w'
)
as
f
:
for
seq_id
in
seq_dict
.
keys
():
f
.
write
(
'>'
+
seq_id
+
'
\n
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment