0.6.0 minor changes

parent 0999af3c
__version__ = "0.5.9"
__version__ = "0.6.0"
__author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils
......
......@@ -12,9 +12,9 @@ import torch.optim as optim
import numpy as np
class DynamicLSTM(pl.LightningModule):
class DynamicGRU(pl.LightningModule):
"""
Dynamic LSTM module, which can handle variable length input sequence.
Dynamic GRU module, which can handle variable length input sequence.
Parameters
----------
......@@ -33,12 +33,12 @@ class DynamicLSTM(pl.LightningModule):
-------
output: tensor, shaped [batch, max_step, num_directions * hidden_size],
tensor containing the output features (h_t) from the last layer
of the LSTM, for each t.
of the GRU, for each t.
"""
def __init__(self, input_size, hidden_size=100,
num_layers=1, dropout=0., bidirectional=False, return_sequences=False):
super(DynamicLSTM, self).__init__()
super(DynamicGRU, self).__init__()
self.lstm = torch.nn.GRU(
input_size, hidden_size, num_layers, bias=True,
......@@ -221,7 +221,7 @@ class SensePPIModel(BaselineModel):
self.encoder_features = self.hparams.encoder_features
self.hidden_dim = 256
self.lstm = DynamicLSTM(self.encoder_features, hidden_size=128, num_layers=3, dropout=0.5, bidirectional=True)
self.lstm = DynamicGRU(self.encoder_features, hidden_size=128, num_layers=3, dropout=0.5, bidirectional=True)
self.dense_head = torch.nn.Sequential(
torch.nn.Dropout(p=0.5),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment