0.1.3 fixed max_len bug in dataset.py (was not transfered first due to a wrong…

0.1.3 fixed max_len bug in dataset.py (was not transfered first due to a wrong final version in previous repository)
parent 87fcd885
__version__ = "0.1.2" __version__ = "0.1.3"
__author__ = "Konstantin Volzhenin" __author__ = "Konstantin Volzhenin"
from . import model, commands, esm2_model, dataset, utils, network_utils from . import model, commands, esm2_model, dataset, utils, network_utils
......
...@@ -42,6 +42,7 @@ class PairSequenceData(Dataset): ...@@ -42,6 +42,7 @@ class PairSequenceData(Dataset):
if self.pad_inputs: if self.pad_inputs:
if tensor_emb.shape[0] > self.max_len: if tensor_emb.shape[0] > self.max_len:
tensor_emb = tensor_emb[:self.max_len] tensor_emb = tensor_emb[:self.max_len]
tensor_len = self.max_len
if tensor_emb.shape[0] < self.max_len: if tensor_emb.shape[0] < self.max_len:
tensor_emb = F.pad(tensor_emb, (0, 0, 0, self.max_len - tensor_emb.size(0)), "constant", 0) tensor_emb = F.pad(tensor_emb, (0, 0, 0, self.max_len - tensor_emb.size(0)), "constant", 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment