Coverage for src/sensai/torch/torch_models/residualffn/residualffn_models.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import logging 

2from typing import Sequence, Union, Optional 

3 

4import torch 

5 

6from .residualffn_modules import ResidualFeedForwardNetwork 

7from ...torch_base import VectorTorchModel, TorchVectorRegressionModel 

8from ...torch_opt import NNOptimiserParams 

9from ....normalisation import NormalisationMode 

10 

11log: logging.Logger = logging.getLogger(__name__) 

12 

13 

14class ResidualFeedForwardNetworkTorchModel(VectorTorchModel): 

15 

16 def __init__(self, cuda: bool, hidden_dims: Sequence[int], bottleneck_dimension_factor: float = 1, p_dropout=None, 

17 use_batch_normalisation: bool = False) -> None: 

18 super().__init__(cuda=cuda) 

19 self.hiddenDims = hidden_dims 

20 self.bottleneckDimensionFactor = bottleneck_dimension_factor 

21 self.pDropout = p_dropout 

22 self.useBatchNormalisation = use_batch_normalisation 

23 

24 def create_torch_module_for_dims(self, input_dim: int, output_dim: int) -> torch.nn.Module: 

25 return ResidualFeedForwardNetwork(input_dim, output_dim, self.hiddenDims, self.bottleneckDimensionFactor, 

26 p_dropout=self.pDropout, use_batch_normalisation=self.useBatchNormalisation) 

27 

28 

29class ResidualFeedForwardNetworkVectorRegressionModel(TorchVectorRegressionModel): 

30 

31 def __init__(self, 

32 hidden_dims: Sequence[int], 

33 bottleneck_dimension_factor: float = 1, 

34 cuda: bool = True, 

35 p_dropout: Optional[float] = None, 

36 use_batch_normalisation: bool = False, 

37 normalisation_mode: NormalisationMode = NormalisationMode.NONE, 

38 nn_optimiser_params: Union[NNOptimiserParams, dict, None] = None) -> None: 

39 self.hidden_dims = hidden_dims 

40 self.bottleneck_dimension_factor = bottleneck_dimension_factor 

41 self.cuda = cuda 

42 self.p_dropout = p_dropout 

43 self.use_batch_normalisation = use_batch_normalisation 

44 super().__init__(self._create_torch_model, 

45 normalisation_mode=normalisation_mode, nn_optimiser_params=nn_optimiser_params) 

46 

47 def _create_torch_model(self): 

48 return ResidualFeedForwardNetworkTorchModel(self.cuda, self.hidden_dims, 

49 bottleneck_dimension_factor=self.bottleneck_dimension_factor, 

50 p_dropout=self.p_dropout, 

51 use_batch_normalisation=self.use_batch_normalisation)