Coverage for src/sensai/torch/torch_models/residualffn/residualffn_modules.py: 82%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from abc import ABC, abstractmethod 

2from typing import Sequence, Optional 

3 

4import torch 

5from torch import nn 

6 

7 

8class ResidualFeedForwardNetwork(nn.Module): 

9 """ 

10 A feed-forward network consisting of a fully connected input layer, a configurable number of residual blocks and a 

11 fully connected output layer. Similar architecture are described in e.g. [1] and [2] and are all inspired by 

12 ResNet [3]. Each residual block consists of two fully connected layers with (optionally) batch normalisation and 

13 dropout, which can all be bypassed by a so-called skip connection. The skip path and the non-skip path are added as 

14 the last step within each block. 

15 

16 More precisely, the non-skip path consists of the following layers: 

17 batch normalization -> ReLU, dropout -> fully-connected -> batch normalization -> ReLU, dropout -> fully-connected 

18 The use of the activation function before the connected layers is called "pre-activation" [4]. 

19 

20 The skip path does nothing for the case where the input dimension of the block equals the output dimension. If these 

21 dimensions are different, the skip-path consists of a fully-connected layer, but with no activation, normalization, 

22 or dropout. 

23 

24 Within each block, the dimension can be reduced by a certain factor. This is known as "bottleneck" design. It has 

25 been shown for the original ResNet, that such a bottleneck design can reduce the number of parameters of the models 

26 and improve the training behaviour without compromising the results. 

27 

28 Batch normalisation can be deactivated, but normally it improves the results, since it not only provides some 

29 regularisation, but also normalises the distribution of the inputs of each layer and therefore addresses the problem 

30 of "internal covariate shift"[5]. The mechanism behind this is not yet fully understood (see e.g. the Wikipedia 

31 article on batch normalisation for further references). 

32 Our batch normalisation module will normalise batches per dimension C in 2D tensors of shape (N, C) or 3D tensors 

33 of shape (N, L, C). 

34 

35 References: 

36 

37 * [1] Chen, Dongwei et al. "Deep Residual Learning for Nonlinear Regression." 

38 Entropy 22, no. 2 (February 2020): 193. https://doi.org/10.3390/e22020193. 

39 * [2] Kiprijanovska, et al. "HousEEC: Day-Ahead Household Electrical Energy Consumption Forecasting Using Deep Learning." 

40 Energies 13, no. 10 (January 2020): 2672. https://doi.org/10.3390/en13102672. 

41 * [3] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep Residual Learning for Image Recognition." 

42 ArXiv:1512.03385 [Cs], December 10, 2015. http://arxiv.org/abs/1512.03385. 

43 * [4] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Identity Mappings in Deep Residual Networks." 

44 ArXiv:1603.05027 [Cs], July 25, 2016. http://arxiv.org/abs/1603.05027. 

45 * [5] Ioffe, Sergey, and Christian Szegedy. "Batch Normalization: Accelerating Deep Network Training by Reducing 

46 Internal Covariate Shift." ArXiv:1502.03167 [Cs], March 2, 2015. http://arxiv.org/abs/1502.03167. 

47 """ 

48 

49 def __init__(self, input_dim: int, output_dim: int, hidden_dims: Sequence[int], bottleneck_dimension_factor: float = 1, 

50 p_dropout: Optional[float] = None, use_batch_normalisation: bool = True) -> None: 

51 """ 

52 :param input_dim: the input dimension of the model 

53 :param output_dim: the output dimension of the model 

54 :param hidden_dims: a list of dimensions; for each list item, a residual block with the corresponding dimension 

55 is created 

56 :param bottleneck_dimension_factor: an optional factor that specifies the hidden dimension within each block 

57 :param p_dropout: the dropout probability to use during training (defaults to None for no dropout) 

58 :param use_batch_normalisation: whether to use batch normalisation (defaults to True) 

59 """ 

60 super().__init__() 

61 self.inputDim = input_dim 

62 self.outputDim = output_dim 

63 self.hiddenDims = hidden_dims 

64 self.useBatchNormalisation = use_batch_normalisation 

65 

66 if p_dropout is not None: 

67 self.dropout = nn.Dropout(p=p_dropout) 

68 else: 

69 self.dropout = None 

70 

71 self.inputLayer = nn.Linear(self.inputDim, self.hiddenDims[0]) 

72 

73 inner_hidden_dims = lambda x: max(1, round(x * bottleneck_dimension_factor)) 

74 blocks = [] 

75 prev_dim = self.hiddenDims[0] 

76 for hidden_dim in self.hiddenDims[1:]: 

77 if hidden_dim == prev_dim: 

78 block = self._IdentityBlock(hidden_dim, inner_hidden_dims(hidden_dim), self.dropout, use_batch_normalisation) 

79 else: 

80 block = self._DenseBlock(prev_dim, inner_hidden_dims(hidden_dim), hidden_dim, self.dropout, use_batch_normalisation) 

81 blocks.append(block) 

82 prev_dim = hidden_dim 

83 

84 self.bnOutput = self._BatchNorm(self.hiddenDims[-1]) if self.useBatchNormalisation else None 

85 self.outputLayer = nn.Linear(self.hiddenDims[-1], output_dim) 

86 self.blocks = nn.ModuleList(blocks) 

87 

88 def forward(self, x): 

89 

90 x = self.inputLayer(x) 

91 

92 for block in self.blocks: 

93 x = block(x) 

94 

95 x = self.bnOutput(x) if self.useBatchNormalisation else x 

96 x = self.dropout(x) if self.dropout is not None else x 

97 x = self.outputLayer(x) 

98 return x 

99 

100 class _BatchNorm(nn.Module): 

101 def __init__(self, dim): 

102 super().__init__() 

103 self.bn = nn.BatchNorm1d(dim) 

104 

105 def forward(self, x): 

106 # BatchNorm1D normalises a 3D tensor per dimension C for shape (N, C, SeqLen). 

107 # For a 3D tensor (N, SeqLen, C), we thus permute to obtain (N, C, SeqLen), adopting the "regular" broadcasting semantics. 

108 is_3d = len(x.shape) == 3 

109 if is_3d: 

110 x = x.permute((0, 2, 1)) 

111 x = self.bn(x) 

112 if is_3d: 

113 x = x.permute((0, 2, 1)) 

114 return x 

115 

116 class _ResidualBlock(nn.Module, ABC): 

117 """ 

118 A generic residual block which need to be specified by defining the skip path. 

119 """ 

120 

121 def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: Optional[nn.Dropout], 

122 use_batch_normalisation: bool) -> None: 

123 super().__init__() 

124 self.inputDim = input_dim 

125 self.hiddenDim = hidden_dim 

126 self.outputDim = output_dim 

127 self.dropout = dropout 

128 self.useBatchNormalisation = use_batch_normalisation 

129 self.bnIn = ResidualFeedForwardNetwork._BatchNorm(self.inputDim) if use_batch_normalisation else None 

130 self.denseIn = nn.Linear(self.inputDim, self.hiddenDim) 

131 self.bnOut = ResidualFeedForwardNetwork._BatchNorm(self.hiddenDim) if use_batch_normalisation else None 

132 self.denseOut = nn.Linear(self.hiddenDim, self.outputDim) 

133 

134 def forward(self, x): 

135 x_skipped = self._skip(x) 

136 

137 x = self.bnIn(x) if self.useBatchNormalisation else x 

138 x = torch.relu(x) 

139 x = self.dropout(x) if self.dropout is not None else x 

140 x = self.denseIn(x) 

141 

142 x = self.bnOut(x) if self.useBatchNormalisation else x 

143 x = torch.relu(x) 

144 x = self.dropout(x) if self.dropout is not None else x 

145 x = self.denseOut(x) 

146 

147 return x + x_skipped 

148 

149 @abstractmethod 

150 def _skip(self, x): 

151 """ 

152 Defines the skip path of the residual block. The input is identical to the argument passed to forward. 

153 """ 

154 pass 

155 

156 class _IdentityBlock(_ResidualBlock): 

157 """ 

158 A residual block preserving the dimension of the input 

159 """ 

160 

161 def __init__(self, input_output_dim: int, hidden_dim: int, dropout: Optional[nn.Dropout], use_batch_normalisation: bool) -> None: 

162 super().__init__(input_output_dim, hidden_dim, input_output_dim, dropout, use_batch_normalisation) 

163 

164 def _skip(self, x): 

165 """ 

166 Defines the skip path as the identity function. 

167 """ 

168 return x 

169 

170 class _DenseBlock(_ResidualBlock): 

171 """ 

172 A residual block changing the dimension of the input to the given value. 

173 """ 

174 

175 def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: Optional[nn.Dropout], 

176 use_batch_normalisation: bool) -> None: 

177 super().__init__(input_dim, hidden_dim, output_dim, dropout, use_batch_normalisation) 

178 self.denseSkip = nn.Linear(self.inputDim, self.outputDim) 

179 

180 def _skip(self, x): 

181 """ 

182 Defines the skip path as a fully connected linear layer which changes the dimension as required by this 

183 block. 

184 """ 

185 return self.denseSkip(x)