Source code for sensai.torch.torch_models.seq.seq_modules

import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, Union, Sequence
from typing_extensions import Protocol

import torch

from ..lstnet.lstnet_modules import LSTNetwork
from ..mlp.mlp_modules import MultiLayerPerceptron
from ... import ActivationFunction
from ....util.string import object_repr, ToStringMixin

log = logging.getLogger(__name__)


[docs]class EncoderProtocol(Protocol):
[docs] def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ :param x: a tensor of shape (batch_size, seq_length=max(lengths), history_features) containing the sequence of history features to encode :param lengths: an optional tensor of shape (batch_size) containing the lengths of the sequences in `x` :return: a tensor of shape (batch_size, latent_dim) containing the encodings """ pass
[docs]class DecoderProtocol(Protocol):
[docs] def forward(self, latent: torch.Tensor, target_features: Optional[torch.Tensor] = None, target_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ :param latent: a tensor of shape (batch_size, latent_dim) containing the latent representations :param target_features: a tensor of shape (batch_size, target_seq_length=max(target_lengths), target_feature_dim) :param target_lengths: a tensor of shape (batch_size) containing the lengths of sequences in `target_features` :return: a tensor of shape (batch_size, output_dim) or (batch_size, target_seq_length, output_dim) containing the predictions, where the shape depends on the use case and can vary depending on the needs """ pass
[docs]class PredictorProtocol(Protocol):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ :param x: a tensor of shape (batch_size, input_dim) an intermediate representation :return: a tensor of shape (batch_size, output_dim) """ pass
# TODO: Should use intersection type A & B once we switch to Python 3.9+ TDecoder = Union[DecoderProtocol, torch.nn.Module] TEncoder = Union[EncoderProtocol, torch.nn.Module] TPredictor = Union[PredictorProtocol, torch.nn.Module]
[docs]class EncoderFactory(ToStringMixin, ABC): """ Represents a factory for an encoder modules that map a sequence of items to a latent vector """
[docs] @abstractmethod def create_encoder(self, input_dim: int, latent_dim: int) -> TEncoder: """ :param input_dim: the input dimension per sequence item :param latent_dim: the latent vector dimension that is to be generated by the encoder :return: a torch module satisfying :class:`EncoderProtocol` """ pass
[docs]class DecoderFactory(ToStringMixin, ABC):
[docs] @abstractmethod def create_decoder(self, latent_dim: int, target_feature_dim: int) -> TDecoder: """ :param latent_dim: the latent vector size which is used for the representation of the history :param target_feature_dim: the number of dimensions/features that are given for each prediction to be made (each future sequence item) :return: a torch module satisfying :class:`DecoderProtocol` """ pass
[docs]class PredictorFactory(ToStringMixin, ABC): """ Represents a factory for predictor components which sample map from an intermediate representation to the desired output dimension. """
[docs] def create_predictor(self, input_dim: int, output_dim: int) -> TPredictor: """ :param input_dim: the input dimension :param output_dim: the output dimension :return: a module which maps an input with dimension `input_dim` to the desired prediction dimension (`output_dim`) """ pass
[docs]class LinearPredictorFactory(PredictorFactory): """A factory for predictors consisting only of a linear layer (without subsequent activation)"""
[docs] def create_predictor(self, input_dim: int, output_dim: int) -> torch.nn.Module: return torch.nn.Linear(input_dim, output_dim)
[docs]class MLPPredictorFactory(PredictorFactory): """A factor for predictors that are multi-layer perceptrons""" def __init__(self, hidden_dims: Sequence[int] = (), hid_activation_fn: ActivationFunction = ActivationFunction.RELU, output_activation_fn: ActivationFunction = ActivationFunction.NONE, p_dropout: Optional[float] = None): self.hidden_dims = hidden_dims self.hid_activation_fn = hid_activation_fn self.output_activation_fn = output_activation_fn self.p_dropout = p_dropout
[docs] def create_predictor(self, input_dim: int, output_dim: int) -> TPredictor: return MultiLayerPerceptron(input_dim, output_dim, self.hidden_dims, hid_activation_fn=self.hid_activation_fn.get_torch_function(), output_activation_fn=self.output_activation_fn.get_torch_function(), p_dropout=self.p_dropout)
[docs]class RnnEncoderModule(torch.nn.Module): """ Encodes a sequence of feature vectors, outputting a latent vector. The input sequence may either be fixed-length or variable-length. """
[docs] class RnnType: GRU = "gru" """gated recurrent unit""" LSTM = "lstm" """long short-term memory"""
def __init__(self, input_dim, latent_dim: int, rnn_type: RnnType = RnnType.LSTM): """ :param input_dim: the input dimension per time slice :param latent_dim: the dimension of the latent output vector :param rnn_type: the type of recurrent network to use """ super().__init__() self.window_dim_per_item = input_dim self.latent_dim = latent_dim self.rnn_type = rnn_type if rnn_type == self.RnnType.GRU: self.rnn = torch.nn.GRU(input_size=self.window_dim_per_item, hidden_size=latent_dim, batch_first=True) elif rnn_type == self.RnnType.LSTM: self.rnn = torch.nn.LSTM(input_size=self.window_dim_per_item, hidden_size=latent_dim, batch_first=True) else: raise ValueError(f"Unknown rnn type '{rnn_type}', use either 'gru' or 'lstm'") def __str__(self): return object_repr(self, dict(rnn=self.rnn))
[docs] def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None): """ :param x: a tensor of size (batch_size, seq_length, dim_per_item) :param lengths: an optional tensor containing the lengths of the sequences; if None, all sequences are assumed to have the same full length :return: a tensor of size (batch_size, latent_dim) """ if lengths is not None: x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) if self.rnn_type == self.RnnType.GRU: _, l = self.rnn(x) elif self.rnn_type == self.RnnType.LSTM: _o, (l, _c) = self.rnn(x) else: raise ValueError(self.rnn_type) # l has shape (1, batch_size, latent_dim) l = l.squeeze(0) return l # (batch_size, latent_dim)
[docs]class RnnEncoderFactory(EncoderFactory): def __init__(self, input_dim: int, latent_dim: int, rnn_type: RnnEncoderModule.RnnType = RnnEncoderModule.RnnType.GRU): self.input_dim = input_dim self.latent_dim = latent_dim self.rnn_type = rnn_type
[docs] def create_encoder(self, input_dim: int, latent_dim: int): return RnnEncoderModule(input_dim, latent_dim, self.rnn_type)
[docs]class LSTNetworkEncoder(torch.nn.Module): """ Adapts an LSTNetwork instance to the encoder interface """ def __init__(self, lstnet: LSTNetwork): super().__init__() self.lstnet = lstnet
[docs] def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None): """ :param x: a tensor of size (batch_size, seq_length, dim_per_item) :param lengths: an optional tensor containing the lengths of the sequences; if None, all sequences are assumed to have the same full length :return: a tensor of size (batch_size, latent_dim) """ if lengths is not None: unique_lengths = torch.unique(lengths) if len(unique_lengths) != 1: raise ValueError("LSTNetwork does not support variable-length inputs") l = self.lstnet(x) return l
[docs]class LSTNetworkEncoderFactory(EncoderFactory): def __init__(self, num_input_time_slices: int, num_convolutions: int, num_cnn_time_slices: int, hid_rnn: int, skip: int, hid_skip: int, dropout: float = 0.2): self.dropout = dropout self.num_input_time_slices = num_input_time_slices self.num_convolutions = num_convolutions self.num_cnn_time_slices = num_cnn_time_slices self.hid_rnn = hid_rnn self.skip = skip self.hid_skip = hid_skip
[docs] def create_encoder(self, input_dim: int, latent_dim: int) -> torch.nn.Module: lstnet = LSTNetwork(num_input_time_slices=self.num_input_time_slices, input_dim_per_time_slice=input_dim, num_convolutions=self.num_convolutions, num_cnn_time_slices=self.num_cnn_time_slices, hid_rnn=self.hid_rnn, skip=self.skip, hid_skip=self.hid_skip, dropout=self.dropout, mode=LSTNetwork.Mode.ENCODER) if lstnet.get_encoder_dim() != latent_dim: raise ValueError(f"LSTNetwork produces latent_dim={lstnet.get_encoder_dim()}; please adjust the parameter") return LSTNetworkEncoder(lstnet)
[docs] def get_latent_dim(self) -> int: return LSTNetwork.compute_encoder_dim(self.hid_rnn, self.skip, self.hid_skip)
[docs]class SingleTargetDecoderModule(torch.nn.Module, DecoderProtocol): """ Represents a decoder that output a single value for a single target item, taking as input the concatenation of the latent tensor (generated by the encoder) and the target item's feature vector. """ def __init__(self, target_feature_dim, latent_dim, predictor_factory: PredictorFactory, output_dim=1): """ :param target_feature_dim: the number of target item features :param latent_dim: the dimension of the latent vector generated by the encoder, which we receive as input :param predictor_factory: a factory for the creation of the predictor that will map the combined latent vector and target feature vector to the prediction of size `output_dim` :param output_dim: the output (prediction) dimension """ super().__init__() self.target_feature_dim = target_feature_dim self.latent_dim = latent_dim self.predictor_input_dim = self.latent_dim + self.target_feature_dim self.predictor = predictor_factory.create_predictor(self.predictor_input_dim, output_dim) def __str__(self): return object_repr(self, dict(predictor=self.predictor))
[docs] def forward(self, latent, target_features=None, target_lengths=None): if target_features is not None: # target_features must have shape (batch_size, 1, target_feature_dim) assert target_features.shape[1] == 1, "target_features must contain but one sequence item" target_features = torch.squeeze(target_features, 1) lf = torch.cat((latent, target_features), dim=1) else: lf = latent return self.predictor(lf)
[docs]class TargetSequenceDecoderModule(torch.nn.Module, DecoderProtocol, ToStringMixin): """ Wrapper for decoders that take as input a latent representation (generated by an encoder) and a sequence of target features. It can generate either a single prediction for the entire sequence of target features or a sequence of predictions (one for each target sequence item), depending on the prediction/output mode. """
[docs] class PredictionMode(Enum): """Defines how the prediction works""" SINGLE_LATENT = "single_latent" """ Use an LSTM to process the target feature sequence and use only the final hidden state for prediction, outputting a single average prediction only (for OutputMode.SINGLE_OUTPUT only) """ MULTI_LATENT = "multi_latent" """Use an LSTM to process the target feature sequence and use all hidden states (full output) for prediction""" DIRECT = "direct" """Directly use the latent vector and target features to make predictions for each target sequence item (use with LatentPassOnMode.CONCAT_INPUT & NO_LATENT only) """
[docs] class LatentPassOnMode(Enum): """Defines how the latent state from the encoder stage is passed on to the decoder""" INIT_HIDDEN = "init_hidden" """ Pass on the encoder output as the initial hidden state of the LSTM (only possible for OutputMode in {SINGLE_LATENT, MULTI_LATENT}) """ CONCAT_INPUT = "concat_input" """Pass on the encoder output by concatenating it with each target feature input vector""" NO_LATENT = "no_latent" """ Do not pass on the latent vector at all (ignored by subsequent decoder component). This is mostly useful for ablation testing. """
[docs] class OutputMode(Enum): """Defines how to treat multiple predictions (for PredictionMode != SINGLE_LATENT)""" SINGLE_OUTPUT = "single" """Output a single result from a single input (for PredictionMode.SINGLE_LATENT only)""" SINGLE_OUTPUT_MEAN = "mean" """Output the mean of multiple (intermediate) predictions""" MULTI_OUTPUT = "multi" """Output multiple predictions directly"""
def __init__(self, target_feature_dim: int, latent_dim: int, predictor_factory: PredictorFactory, output_dim: int = 1, prediction_mode: PredictionMode = PredictionMode.MULTI_LATENT, latent_pass_on_mode: LatentPassOnMode = LatentPassOnMode.CONCAT_INPUT, output_mode: OutputMode = OutputMode.MULTI_OUTPUT, p_recurrent_dropout: float = 0.0): super().__init__() if not ((prediction_mode == self.PredictionMode.SINGLE_LATENT) == (output_mode == self.OutputMode.SINGLE_OUTPUT)): # SINGLE_LATENT <=> SINGLE_OUTPUT raise ValueError(f"{self.PredictionMode.SINGLE_LATENT} must coincide with {self.OutputMode.SINGLE_OUTPUT}; " f"got {prediction_mode} and {output_mode}") if prediction_mode == self.PredictionMode.DIRECT and \ latent_pass_on_mode not in (self.LatentPassOnMode.CONCAT_INPUT, self.LatentPassOnMode.NO_LATENT): raise ValueError(f"{prediction_mode} requires {self.LatentPassOnMode.CONCAT_INPUT}") if latent_pass_on_mode == self.LatentPassOnMode.INIT_HIDDEN and \ prediction_mode not in (self.PredictionMode.SINGLE_LATENT, self.PredictionMode.MULTI_LATENT): raise ValueError(f"{output_mode} requires {self.PredictionMode.SINGLE_LATENT} or {self.PredictionMode.MULTI_LATENT} ") if latent_pass_on_mode == self.LatentPassOnMode.NO_LATENT: latent_dim = 0 self.latent_pass_on_mode = latent_pass_on_mode self.prediction_mode = prediction_mode self.output_mode = output_mode self.target_feature_dim = target_feature_dim self.latent_dim = latent_dim if prediction_mode == self.PredictionMode.DIRECT: self.lstm = None predictor_input_dim = self.latent_dim + self.target_feature_dim else: if latent_pass_on_mode == self.LatentPassOnMode.INIT_HIDDEN: rnn_input_dim = target_feature_dim else: rnn_input_dim = target_feature_dim + latent_dim self.lstm = torch.nn.LSTM(rnn_input_dim, self.latent_dim, batch_first=True, dropout=p_recurrent_dropout) predictor_input_dim = self.latent_dim self.predictor = predictor_factory.create_predictor(predictor_input_dim, output_dim) def _tostring_exclude_private(self) -> bool: return True
[docs] def forward(self, latent, target_features=None, target_lengths=None): """ :param latent: a tensor of shape (batch_size, latent_dim) :param target_features: a tensor of shape (batch_size, max_seq_length, target_feature_dim) :param target_lengths: a tensor indicating the lengths of the sequences in target_features :return: """ if target_features is None: raise ValueError(f"target_features cannot be None when using {self.__class__}") # latent has shape (batch_size, latentDim) # targetFeatures has shape (batch_size, maxSeqLength, targetFeatureDim) batch_size = target_features.shape[0] use_lstm = self.prediction_mode != self.PredictionMode.DIRECT lstm_input, s0, latent_plus_target_features = None, None, None if self.latent_pass_on_mode == self.LatentPassOnMode.INIT_HIDDEN: if target_lengths is not None: lstm_input = torch.nn.utils.rnn.pack_padded_sequence(target_features, target_lengths, batch_first=True, enforce_sorted=False) else: lstm_input = target_features c0 = torch.zeros((1, batch_size, self.latent_dim)).to(latent.device) h0 = latent.unsqueeze(0) s0 = (h0, c0) elif self.latent_pass_on_mode in (self.LatentPassOnMode.CONCAT_INPUT, self.LatentPassOnMode.NO_LATENT): if self.latent_pass_on_mode == self.LatentPassOnMode.NO_LATENT: latent_plus_target_features = target_features else: latent = latent.unsqueeze(1) # (batch_size, 1, latentDim) latent = latent.expand(-1, target_features.shape[1], -1) # (batch_size, maxSeqLength, latentDim) latent_plus_target_features = torch.cat((latent, target_features), dim=2) # latent_plus_target_features has shape (batch_size, maxSeqLength, latentDim + targetFeatureDim) if use_lstm: lstm_input = torch.nn.utils.rnn.pack_padded_sequence(latent_plus_target_features, target_lengths, batch_first=True, enforce_sorted=False) s0 = None else: raise ValueError(f"Unknown latent pass-on mode '{self.latent_pass_on_mode}'") if self.prediction_mode == self.PredictionMode.SINGLE_LATENT: # use only final latent state and produce a single output _, (hn, _) = self.lstm(lstm_input, s0) encoding = hn.squeeze(0) result = self.predictor(encoding) else: # compute multiple predictions (and optionally compute their mean) if self.prediction_mode == self.PredictionMode.MULTI_LATENT: # use all latent states hseq, _ = self.lstm(lstm_input, s0) encodings, lengths = torch.nn.utils.rnn.pad_packed_sequence(hseq, batch_first=True) # (batch_size, maxSeqLength, latentDim) predictions = self.predictor(encodings) # (batch_size, maxSeqLength, outputDim) elif self.prediction_mode == self.PredictionMode.DIRECT: # directly map concatenated values to outputs via predictor predictions = self.predictor(latent_plus_target_features) else: raise ValueError(f"Unknown prediction mode '{self.prediction_mode}'") if self.output_mode == self.OutputMode.SINGLE_OUTPUT_MEAN: mean_predictions = predictions.data.new(batch_size, 1) for i, l in enumerate(target_lengths): mean_predictions[i] = predictions[i][:l].sum() / l result = mean_predictions elif self.output_mode == self.OutputMode.MULTI_OUTPUT: return predictions else: raise ValueError(self.output_mode) return result
[docs]class TargetSequenceDecoderFactory(DecoderFactory): """ A factory for :class:`TargetSequenceDecoderModule` which takes the latent encoding and a sequence of target items as input """ def __init__(self, prediction_mode: TargetSequenceDecoderModule.PredictionMode = TargetSequenceDecoderModule.PredictionMode.MULTI_LATENT, output_mode: TargetSequenceDecoderModule.OutputMode = TargetSequenceDecoderModule.OutputMode.MULTI_OUTPUT, latent_pass_on_mode: TargetSequenceDecoderModule.LatentPassOnMode = TargetSequenceDecoderModule.LatentPassOnMode.CONCAT_INPUT, predictor_factory: Optional[PredictorFactory] = None, p_recurrent_dropout: float = 0.0, output_dim: int = 1): if predictor_factory is None: predictor_factory = LinearPredictorFactory() self.output_dim = output_dim self.p_recurrent_dropout = p_recurrent_dropout self.prediction_mode = prediction_mode self.output_mode = output_mode self.latent_pass_on_mode = latent_pass_on_mode self.predictor_factory = predictor_factory
[docs] def create_decoder(self, latent_dim: int, target_feature_dim: int) -> torch.nn.Module: return TargetSequenceDecoderModule(target_feature_dim, latent_dim, self.predictor_factory, prediction_mode=self.prediction_mode, output_mode=self.output_mode, latent_pass_on_mode=self.latent_pass_on_mode, output_dim=self.output_dim, p_recurrent_dropout=self.p_recurrent_dropout)
[docs]class SingleTargetDecoderFactory(DecoderFactory): """ A factory for :class:`SingleTargetDecoderModule` which takes the latent encoding and a single-element sequence of target items as input, producing a single prediction """ def __init__(self, predictor_factory: PredictorFactory): self.predictor_factory = predictor_factory
[docs] def create_decoder(self, latent_dim: int, target_feature_dim: int) -> torch.nn.Module: return SingleTargetDecoderModule(target_feature_dim, latent_dim, self.predictor_factory)
[docs]class EncoderDecoderModule(torch.nn.Module): """ Represents and encoder-decoder (where both components can be injected). It takes a history sequence and a sequence of target feature vectors as input. Both sequences are potentially of variable length, and for the target sequence, the common special case where there is but one target and thus one prediction to be made is specifically catered for using dedicated decoders (see :class:`SingleTargetDecoderModule`). The module first encodes the history sequence to a latent vector and then uses the decoder to map this latent vector along with the target features to a prediction. """ def __init__(self, encoder: TEncoder, decoder: TDecoder, variable_history_length: bool): """ :param encoder: a torch module satisfying :class:`EncoderProtocol` :param decoder: a torch module satisfying :class:`DecoderProtocol` :param variable_history_length: whether the history sequence is variable-length. If it is not, then the model will not pass on the lengths tensor to the encoder, allowing it to simplify its handling of this case (even if the original input provides the lengths). """ super().__init__() self.variable_history_length = variable_history_length self.encoder = encoder self.decoder = decoder def __str__(self): return object_repr(self, dict(encoder=self.encoder, predictor=self.decoder))
[docs] def forward(self, window_features: torch.Tensor, window_lengths: Optional[torch.Tensor] = None, target_features: Optional[torch.Tensor] = None, target_lengths: Optional[torch.Tensor] = None): """ :param window_features: a tensor of size (batch_size, max(window_lengths), dim_per_window_item) containing the window features :param window_lengths: a tensor containing the lengths of windows in `w` :param target_features: an optional tensor containing target features with shape (batch_size, max_target_seq_length, target_feature_dim). For the case where there is only one target item (no actual sequence), `max_target_seq_length` should be 1. :param target_lengths: an optional tensor containing the lengths target the target sequences, allowing the actual sequence lengths to differ """ if self.variable_history_length: latent = self.encoder(window_features, window_lengths) else: latent = self.encoder(window_features) return self.decoder(latent, target_features, target_lengths)