0.1.7

GPU running fixed
parent 52747bde
__version__ = "0.1.6"
__version__ = "0.1.7"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -42,6 +42,8 @@ def main():
if params.device == 'gpu':
torch.set_float32_matmul_precision('high')
logging.info('Device used: {}'.format(params.device))
params.func(params)
......
......@@ -47,7 +47,7 @@ def run(args):
model, alphabet = pretrained.load_model_and_alphabet(args.model_location_esm)
model.eval()
if args.device == 'cuda':
if args.device == 'gpu':
model = model.cuda()
print("Transferred the ESM2 model to GPU")
elif args.device == 'mps':
......@@ -71,7 +71,7 @@ def run(args):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if args.device == 'cuda':
if args.device == 'gpu':
toks = toks.to(device="cuda", non_blocking=True)
elif args.device == 'mps':
toks = toks.to(device="mps", non_blocking=True)
......
......@@ -14,6 +14,7 @@ import shutil
DOWNLOAD_LINK_STRING = "https://stringdb-downloads.org/download/"
def generate_pairs_string(fasta_file, output_file, with_self=False, delete_proteins=None):
ids = []
for record in SeqIO.parse(fasta_file, "fasta"):
......@@ -85,7 +86,8 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require
# Download protein sequences for given species if not downloaded yet
if not os.path.isfile('{}.protein.sequences.v{}.fa'.format(species, version)):
print('Downloading protein sequences')
url = '{0}protein.sequences.v{1}/{2}.protein.sequences.v{1}.fa.gz'.format(DOWNLOAD_LINK_STRING, version, species)
url = '{0}protein.sequences.v{1}/{2}.protein.sequences.v{1}.fa.gz'.format(DOWNLOAD_LINK_STRING, version,
species)
urllib.request.urlretrieve(url, '{}.protein.sequences.v{}.fa.gz'.format(species, version))
print('Unzipping protein sequences')
with gzip.open('{}.protein.sequences.v{}.fa.gz'.format(species, version), 'rb') as f_in:
......@@ -145,5 +147,6 @@ def get_interactions_from_string(gene_names, species=9606, add_nodes=10, require
SeqIO.write(record, f, "fasta")
string_interactions.to_csv('string_interactions.tsv', sep='\t', index=False)
if __name__ == '__main__':
get_interactions_from_string('RFC5')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment