Source code for sensai.torch.torch_models.residualffn.residualffn_models

import logging
from typing import Sequence, Union, Optional

import torch

from .residualffn_modules import ResidualFeedForwardNetwork
from ...torch_base import VectorTorchModel, TorchVectorRegressionModel
from ...torch_opt import NNOptimiserParams
from ....normalisation import NormalisationMode

log: logging.Logger = logging.getLogger(__name__)


[docs]class ResidualFeedForwardNetworkTorchModel(VectorTorchModel): def __init__(self, cuda: bool, hidden_dims: Sequence[int], bottleneck_dimension_factor: float = 1, p_dropout=None, use_batch_normalisation: bool = False) -> None: super().__init__(cuda=cuda) self.hiddenDims = hidden_dims self.bottleneckDimensionFactor = bottleneck_dimension_factor self.pDropout = p_dropout self.useBatchNormalisation = use_batch_normalisation
[docs] def create_torch_module_for_dims(self, input_dim: int, output_dim: int) -> torch.nn.Module: return ResidualFeedForwardNetwork(input_dim, output_dim, self.hiddenDims, self.bottleneckDimensionFactor, p_dropout=self.pDropout, use_batch_normalisation=self.useBatchNormalisation)
[docs]class ResidualFeedForwardNetworkVectorRegressionModel(TorchVectorRegressionModel): def __init__(self, hidden_dims: Sequence[int], bottleneck_dimension_factor: float = 1, cuda: bool = True, p_dropout: Optional[float] = None, use_batch_normalisation: bool = False, normalisation_mode: NormalisationMode = NormalisationMode.NONE, nn_optimiser_params: Union[NNOptimiserParams, dict, None] = None) -> None: self.hidden_dims = hidden_dims self.bottleneck_dimension_factor = bottleneck_dimension_factor self.cuda = cuda self.p_dropout = p_dropout self.use_batch_normalisation = use_batch_normalisation super().__init__(self._create_torch_model, normalisation_mode=normalisation_mode, nn_optimiser_params=nn_optimiser_params) def _create_torch_model(self): return ResidualFeedForwardNetworkTorchModel(self.cuda, self.hidden_dims, bottleneck_dimension_factor=self.bottleneck_dimension_factor, p_dropout=self.p_dropout, use_batch_normalisation=self.use_batch_normalisation)