Coverage for src/sensai/torch/torch_models/lstnet/lstnet_modules.py: 21%
108 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from enum import Enum
2from typing import Union, Callable
4import torch
5from torch import nn
6from torch.nn import functional as F
8from sensai.util.pickle import setstate
9from ...torch_base import MCDropoutCapableNNModule
10from ...torch_enums import ActivationFunction
13class LSTNetwork(MCDropoutCapableNNModule):
14 """
15 Network for (auto-regressive) time-series prediction with long- and short-term dependencies as proposed by G. Lai et al.
16 It applies two parallel paths to a time series of size (numInputTimeSlices, inputDimPerTimeSlice):
18 * Complex path with the following stages:
20 * Convolutions on the time series input data (CNNs):
21 For a CNN with numCnnTimeSlices (= kernel size), it produces an output series of size numInputTimeSlices-numCnnTimeSlices+1.
22 If the number of parallel convolutions is numConvolutions, the total output size of this stage is thus
23 numConvolutions*(numInputTimeSlices-numCnnTimeSlices+1)
24 * Two RNN components which process the CNN output in parallel:
26 * RNN (GRU)
27 The output dimension of this stage is the hidden state of the GRU after seeing the entire
28 input data from the previous stage, i.e. if has size hidRNN.
29 * Skip-RNN (GRU), which processes time series elements that are 'skip' time slices apart.
30 It does this by grouping the input such that 'skip' GRUs are applied in parallel, which all use the same parameters.
31 If the hidden state dimension of each GRU is hidSkip, then the output size of this stage is skip*hidSkip.
33 * Dense layer
35 * Direct regression dense layer (so-called "highway" path) which uses the features of the last hwWindow time slices to
36 directly make a prediction
38 The model ultimately combines the outputs of these two paths via a combination function.
39 Many parts of the model are optional and can be completely disabled.
40 The model can produce one or more (potentially multi-dimensional) outputs, where each output typically corresponds
41 to a time slice for which a prediction is made.
43 The model expects as input a tensor of size (batchSize, numInputTimeSlices, inputDimPerTimeSlice).
44 As output, the model will produce a tensor of size (batchSize, numOutputTimeSlices, outputDimPerTimeSlice)
45 if mode==REGRESSION and a tensor of size (batchSize, outputDimPerTimeSlice=numClasses, numOutputTimeSlices)
46 if mode==CLASSIFICATION; the latter shape matches what is required by the multi-dimensional case of loss function
47 CrossEntropyLoss, for example, and therefore is suitable for classification use cases.
48 For mode==ENCODER, the model will produce a tensor of size (batch_size, hidRNN + skip * hidSkip).
49 """
51 class Mode(Enum):
52 REGRESSION = "regression"
53 CLASSIFICATION = "classification"
54 ENCODER = "encoder"
56 def __init__(self,
57 num_input_time_slices: int,
58 input_dim_per_time_slice: int,
59 num_output_time_slices: int = 1,
60 output_dim_per_time_slice: int = 1,
61 num_convolutions: int = 100,
62 num_cnn_time_slices: int = 6,
63 hid_rnn: int = 100,
64 skip: int = 0,
65 hid_skip: int = 5,
66 hw_window: int = 0,
67 hw_combine: str = "plus",
68 dropout=0.2,
69 output_activation: Union[str, ActivationFunction, Callable] = "sigmoid",
70 mode: Mode = Mode.REGRESSION):
71 """
72 :param num_input_time_slices: the number of input time slices
73 :param input_dim_per_time_slice: the dimension of the input data per time slice
74 :param num_output_time_slices: the number of time slices predicted by the model
75 :param output_dim_per_time_slice: the number of dimensions per output time slice. While this is the number of
76 target variables per time slice for regression problems, this must be the number of classes for classification problems.
77 :param num_cnn_time_slices: the number of time slices considered by each convolution (i.e. it is one of the dimensions of the matrix used for
78 convolutions, the other dimension being inputDimPerTimeSlice), a.k.a. "Ck"
79 :param num_convolutions: the number of separate convolutions to apply, i.e. the number of independent convolution matrices, a.k.a "hidC";
80 if it is 0, then the entire complex processing path is not applied.
81 :param hid_rnn: the number of hidden output dimensions for the RNN stage
82 :param skip: the number of time slices to skip for the skip-RNN. If it is 0, then the skip-RNN is not used.
83 :param hid_skip: the number of output dimensions of each of the skip parallel RNNs
84 :param hw_window: the number of time slices from the end of the input time series to consider as input for the highway component.
85 If it is 0, the highway component is not used.
86 :param hw_combine: {"plus", "product", "bilinear"} the function with which the highway component's output is combined with the complex path's output
87 :param dropout: the dropout probability to use during training (dropouts are applied after every major step in the evaluation path)
88 :param output_activation: the output activation function
89 :param mode: the prediction mode. For `CLASSIFICATION`, the output tensor dimension ordering is adapted to suit loss functions such
90 as CrossEntropyLoss. When set to `ENCODER`, will output the latent representation prior to the dense layer in the complex path
91 of the network (see class docstring).
92 """
93 if num_convolutions == 0 and hw_window == 0:
94 raise ValueError("No processing paths remain")
95 if num_input_time_slices < num_cnn_time_slices or (hw_window != 0 and hw_window < num_input_time_slices):
96 raise Exception("Inconsistent numbers of times slices provided")
98 super().__init__()
99 self.inputDimPerTimeSlice = input_dim_per_time_slice
100 self.timeSeriesDimPerTimeSlice = output_dim_per_time_slice
101 self.totalOutputDim = self.timeSeriesDimPerTimeSlice * num_output_time_slices
102 self.numOutputTimeSlices = num_output_time_slices
103 self.window = num_input_time_slices
104 self.hidRNN = hid_rnn
105 self.numConv = num_convolutions
106 self.hidSkip = hid_skip
107 self.Ck = num_cnn_time_slices # the "height" of the CNN filter/kernel; the "width" being inputDimPerTimeSlice
108 self.convSeqLength = self.window - self.Ck + 1 # the length of the output sequence produced by the CNN for each kernel matrix
109 self.skip = skip
110 self.hw = hw_window
111 self.pDropout = dropout
112 self.mode = mode
114 # configure CNN-RNN path
115 if self.numConv > 0:
116 self.conv1 = nn.Conv2d(1, self.numConv, kernel_size=(self.Ck, self.inputDimPerTimeSlice)) # produce numConv sequences using numConv kernel matrices of size (height=Ck, width=inputDimPerTimeSlice)
117 self.GRU1 = nn.GRU(self.numConv, self.hidRNN)
118 if self.skip > 0:
119 self.skipRnnSeqLength = self.convSeqLength // self.skip # we divide by skip to obtain the sequence length, because, in order to support skipping via a regrouping of the tensor, the Skip-RNN processes skip entries of the series in parallel to produce skip hidden output vectors
120 if self.skipRnnSeqLength == 0:
121 raise Exception("Window size %d is not large enough for skip length %d; would result in Skip-RNN sequence length of 0!" % (self.window, self.skip))
122 self.GRUskip = nn.GRU(self.numConv, self.hidSkip)
123 self.linear1 = nn.Linear(self.hidRNN + self.skip * self.hidSkip, self.totalOutputDim)
124 else:
125 self.linear1 = nn.Linear(self.hidRNN, self.totalOutputDim)
127 # configure highway component
128 if self.hw > 0:
129 # direct mapping from all inputs to all outputs
130 self.highway = nn.Linear(self.hw * self.inputDimPerTimeSlice, self.totalOutputDim)
131 if hw_combine == 'plus':
132 self.highwayCombine = self._plus
133 elif hw_combine == 'product':
134 self.highwayCombine = self._product
135 elif hw_combine == 'bilinear':
136 self.highwayCombine = nn.Bilinear(self.totalOutputDim, self.totalOutputDim, self.totalOutputDim)
137 else:
138 raise ValueError("Unknown highway combination function '%s'" % hw_combine)
140 self.output = ActivationFunction.torch_function_from_any(output_activation)
142 def __setstate__(self, state):
143 if "isClassification" in state:
144 state["mode"] = self.Mode.CLASSIFICATION if state["isClassification"] else self.Mode.REGRESSION
145 setstate(LSTNetwork, self, state, removed_properties=["isClassification"])
147 @staticmethod
148 def compute_encoder_dim(hid_rnn: int, skip: int, hid_skip: int) -> int:
149 return hid_rnn + skip * hid_skip
151 def get_encoder_dim(self):
152 """
153 :return: the vector dimension that is output for the case where mode=ENCODER
154 """
155 return self.compute_encoder_dim(self.hidRNN, self.skip, self.hidSkip)
157 def forward(self, x):
158 batch_size = x.size(0)
159 # x has size (batch_size, window=numInputTimeSlices, inputDimPerTimeSlice)
161 dropout = lambda x: self._dropout(x, p_training=self.pDropout, p_inference=self.pDropout)
163 res = None
165 if self.numConv > 0:
166 # CNN
167 # convSeqLength = self.window - self.Ck + 1
168 # convolution produces, via numConv kernel matrices of dimension (height=Ck, width=inputDimPerTimeSlice), from an original input sequence of length window, numConv output sequences of length convSeqLength
169 c = x.view(batch_size, 1, self.window, self.inputDimPerTimeSlice) # insert one dim of size 1 (one channel): (batch_size, 1, height=window, width=inputDimPerTimeSlice)
170 c = F.relu(self.conv1(c)) # (batch_size, channels=numConv, convSeqLength, 1)
171 c = dropout(c)
172 c = torch.squeeze(c, 3) # drops last dimension, i.e. new size (batch_size, numConv, convSeqLength)
174 # RNN
175 # It processes the numConv sequences of length convSeqLength obtained through convolution and keep the hidden state at the end, which is comprised of hidR entries
176 # Specifically, it squashes the numConv sequences of length convSeqLength to a vector of size hidS (by iterating through the sequences and applying the same model in each step, processing all batches in parallel)
177 r = c.permute(2, 0, 1).contiguous() # (convSeqLength, batch_size, numConv)
178 self.GRU1.flatten_parameters()
179 _, r = self.GRU1(r) # maps (seq_len=convSeqLength, batch=batch_size, input_size=numConv) -> hidden state (num_layers=1, batch=batch_size, hidden_size=hidR)
180 r = torch.squeeze(r, 0) # (batch_size, hidR)
181 r = dropout(r)
183 # Skip-RNN
184 if self.skip > 0:
185 s = c[:, :, -(self.skipRnnSeqLength * self.skip):].contiguous() # (batch_size, numConv, convSeqLength) -> (batch_size, numConv, skipRnnSeqLength * skip)
186 s = s.view(batch_size, self.numConv, self.skipRnnSeqLength, self.skip) # (batch_size, numConv, skipRnnSeqLength, skip)
187 s = s.permute(2, 0, 3, 1).contiguous() # (skipRnnSeqLength, batch_size, skip, numConv)
188 s = s.view(self.skipRnnSeqLength, batch_size * self.skip, self.numConv) # (skipRnnSeqLength, batch_size * skip, numConv)
189 # Why the above view makes sense:
190 # skipRnnSeqLength is the sequence length considered for the RNN, i.e. the number of steps that is taken for each sequence.
191 # The batch_size*skip elements of the second dimension are all processed in parallel, i.e. there are batch_size*skip RNNs being applied in parallel.
192 # By scaling the actual batch size with 'skip', we process 'skip' RNNs of each batch in parallel, such that each RNN consecutively processes entries that are 'skip' steps apart
193 self.GRUskip.flatten_parameters()
194 _, s = self.GRUskip(s) # maps (seq_len=skipRnnSeqLength, batch=batch_size * skip, input_size=numConv) -> hidden state (num_layers=1, batch=batch_size * skip, hidden_size=hidS)
195 # Because of the way the data is grouped, we obtain not one vector of size hidS but skip vectors of size hidS
196 s = s.view(batch_size, self.skip * self.hidSkip) # regroup by batch -> (batch_size, skip * hidS)
197 s = dropout(s)
198 r = torch.cat((r, s), 1) # (batch_size, hidR + skip * hidS)
200 if self.mode == self.Mode.ENCODER:
201 return r
203 res = self.linear1(r) # (batch_size, totalOutputDim)
205 # auto-regressive highway model
206 if self.hw > 0:
207 resHW = x[:, -self.hw:, :] # keep only the last hw entries for each input: (batch_size, hw, inputDimPerTimeSlice)
208 resHW = resHW.view(-1, self.hw * self.inputDimPerTimeSlice) # (batch_size, hw * inputDimPerTimeSlice)
209 resHW = self.highway(resHW) # (batch_size, totalOutputDim)
210 if res is None:
211 res = resHW
212 else:
213 res = self.highwayCombine(res, resHW) # (batch_size, totalOutputDim)
215 if self.output:
216 res = self.output(res)
218 res = res.view(batch_size, self.numOutputTimeSlices, self.timeSeriesDimPerTimeSlice)
219 if self.mode == self.Mode.CLASSIFICATION:
220 res = res.permute(0, 2, 1)
221 return res
223 @staticmethod
224 def _plus(x, y):
225 return x + y
227 @staticmethod
228 def _product(x, y):
229 return x * y