Coverage for src/sensai/torch/torch_models/residualffn/residualffn_models.py: 100%
27 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import logging
2from typing import Sequence, Union, Optional
4import torch
6from .residualffn_modules import ResidualFeedForwardNetwork
7from ...torch_base import VectorTorchModel, TorchVectorRegressionModel
8from ...torch_opt import NNOptimiserParams
9from ....normalisation import NormalisationMode
11log: logging.Logger = logging.getLogger(__name__)
14class ResidualFeedForwardNetworkTorchModel(VectorTorchModel):
16 def __init__(self, cuda: bool, hidden_dims: Sequence[int], bottleneck_dimension_factor: float = 1, p_dropout=None,
17 use_batch_normalisation: bool = False) -> None:
18 super().__init__(cuda=cuda)
19 self.hiddenDims = hidden_dims
20 self.bottleneckDimensionFactor = bottleneck_dimension_factor
21 self.pDropout = p_dropout
22 self.useBatchNormalisation = use_batch_normalisation
24 def create_torch_module_for_dims(self, input_dim: int, output_dim: int) -> torch.nn.Module:
25 return ResidualFeedForwardNetwork(input_dim, output_dim, self.hiddenDims, self.bottleneckDimensionFactor,
26 p_dropout=self.pDropout, use_batch_normalisation=self.useBatchNormalisation)
29class ResidualFeedForwardNetworkVectorRegressionModel(TorchVectorRegressionModel):
31 def __init__(self,
32 hidden_dims: Sequence[int],
33 bottleneck_dimension_factor: float = 1,
34 cuda: bool = True,
35 p_dropout: Optional[float] = None,
36 use_batch_normalisation: bool = False,
37 normalisation_mode: NormalisationMode = NormalisationMode.NONE,
38 nn_optimiser_params: Union[NNOptimiserParams, dict, None] = None) -> None:
39 self.hidden_dims = hidden_dims
40 self.bottleneck_dimension_factor = bottleneck_dimension_factor
41 self.cuda = cuda
42 self.p_dropout = p_dropout
43 self.use_batch_normalisation = use_batch_normalisation
44 super().__init__(self._create_torch_model,
45 normalisation_mode=normalisation_mode, nn_optimiser_params=nn_optimiser_params)
47 def _create_torch_model(self):
48 return ResidualFeedForwardNetworkTorchModel(self.cuda, self.hidden_dims,
49 bottleneck_dimension_factor=self.bottleneck_dimension_factor,
50 p_dropout=self.p_dropout,
51 use_batch_normalisation=self.use_batch_normalisation)