Coverage for src/sensai/torch/torch_models/lstnet/lstnet_modules.py: 21%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1from enum import Enum 

2from typing import Union, Callable 

3 

4import torch 

5from torch import nn 

6from torch.nn import functional as F 

7 

8from sensai.util.pickle import setstate 

9from ...torch_base import MCDropoutCapableNNModule 

10from ...torch_enums import ActivationFunction 

11 

12 

13class LSTNetwork(MCDropoutCapableNNModule): 

14 """ 

15 Network for (auto-regressive) time-series prediction with long- and short-term dependencies as proposed by G. Lai et al. 

16 It applies two parallel paths to a time series of size (numInputTimeSlices, inputDimPerTimeSlice): 

17  

18 * Complex path with the following stages: 

19  

20 * Convolutions on the time series input data (CNNs): 

21 For a CNN with numCnnTimeSlices (= kernel size), it produces an output series of size numInputTimeSlices-numCnnTimeSlices+1. 

22 If the number of parallel convolutions is numConvolutions, the total output size of this stage is thus 

23 numConvolutions*(numInputTimeSlices-numCnnTimeSlices+1) 

24 * Two RNN components which process the CNN output in parallel: 

25  

26 * RNN (GRU) 

27 The output dimension of this stage is the hidden state of the GRU after seeing the entire 

28 input data from the previous stage, i.e. if has size hidRNN. 

29 * Skip-RNN (GRU), which processes time series elements that are 'skip' time slices apart. 

30 It does this by grouping the input such that 'skip' GRUs are applied in parallel, which all use the same parameters. 

31 If the hidden state dimension of each GRU is hidSkip, then the output size of this stage is skip*hidSkip. 

32  

33 * Dense layer 

34  

35 * Direct regression dense layer (so-called "highway" path) which uses the features of the last hwWindow time slices to 

36 directly make a prediction 

37  

38 The model ultimately combines the outputs of these two paths via a combination function. 

39 Many parts of the model are optional and can be completely disabled. 

40 The model can produce one or more (potentially multi-dimensional) outputs, where each output typically corresponds 

41 to a time slice for which a prediction is made. 

42 

43 The model expects as input a tensor of size (batchSize, numInputTimeSlices, inputDimPerTimeSlice). 

44 As output, the model will produce a tensor of size (batchSize, numOutputTimeSlices, outputDimPerTimeSlice) 

45 if mode==REGRESSION and a tensor of size (batchSize, outputDimPerTimeSlice=numClasses, numOutputTimeSlices) 

46 if mode==CLASSIFICATION; the latter shape matches what is required by the multi-dimensional case of loss function 

47 CrossEntropyLoss, for example, and therefore is suitable for classification use cases. 

48 For mode==ENCODER, the model will produce a tensor of size (batch_size, hidRNN + skip * hidSkip). 

49 """ 

50 

51 class Mode(Enum): 

52 REGRESSION = "regression" 

53 CLASSIFICATION = "classification" 

54 ENCODER = "encoder" 

55 

56 def __init__(self, 

57 num_input_time_slices: int, 

58 input_dim_per_time_slice: int, 

59 num_output_time_slices: int = 1, 

60 output_dim_per_time_slice: int = 1, 

61 num_convolutions: int = 100, 

62 num_cnn_time_slices: int = 6, 

63 hid_rnn: int = 100, 

64 skip: int = 0, 

65 hid_skip: int = 5, 

66 hw_window: int = 0, 

67 hw_combine: str = "plus", 

68 dropout=0.2, 

69 output_activation: Union[str, ActivationFunction, Callable] = "sigmoid", 

70 mode: Mode = Mode.REGRESSION): 

71 """ 

72 :param num_input_time_slices: the number of input time slices 

73 :param input_dim_per_time_slice: the dimension of the input data per time slice 

74 :param num_output_time_slices: the number of time slices predicted by the model 

75 :param output_dim_per_time_slice: the number of dimensions per output time slice. While this is the number of 

76 target variables per time slice for regression problems, this must be the number of classes for classification problems. 

77 :param num_cnn_time_slices: the number of time slices considered by each convolution (i.e. it is one of the dimensions of the matrix used for 

78 convolutions, the other dimension being inputDimPerTimeSlice), a.k.a. "Ck" 

79 :param num_convolutions: the number of separate convolutions to apply, i.e. the number of independent convolution matrices, a.k.a "hidC"; 

80 if it is 0, then the entire complex processing path is not applied. 

81 :param hid_rnn: the number of hidden output dimensions for the RNN stage 

82 :param skip: the number of time slices to skip for the skip-RNN. If it is 0, then the skip-RNN is not used. 

83 :param hid_skip: the number of output dimensions of each of the skip parallel RNNs 

84 :param hw_window: the number of time slices from the end of the input time series to consider as input for the highway component. 

85 If it is 0, the highway component is not used. 

86 :param hw_combine: {"plus", "product", "bilinear"} the function with which the highway component's output is combined with the complex path's output 

87 :param dropout: the dropout probability to use during training (dropouts are applied after every major step in the evaluation path) 

88 :param output_activation: the output activation function 

89 :param mode: the prediction mode. For `CLASSIFICATION`, the output tensor dimension ordering is adapted to suit loss functions such 

90 as CrossEntropyLoss. When set to `ENCODER`, will output the latent representation prior to the dense layer in the complex path 

91 of the network (see class docstring). 

92 """ 

93 if num_convolutions == 0 and hw_window == 0: 

94 raise ValueError("No processing paths remain") 

95 if num_input_time_slices < num_cnn_time_slices or (hw_window != 0 and hw_window < num_input_time_slices): 

96 raise Exception("Inconsistent numbers of times slices provided") 

97 

98 super().__init__() 

99 self.inputDimPerTimeSlice = input_dim_per_time_slice 

100 self.timeSeriesDimPerTimeSlice = output_dim_per_time_slice 

101 self.totalOutputDim = self.timeSeriesDimPerTimeSlice * num_output_time_slices 

102 self.numOutputTimeSlices = num_output_time_slices 

103 self.window = num_input_time_slices 

104 self.hidRNN = hid_rnn 

105 self.numConv = num_convolutions 

106 self.hidSkip = hid_skip 

107 self.Ck = num_cnn_time_slices # the "height" of the CNN filter/kernel; the "width" being inputDimPerTimeSlice 

108 self.convSeqLength = self.window - self.Ck + 1 # the length of the output sequence produced by the CNN for each kernel matrix 

109 self.skip = skip 

110 self.hw = hw_window 

111 self.pDropout = dropout 

112 self.mode = mode 

113 

114 # configure CNN-RNN path 

115 if self.numConv > 0: 

116 self.conv1 = nn.Conv2d(1, self.numConv, kernel_size=(self.Ck, self.inputDimPerTimeSlice)) # produce numConv sequences using numConv kernel matrices of size (height=Ck, width=inputDimPerTimeSlice) 

117 self.GRU1 = nn.GRU(self.numConv, self.hidRNN) 

118 if self.skip > 0: 

119 self.skipRnnSeqLength = self.convSeqLength // self.skip # we divide by skip to obtain the sequence length, because, in order to support skipping via a regrouping of the tensor, the Skip-RNN processes skip entries of the series in parallel to produce skip hidden output vectors 

120 if self.skipRnnSeqLength == 0: 

121 raise Exception("Window size %d is not large enough for skip length %d; would result in Skip-RNN sequence length of 0!" % (self.window, self.skip)) 

122 self.GRUskip = nn.GRU(self.numConv, self.hidSkip) 

123 self.linear1 = nn.Linear(self.hidRNN + self.skip * self.hidSkip, self.totalOutputDim) 

124 else: 

125 self.linear1 = nn.Linear(self.hidRNN, self.totalOutputDim) 

126 

127 # configure highway component 

128 if self.hw > 0: 

129 # direct mapping from all inputs to all outputs 

130 self.highway = nn.Linear(self.hw * self.inputDimPerTimeSlice, self.totalOutputDim) 

131 if hw_combine == 'plus': 

132 self.highwayCombine = self._plus 

133 elif hw_combine == 'product': 

134 self.highwayCombine = self._product 

135 elif hw_combine == 'bilinear': 

136 self.highwayCombine = nn.Bilinear(self.totalOutputDim, self.totalOutputDim, self.totalOutputDim) 

137 else: 

138 raise ValueError("Unknown highway combination function '%s'" % hw_combine) 

139 

140 self.output = ActivationFunction.torch_function_from_any(output_activation) 

141 

142 def __setstate__(self, state): 

143 if "isClassification" in state: 

144 state["mode"] = self.Mode.CLASSIFICATION if state["isClassification"] else self.Mode.REGRESSION 

145 setstate(LSTNetwork, self, state, removed_properties=["isClassification"]) 

146 

147 @staticmethod 

148 def compute_encoder_dim(hid_rnn: int, skip: int, hid_skip: int) -> int: 

149 return hid_rnn + skip * hid_skip 

150 

151 def get_encoder_dim(self): 

152 """ 

153 :return: the vector dimension that is output for the case where mode=ENCODER 

154 """ 

155 return self.compute_encoder_dim(self.hidRNN, self.skip, self.hidSkip) 

156 

157 def forward(self, x): 

158 batch_size = x.size(0) 

159 # x has size (batch_size, window=numInputTimeSlices, inputDimPerTimeSlice) 

160 

