Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
SENSE-PPI
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Konstantin Volzhenin
SENSE-PPI
Commits
f7ddda97
Commit
f7ddda97
authored
Jul 27, 2023
by
Konstantin Volzhenin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
0.1.7 embeddings are computed not from the beginning but only for missing sequences
parent
ccdf1277
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
33 deletions
+47
-33
__init__.py
senseppi/__init__.py
+1
-1
__main__.py
senseppi/__main__.py
+5
-5
esm2_model.py
senseppi/esm2_model.py
+41
-27
No files found.
senseppi/__init__.py
View file @
f7ddda97
__version__
=
"0.1.
7
"
__version__
=
"0.1.
8
"
__author__
=
"Konstantin Volzhenin"
from
.
import
model
,
commands
,
esm2_model
,
dataset
,
utils
,
network_utils
...
...
senseppi/__main__.py
View file @
f7ddda97
...
...
@@ -33,11 +33,11 @@ def main():
params
=
parser
.
parse_args
()
if
hasattr
(
params
,
'device'
):
# WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
if
params
.
device
==
'mps'
:
logging
.
warning
(
'WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
'The cpu backend will be used instead.'
)
params
.
device
=
'cpu'
#
#
WARNING: due to some internal issues of pytorch, the mps backend is temporarily disabled
#
if params.device == 'mps':
#
logging.warning('WARNING: due to some internal issues of torch, the mps backend is temporarily disabled.'
#
'The cpu backend will be used instead.')
#
params.device = 'cpu'
if
params
.
device
==
'gpu'
:
torch
.
set_float32_matmul_precision
(
'high'
)
...
...
senseppi/esm2_model.py
View file @
f7ddda97
...
...
@@ -7,6 +7,8 @@ import torch
import
os
import
logging
from
esm
import
FastaBatchedDataset
,
pretrained
from
copy
import
copy
from
Bio
import
SeqIO
def
add_esm_args
(
parent_parser
):
...
...
@@ -43,37 +45,37 @@ def add_esm_args(parent_parser):
)
def
run
(
arg
s
):
model
,
alphabet
=
pretrained
.
load_model_and_alphabet
(
arg
s
.
model_location_esm
)
def
run
(
param
s
):
model
,
alphabet
=
pretrained
.
load_model_and_alphabet
(
param
s
.
model_location_esm
)
model
.
eval
()
if
arg
s
.
device
==
'gpu'
:
if
param
s
.
device
==
'gpu'
:
model
=
model
.
cuda
()
print
(
"Transferred the ESM2 model to GPU"
)
elif
arg
s
.
device
==
'mps'
:
elif
param
s
.
device
==
'mps'
:
model
=
model
.
to
(
'mps'
)
print
(
"Transferred the ESM2 model to MPS"
)
dataset
=
FastaBatchedDataset
.
from_file
(
arg
s
.
fasta_file
)
batches
=
dataset
.
get_batch_indices
(
arg
s
.
toks_per_batch_esm
,
extra_toks_per_seq
=
1
)
dataset
=
FastaBatchedDataset
.
from_file
(
param
s
.
fasta_file
)
batches
=
dataset
.
get_batch_indices
(
param
s
.
toks_per_batch_esm
,
extra_toks_per_seq
=
1
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
collate_fn
=
alphabet
.
get_batch_converter
(
arg
s
.
truncation_seq_length_esm
),
batch_sampler
=
batches
dataset
,
collate_fn
=
alphabet
.
get_batch_converter
(
param
s
.
truncation_seq_length_esm
),
batch_sampler
=
batches
)
print
(
f
"Read {
arg
s.fasta_file} with {len(dataset)} sequences"
)
print
(
f
"Read {
param
s.fasta_file} with {len(dataset)} sequences"
)
arg
s
.
output_dir_esm
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
param
s
.
output_dir_esm
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
assert
all
(
-
(
model
.
num_layers
+
1
)
<=
i
<=
model
.
num_layers
for
i
in
arg
s
.
repr_layers_esm
)
repr_layers
=
[(
i
+
model
.
num_layers
+
1
)
%
(
model
.
num_layers
+
1
)
for
i
in
arg
s
.
repr_layers_esm
]
assert
all
(
-
(
model
.
num_layers
+
1
)
<=
i
<=
model
.
num_layers
for
i
in
param
s
.
repr_layers_esm
)
repr_layers
=
[(
i
+
model
.
num_layers
+
1
)
%
(
model
.
num_layers
+
1
)
for
i
in
param
s
.
repr_layers_esm
]
with
torch
.
no_grad
():
for
batch_idx
,
(
labels
,
strs
,
toks
)
in
enumerate
(
data_loader
):
print
(
f
"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if
arg
s
.
device
==
'gpu'
:
if
param
s
.
device
==
'gpu'
:
toks
=
toks
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
elif
arg
s
.
device
==
'mps'
:
elif
param
s
.
device
==
'mps'
:
toks
=
toks
.
to
(
device
=
"mps"
,
non_blocking
=
True
)
out
=
model
(
toks
,
repr_layers
=
repr_layers
,
return_contacts
=
False
)
...
...
@@ -83,42 +85,54 @@ def run(args):
}
for
i
,
label
in
enumerate
(
labels
):
args
.
output_file_esm
=
arg
s
.
output_dir_esm
/
f
"{label}.pt"
arg
s
.
output_file_esm
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
params
.
output_file_esm
=
param
s
.
output_dir_esm
/
f
"{label}.pt"
param
s
.
output_file_esm
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
result
=
{
"label"
:
label
}
truncate_len
=
min
(
arg
s
.
truncation_seq_length_esm
,
len
(
strs
[
i
]))
truncate_len
=
min
(
param
s
.
truncation_seq_length_esm
,
len
(
strs
[
i
]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
result
[
"representations"
]
=
{
layer
:
t
[
i
,
1
:
truncate_len
+
1
]
.
clone
()
layer
:
t
[
i
,
1
:
truncate_len
+
1
]
.
clone
()
for
layer
,
t
in
representations
.
items
()
}
torch
.
save
(
result
,
arg
s
.
output_file_esm
,
param
s
.
output_file_esm
,
)
def
compute_embeddings
(
params
):
# Compute ESM embeddings
logging
.
info
(
'Computing ESM embeddings
if they are not already computed.
'
'
If all the files alreaady exist in {} folder,
this step will be skipped.'
.
format
(
params
.
output_dir_esm
))
logging
.
info
(
'Computing ESM embeddings
. If all the files already exist in {} folder,
'
'this step will be skipped.'
.
format
(
params
.
output_dir_esm
))
if
not
os
.
path
.
exists
(
params
.
output_dir_esm
):
run
(
params
)
else
:
with
open
(
params
.
fasta_file
,
'r'
)
as
f
:
seq_ids
=
[
line
.
strip
()
.
split
(
' '
)[
0
]
.
replace
(
'>'
,
''
)
for
line
in
f
.
readlines
()
if
line
.
startswith
(
'>'
)]
# dict of only id and sequences from parsing fasta file
seq_dict
=
SeqIO
.
to_dict
(
SeqIO
.
parse
(
f
,
'fasta'
))
seq_ids
=
list
(
seq_dict
.
keys
())
for
seq_id
in
seq_ids
:
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
params
.
output_dir_esm
,
seq_id
+
'.pt'
)):
run
(
params
)
break
if
os
.
path
.
exists
(
os
.
path
.
join
(
params
.
output_dir_esm
,
seq_id
+
'.pt'
)):
seq_dict
.
pop
(
seq_id
)
if
len
(
seq_dict
)
>
0
:
params_esm
=
copy
(
params
)
params_esm
.
fasta_file
=
'tmp_for_esm.fasta'
with
open
(
params_esm
.
fasta_file
,
'w'
)
as
f
:
for
seq_id
in
seq_dict
.
keys
():
f
.
write
(
'>'
+
seq_id
+
'
\n
'
)
f
.
write
(
str
(
seq_dict
[
seq_id
]
.
seq
)
+
'
\n
'
)
run
(
params_esm
)
os
.
remove
(
params_esm
.
fasta_file
)
else
:
logging
.
info
(
'All ESM embeddings already computed'
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
add_esm_args
(
parser
)
args
=
parser
.
parse_args
()
esm_
parser
=
argparse
.
ArgumentParser
()
add_esm_args
(
esm_
parser
)
args
=
esm_
parser
.
parse_args
()
run
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment