Coverage for src/sensai/torch/torch_models/residualffn/residualffn_modules.py: 82%
84 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from abc import ABC, abstractmethod
2from typing import Sequence, Optional
4import torch
5from torch import nn
8class ResidualFeedForwardNetwork(nn.Module):
9 """
10 A feed-forward network consisting of a fully connected input layer, a configurable number of residual blocks and a
11 fully connected output layer. Similar architecture are described in e.g. [1] and [2] and are all inspired by
12 ResNet [3]. Each residual block consists of two fully connected layers with (optionally) batch normalisation and
13 dropout, which can all be bypassed by a so-called skip connection. The skip path and the non-skip path are added as
14 the last step within each block.
16 More precisely, the non-skip path consists of the following layers:
17 batch normalization -> ReLU, dropout -> fully-connected -> batch normalization -> ReLU, dropout -> fully-connected
18 The use of the activation function before the connected layers is called "pre-activation" [4].
20 The skip path does nothing for the case where the input dimension of the block equals the output dimension. If these
21 dimensions are different, the skip-path consists of a fully-connected layer, but with no activation, normalization,
22 or dropout.
24 Within each block, the dimension can be reduced by a certain factor. This is known as "bottleneck" design. It has
25 been shown for the original ResNet, that such a bottleneck design can reduce the number of parameters of the models
26 and improve the training behaviour without compromising the results.
28 Batch normalisation can be deactivated, but normally it improves the results, since it not only provides some
29 regularisation, but also normalises the distribution of the inputs of each layer and therefore addresses the problem
30 of "internal covariate shift"[5]. The mechanism behind this is not yet fully understood (see e.g. the Wikipedia
31 article on batch normalisation for further references).
32 Our batch normalisation module will normalise batches per dimension C in 2D tensors of shape (N, C) or 3D tensors
33 of shape (N, L, C).
35 References:
37 * [1] Chen, Dongwei et al. "Deep Residual Learning for Nonlinear Regression."
38 Entropy 22, no. 2 (February 2020): 193. https://doi.org/10.3390/e22020193.
39 * [2] Kiprijanovska, et al. "HousEEC: Day-Ahead Household Electrical Energy Consumption Forecasting Using Deep Learning."
40 Energies 13, no. 10 (January 2020): 2672. https://doi.org/10.3390/en13102672.
41 * [3] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep Residual Learning for Image Recognition."
42 ArXiv:1512.03385 [Cs], December 10, 2015. http://arxiv.org/abs/1512.03385.
43 * [4] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Identity Mappings in Deep Residual Networks."
44 ArXiv:1603.05027 [Cs], July 25, 2016. http://arxiv.org/abs/1603.05027.
45 * [5] Ioffe, Sergey, and Christian Szegedy. "Batch Normalization: Accelerating Deep Network Training by Reducing
46 Internal Covariate Shift." ArXiv:1502.03167 [Cs], March 2, 2015. http://arxiv.org/abs/1502.03167.
47 """
49 def __init__(self, input_dim: int, output_dim: int, hidden_dims: Sequence[int], bottleneck_dimension_factor: float = 1,
50 p_dropout: Optional[float] = None, use_batch_normalisation: bool = True) -> None:
51 """
52 :param input_dim: the input dimension of the model
53 :param output_dim: the output dimension of the model
54 :param hidden_dims: a list of dimensions; for each list item, a residual block with the corresponding dimension
55 is created
56 :param bottleneck_dimension_factor: an optional factor that specifies the hidden dimension within each block
57 :param p_dropout: the dropout probability to use during training (defaults to None for no dropout)
58 :param use_batch_normalisation: whether to use batch normalisation (defaults to True)
59 """
60 super().__init__()
61 self.inputDim = input_dim
62 self.outputDim = output_dim
63 self.hiddenDims = hidden_dims
64 self.useBatchNormalisation = use_batch_normalisation
66 if p_dropout is not None:
67 self.dropout = nn.Dropout(p=p_dropout)
68 else:
69 self.dropout = None
71 self.inputLayer = nn.Linear(self.inputDim, self.hiddenDims[0])
73 inner_hidden_dims = lambda x: max(1, round(x * bottleneck_dimension_factor))
74 blocks = []
75 prev_dim = self.hiddenDims[0]
76 for hidden_dim in self.hiddenDims[1:]:
77 if hidden_dim == prev_dim:
78 block = self._IdentityBlock(hidden_dim, inner_hidden_dims(hidden_dim), self.dropout, use_batch_normalisation)
79 else:
80 block = self._DenseBlock(prev_dim, inner_hidden_dims(hidden_dim), hidden_dim, self.dropout, use_batch_normalisation)
81 blocks.append(block)
82 prev_dim = hidden_dim
84 self.bnOutput = self._BatchNorm(self.hiddenDims[-1]) if self.useBatchNormalisation else None
85 self.outputLayer = nn.Linear(self.hiddenDims[-1], output_dim)
86 self.blocks = nn.ModuleList(blocks)
88 def forward(self, x):
90 x = self.inputLayer(x)
92 for block in self.blocks:
93 x = block(x)
95 x = self.bnOutput(x) if self.useBatchNormalisation else x
96 x = self.dropout(x) if self.dropout is not None else x
97 x = self.outputLayer(x)
98 return x
100 class _BatchNorm(nn.Module):
101 def __init__(self, dim):
102 super().__init__()
103 self.bn = nn.BatchNorm1d(dim)
105 def forward(self, x):
106 # BatchNorm1D normalises a 3D tensor per dimension C for shape (N, C, SeqLen).
107 # For a 3D tensor (N, SeqLen, C), we thus permute to obtain (N, C, SeqLen), adopting the "regular" broadcasting semantics.
108 is_3d = len(x.shape) == 3
109 if is_3d:
110 x = x.permute((0, 2, 1))
111 x = self.bn(x)
112 if is_3d:
113 x = x.permute((0, 2, 1))
114 return x
116 class _ResidualBlock(nn.Module, ABC):
117 """
118 A generic residual block which need to be specified by defining the skip path.
119 """
121 def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: Optional[nn.Dropout],
122 use_batch_normalisation: bool) -> None:
123 super().__init__()
124 self.inputDim = input_dim
125 self.hiddenDim = hidden_dim
126 self.outputDim = output_dim
127 self.dropout = dropout
128 self.useBatchNormalisation = use_batch_normalisation
129 self.bnIn = ResidualFeedForwardNetwork._BatchNorm(self.inputDim) if use_batch_normalisation else None
130 self.denseIn = nn.Linear(self.inputDim, self.hiddenDim)
131 self.bnOut = ResidualFeedForwardNetwork._BatchNorm(self.hiddenDim) if use_batch_normalisation else None
132 self.denseOut = nn.Linear(self.hiddenDim, self.outputDim)
134 def forward(self, x):
135 x_skipped = self._skip(x)
137 x = self.bnIn(x) if self.useBatchNormalisation else x
138 x = torch.relu(x)
139 x = self.dropout(x) if self.dropout is not None else x
140 x = self.denseIn(x)
142 x = self.bnOut(x) if self.useBatchNormalisation else x
143 x = torch.relu(x)
144 x = self.dropout(x) if self.dropout is not None else x
145 x = self.denseOut(x)
147 return x + x_skipped
149 @abstractmethod
150 def _skip(self, x):
151 """
152 Defines the skip path of the residual block. The input is identical to the argument passed to forward.
153 """
154 pass
156 class _IdentityBlock(_ResidualBlock):
157 """
158 A residual block preserving the dimension of the input
159 """
161 def __init__(self, input_output_dim: int, hidden_dim: int, dropout: Optional[nn.Dropout], use_batch_normalisation: bool) -> None:
162 super().__init__(input_output_dim, hidden_dim, input_output_dim, dropout, use_batch_normalisation)
164 def _skip(self, x):
165 """
166 Defines the skip path as the identity function.
167 """
168 return x
170 class _DenseBlock(_ResidualBlock):
171 """
172 A residual block changing the dimension of the input to the given value.
173 """
175 def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: Optional[nn.Dropout],
176 use_batch_normalisation: bool) -> None:
177 super().__init__(input_dim, hidden_dim, output_dim, dropout, use_batch_normalisation)
178 self.denseSkip = nn.Linear(self.inputDim, self.outputDim)
180 def _skip(self, x):
181 """
182 Defines the skip path as a fully connected linear layer which changes the dimension as required by this
183 block.
184 """
185 return self.denseSkip(x)