161 dropout = lambda x: self._dropout(x, p_training=self.pDropout, p_inference=self.pDropout) 

162 

163 res = None 

164 

165 if self.numConv > 0: 

166 # CNN 

167 # convSeqLength = self.window - self.Ck + 1 

168 # convolution produces, via numConv kernel matrices of dimension (height=Ck, width=inputDimPerTimeSlice), from an original input sequence of length window, numConv output sequences of length convSeqLength 

169 c = x.view(batch_size, 1, self.window, self.inputDimPerTimeSlice) # insert one dim of size 1 (one channel): (batch_size, 1, height=window, width=inputDimPerTimeSlice) 

170 c = F.relu(self.conv1(c)) # (batch_size, channels=numConv, convSeqLength, 1) 

171 c = dropout(c) 

172 c = torch.squeeze(c, 3) # drops last dimension, i.e. new size (batch_size, numConv, convSeqLength) 

173 

174 # RNN 

175 # It processes the numConv sequences of length convSeqLength obtained through convolution and keep the hidden state at the end, which is comprised of hidR entries 

176 # Specifically, it squashes the numConv sequences of length convSeqLength to a vector of size hidS (by iterating through the sequences and applying the same model in each step, processing all batches in parallel) 

177 r = c.permute(2, 0, 1).contiguous() # (convSeqLength, batch_size, numConv) 

178 self.GRU1.flatten_parameters() 

179 _, r = self.GRU1(r) # maps (seq_len=convSeqLength, batch=batch_size, input_size=numConv) -> hidden state (num_layers=1, batch=batch_size, hidden_size=hidR) 

180 r = torch.squeeze(r, 0) # (batch_size, hidR) 

181 r = dropout(r) 

182 

183 # Skip-RNN 

184 if self.skip > 0: 

185 s = c[:, :, -(self.skipRnnSeqLength * self.skip):].contiguous() # (batch_size, numConv, convSeqLength) -> (batch_size, numConv, skipRnnSeqLength * skip) 

186 s = s.view(batch_size, self.numConv, self.skipRnnSeqLength, self.skip) # (batch_size, numConv, skipRnnSeqLength, skip) 

187 s = s.permute(2, 0, 3, 1).contiguous() # (skipRnnSeqLength, batch_size, skip, numConv) 

188 s = s.view(self.skipRnnSeqLength, batch_size * self.skip, self.numConv) # (skipRnnSeqLength, batch_size * skip, numConv) 

189 # Why the above view makes sense: 

190 # skipRnnSeqLength is the sequence length considered for the RNN, i.e. the number of steps that is taken for each sequence. 

191 # The batch_size*skip elements of the second dimension are all processed in parallel, i.e. there are batch_size*skip RNNs being applied in parallel. 

192 # By scaling the actual batch size with 'skip', we process 'skip' RNNs of each batch in parallel, such that each RNN consecutively processes entries that are 'skip' steps apart 

193 self.GRUskip.flatten_parameters() 

194 _, s = self.GRUskip(s) # maps (seq_len=skipRnnSeqLength, batch=batch_size * skip, input_size=numConv) -> hidden state (num_layers=1, batch=batch_size * skip, hidden_size=hidS) 

195 # Because of the way the data is grouped, we obtain not one vector of size hidS but skip vectors of size hidS 

196 s = s.view(batch_size, self.skip * self.hidSkip) # regroup by batch -> (batch_size, skip * hidS) 

197 s = dropout(s) 

198 r = torch.cat((r, s), 1) # (batch_size, hidR + skip * hidS) 

199 

200 if self.mode == self.Mode.ENCODER: 

201 return r 

202 

203 res = self.linear1(r) # (batch_size, totalOutputDim) 

204 

205 # auto-regressive highway model 

206 if self.hw > 0: 

207 resHW = x[:, -self.hw:, :] # keep only the last hw entries for each input: (batch_size, hw, inputDimPerTimeSlice) 

208 resHW = resHW.view(-1, self.hw * self.inputDimPerTimeSlice) # (batch_size, hw * inputDimPerTimeSlice) 

209 resHW = self.highway(resHW) # (batch_size, totalOutputDim) 

210 if res is None: 

211 res = resHW 

212 else: 

213 res = self.highwayCombine(res, resHW) # (batch_size, totalOutputDim) 

214 

215 if self.output: 

216 res = self.output(res) 

217 

218 res = res.view(batch_size, self.numOutputTimeSlices, self.timeSeriesDimPerTimeSlice) 

219 if self.mode == self.Mode.CLASSIFICATION: 

220 res = res.permute(0, 2, 1) 

221 return res 

222 

223 @staticmethod 

224 def _plus(x, y): 

225 return x + y 

226 

227 @staticmethod 

228 def _product(x, y): 

229 return x * y