Coverage for src/sensai/torch/torch_opt.py: 80%
571 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import enum
2import functools
3import logging
4import math
5import time
6from abc import ABC, abstractmethod
7from collections import OrderedDict
8from enum import Enum
9from typing import List, Union, Sequence, Callable, TYPE_CHECKING, Tuple, Optional, Dict, Any
11import matplotlib.figure
12import numpy as np
13import pandas as pd
14import torch
15import torch.nn as nn
16import torch.optim as optim
17from matplotlib import pyplot as plt
18from torch import cuda as torchcuda
20from .torch_data import TensorScaler, DataUtil, TorchDataSet, TorchDataSetProviderFromDataUtil, TorchDataSetProvider, \
21 TensorScalerIdentity, TensorTransformer
22from .torch_enums import ClassificationOutputMode
23from ..util.string import ToStringMixin
25if TYPE_CHECKING:
26 from .torch_base import TorchModel
28log = logging.getLogger(__name__)
31class Optimiser(enum.Enum):
32 SGD = ("sgd", optim.SGD)
33 ASGD = ("asgd", optim.ASGD)
34 ADAGRAD = ("adagrad", optim.Adagrad)
35 ADADELTA = ("adadelta", optim.Adadelta)
36 ADAM = ("adam", optim.Adam)
37 ADAMW = ("adamw", optim.AdamW)
38 ADAMAX = ("adamax", optim.Adamax)
39 RMSPROP = ("rmsprop", optim.RMSprop)
40 RPROP = ("rprop", optim.Rprop)
41 LBFGS = ("lbfgs", optim.LBFGS)
43 @classmethod
44 def from_name(cls, name: str) -> "Optimiser":
45 lname = name.lower()
46 for o in cls:
47 if o.value[0] == lname:
48 return o
49 raise ValueError(f"Unknown optimiser name '{name}'; known names: {[o.value[0] for o in cls]}")
51 @classmethod
52 def from_name_or_instance(cls, name_or_instance: Union[str, "Optimiser"]) -> "Optimiser":
53 if type(name_or_instance) == str:
54 return cls.from_name(name_or_instance)
55 else:
56 return name_or_instance
59class _Optimiser(object):
60 """
61 Wrapper for classes inherited from torch.optim.Optimizer
62 """
63 def __init__(self, params, method: Union[str, Optimiser], lr, max_grad_norm, use_shrinkage=True, **optimiser_args):
64 """
65 :param params: an iterable of torch.Tensor s or dict s. Specifies what Tensors should be optimized.
66 :param method: the optimiser to use
67 :param lr: learnig rate
68 :param max_grad_norm: gradient norm value beyond which to apply gradient shrinkage
69 :param optimiser_args: keyword arguments to be used in actual torch optimiser
70 """
71 self.method = Optimiser.from_name_or_instance(method)
72 self.params = list(params) # careful: params may be a generator
73 self.last_ppl = None
74 self.lr = lr
75 self.max_grad_norm = max_grad_norm
76 self.start_decay = False
77 self.optimiserArgs = optimiser_args
78 self.use_shrinkage = use_shrinkage
80 # instantiate optimiser
81 optimiser_args = dict(self.optimiserArgs)
82 optimiser_args.update({'lr': self.lr})
83 if self.method == Optimiser.LBFGS:
84 self.use_shrinkage = False
85 self.optimizer = optim.LBFGS(self.params, **optimiser_args)
86 else:
87 cons = self.method.value[1]
88 self.optimizer = cons(self.params, **optimiser_args)
90 def step(self, loss_backward: Callable):
91 """
92 :param loss_backward: callable, performs backward step and returns loss
93 :return: loss value
94 """
95 if self.use_shrinkage:
96 def closure_with_shrinkage():
97 loss_value = loss_backward()
98 torch.nn.utils.clip_grad_norm_(self.params, self.max_grad_norm)
99 return loss_value
101 closure = closure_with_shrinkage
102 else:
103 closure = loss_backward
105 loss = self.optimizer.step(closure)
106 return loss
109class NNLossEvaluator(ABC):
110 """
111 Base class defining the interface for training and validation loss evaluation.
112 """
113 class Evaluation(ABC):
114 @abstractmethod
115 def start_epoch(self) -> None:
116 """
117 Starts a new epoch, resetting any aggregated values required to ultimately return the
118 epoch's overall training loss (via getEpochTrainLoss) and validation metrics (via getValidationMetrics)
119 """
120 pass
122 @abstractmethod
123 def compute_train_batch_loss(self, model_output, ground_truth, x, y) -> torch.Tensor:
124 """
125 Computes the loss for the given model outputs and ground truth values for a batch
126 and aggregates the computed loss values such that :meth:``getEpochTrainLoss`` can return an appropriate
127 result for the entire epoch.
128 The original batch tensors X and Y are provided as meta-information only.
130 :param model_output: the model output
131 :param ground_truth: the ground truth values
132 :param x: the original batch input tensor
133 :param y: the original batch output (ground truth) tensor
134 :return: the loss (scalar tensor)
135 """
136 pass
138 @abstractmethod
139 def get_epoch_train_loss(self) -> float:
140 """
141 :return: the epoch's overall training loss (as obtained by collecting data from individual training
142 batch data passed to computeTrainBatchLoss)
143 """
144 pass
146 @abstractmethod
147 def process_validation_batch(self, model_output, ground_truth, x, y) -> None:
148 """
149 Processes the given model outputs and ground truth values in order to compute sufficient statistics for
150 velidation metrics, which at the end of the epoch, shall be retrievable via method getValidationMetrics
152 :param model_output: the model output
153 :param ground_truth: the ground truth values
154 :param x: the original batch input tensor
155 :param y: the original batch output (ground truth) tensor
156 :return: the loss (scalar tensor)
157 """
158 pass
160 @abstractmethod
161 def get_validation_metrics(self) -> Dict[str, float]:
162 pass
164 @abstractmethod
165 def start_evaluation(self, cuda: bool) -> Evaluation:
166 """
167 Begins the evaluation of a model, returning a (stateful) object which is to perform the necessary computations.
169 :param cuda: whether CUDA is being applied (all tensors/models on the GPU)
170 :return: the evaluation object
171 """
172 pass
174 @abstractmethod
175 def get_validation_metric_name(self) -> str:
176 """
177 :return: the name of the validation metric which is to be used to determine the best model (key for the ordered
178 dictionary returned by method Evaluation.getValidationMetrics)
179 """
180 pass
183class NNLossEvaluatorFixedDim(NNLossEvaluator, ABC):
184 """
185 Base class defining the interface for training and validation loss evaluation, which uses fixed-dimension
186 outputs and aggregates individual training batch losses that are summed losses per batch
187 (averaging appropriately internally).
188 """
189 class Evaluation(NNLossEvaluator.Evaluation):
190 def __init__(self, criterion, validation_loss_evaluator: "NNLossEvaluatorFixedDim.ValidationLossEvaluator",
191 output_dim_weights: torch.Tensor = None):
192 self.output_dim_weights = output_dim_weights
193 self.output_dim_weight_sum = torch.sum(output_dim_weights) if output_dim_weights is not None else None
194 self.validation_loss_evaluator = validation_loss_evaluator
195 self.criterion = criterion
196 self.total_loss = None
197 self.num_samples = None
198 self.num_outputs_per_data_point: Optional[int] = None
199 self.validation_ground_truth_shape = None
201 def start_epoch(self):
202 self.total_loss = 0
203 self.num_samples = 0
204 self.validation_ground_truth_shape = None
206 def compute_train_batch_loss(self, model_output, ground_truth, x, y) -> torch.Tensor:
207 # size of modelOutput and groundTruth: (batchSize, outputDim=numOutputsPerDataPoint)
208 if self.num_outputs_per_data_point is None:
209 output_shape = y.shape[1:]
210 self.num_outputs_per_data_point = functools.reduce(lambda x, y: x * y, output_shape, 1)
211 assert self.output_dim_weights is None or len(self.output_dim_weights) == self.num_outputs_per_data_point
212 num_data_points_in_batch = y.shape[0]
213 if self.output_dim_weights is None:
214 # treat all dimensions as equal, applying criterion to entire tensors
215 loss = self.criterion(model_output, ground_truth)
216 self.num_samples += num_data_points_in_batch * self.num_outputs_per_data_point
217 self.total_loss += loss.item()
218 return loss
219 else:
220 # compute loss per dimension and return weighted loss
221 loss_per_dim = torch.zeros(self.num_outputs_per_data_point, device=model_output.device, dtype=torch.float)
222 for o in range(self.num_outputs_per_data_point):
223 loss_per_dim[o] = self.criterion(model_output[:, o], ground_truth[:, o])
224 weighted_loss = (loss_per_dim * self.output_dim_weights).sum() / self.output_dim_weight_sum
225 self.num_samples += num_data_points_in_batch
226 self.total_loss += weighted_loss.item()
227 return weighted_loss
229 def get_epoch_train_loss(self) -> float:
230 return self.total_loss / self.num_samples
232 def process_validation_batch(self, model_output, ground_truth, x, y):
233 if self.validation_ground_truth_shape is None:
234 self.validation_ground_truth_shape = y.shape[1:] # the shape of the output of a single model application
235 self.validation_loss_evaluator.start_validation_collection(self.validation_ground_truth_shape)
236 self.validation_loss_evaluator.process_validation_result_batch(model_output, ground_truth)
238 def get_validation_metrics(self) -> Dict[str, float]:
239 return self.validation_loss_evaluator.end_validation_collection()
241 def start_evaluation(self, cuda: bool) -> Evaluation:
242 criterion = self.get_training_criterion()
243 output_dim_weights_array = self.get_output_dim_weights()
244 output_dim_weights_tensor = torch.from_numpy(output_dim_weights_array).float() if output_dim_weights_array is not None else None
245 if cuda:
246 criterion = criterion.cuda()
247 if output_dim_weights_tensor is not None:
248 output_dim_weights_tensor = output_dim_weights_tensor.cuda()
249 return self.Evaluation(criterion, self.create_validation_loss_evaluator(cuda), output_dim_weights=output_dim_weights_tensor)
251 @abstractmethod
252 def get_training_criterion(self) -> nn.Module:
253 """
254 Gets the optimisation criterion (loss function) for training.
255 Standard implementations are available in torch.nn (torch.nn.MSELoss, torch.nn.CrossEntropyLoss, etc.).
256 """
257 pass
259 @abstractmethod
260 def get_output_dim_weights(self) -> Optional[np.ndarray]:
261 pass
263 @abstractmethod
264 def create_validation_loss_evaluator(self, cuda: bool) -> "ValidationLossEvaluator":
265 """
266 :param cuda: whether to use CUDA-based tensors
267 :return: the evaluator instance which is to be used to evaluate the model on validation data
268 """
269 pass
271 def get_validation_metric_name(self) -> str:
272 """
273 Gets the name of the metric (key of dictionary as returned by the validation loss evaluator's
274 endValidationCollection method), which is defining for the quality of the model and thus determines which
275 epoch's model is considered the best.
277 :return: the name of the metric
278 """
279 pass
281 class ValidationLossEvaluator(ABC):
282 @abstractmethod
283 def start_validation_collection(self, ground_truth_shape):
284 """
285 Initiates validation data collection for a new epoch, appropriately resetting this object's internal state.
287 :param ground_truth_shape: the tensor shape of a single ground truth data point (not including the batch
288 entry dimension)
289 """
290 pass
292 @abstractmethod
293 def process_validation_result_batch(self, output, ground_truth):
294 """
295 Collects, for validation, the given output and ground truth data (tensors holding data on one batch,
296 where the first dimension is the batch entry)
298 :param output: the model's output
299 :param ground_truth: the corresponding ground truth
300 """
301 pass
303 @abstractmethod
304 def end_validation_collection(self) -> OrderedDict:
305 """
306 Computes validation metrics based on the data previously processed.
308 :return: an ordered dictionary with validation metrics
309 """
310 pass
313class NNLossEvaluatorRegression(NNLossEvaluatorFixedDim, ToStringMixin):
314 """A loss evaluator for (multi-variate) regression."""
316 class LossFunction(Enum):
317 L1LOSS = "L1Loss"
318 L2LOSS = "L2Loss"
319 MSELOSS = "MSELoss"
320 SMOOTHL1LOSS = "SmoothL1Loss"
322 def __init__(self, loss_fn: LossFunction = LossFunction.L2LOSS, validation_tensor_transformer: Optional[TensorTransformer] = None,
323 output_dim_weights: Sequence[float] = None, apply_output_dim_weights_in_validation=True,
324 validation_metric_name: Optional[str] = None):
325 """
326 :param loss_fn: the loss function to use
327 :param validation_tensor_transformer: a transformer which is to be applied to validation tensors (both model outputs and ground
328 truth) prior to computing the validation metrics
329 :param output_dim_weights: vector of weights to apply to then mean loss per output dimension, i.e. for the case where for each data
330 point, the model produces n output dimensions, the mean loss for the i-th dimension is to be computed separately and be scaled
331 with the weight, and the overall loss returned is the weighted average. The weights need not sum to 1 (normalisation is
332 applied).
333 :param apply_output_dim_weights_in_validation: whether output dimension weights are also to be applied to to the metrics computed
334 for validation. Note that this may not be possible if a validationTensorTransformer which changes the output dimensions is
335 used.
336 :param validation_metric_name: the metric to use for model selection during validation; if None, use default depending on lossFn
337 """
338 self.validation_tensor_transformer = validation_tensor_transformer
339 self.output_dim_weights = np.array(output_dim_weights) if output_dim_weights is not None else None
340 self.apply_output_dim_weights_in_validation = apply_output_dim_weights_in_validation
341 self.validation_metric_name = validation_metric_name
342 if loss_fn is None:
343 loss_fn = self.LossFunction.L2LOSS
344 try:
345 self.loss_fn = self.LossFunction(loss_fn)
346 except ValueError:
347 raise Exception(f"The loss function '{loss_fn}' is not supported. "
348 f"Available options are: {[e.value for e in self.LossFunction]}")
350 def create_validation_loss_evaluator(self, cuda):
351 return self.ValidationLossEvaluator(cuda, self.validation_tensor_transformer, self.output_dim_weights,
352 self.apply_output_dim_weights_in_validation)
354 def get_training_criterion(self):
355 if self.loss_fn is self.LossFunction.L1LOSS:
356 criterion = nn.L1Loss(reduction='sum')
357 elif self.loss_fn is self.LossFunction.L2LOSS or self.loss_fn == self.LossFunction.MSELOSS:
358 criterion = nn.MSELoss(reduction='sum')
359 elif self.loss_fn is self.LossFunction.SMOOTHL1LOSS:
360 criterion = nn.SmoothL1Loss(reduction='sum')
361 else:
362 raise AssertionError(f"Loss function {self.loss_fn} defined but instantiation not implemented.")
363 return criterion
365 def get_output_dim_weights(self) -> Optional[np.ndarray]:
366 return self.output_dim_weights
368 class ValidationLossEvaluator(NNLossEvaluatorFixedDim.ValidationLossEvaluator):
369 def __init__(self, cuda: bool, validation_tensor_transformer: Optional[TensorTransformer], output_dim_weights: np.ndarray,
370 apply_output_dim_weights: bool):
371 self.validationTensorTransformer = validation_tensor_transformer
372 self.outputDimWeights = output_dim_weights
373 self.applyOutputDimWeights = apply_output_dim_weights
374 self.total_loss_l1 = None
375 self.total_loss_l2 = None
376 self.output_dims = None
377 self.allTrueOutputs = None
378 self.evaluate_l1 = nn.L1Loss(reduction='sum')
379 self.evaluate_l2 = nn.MSELoss(reduction='sum')
380 if cuda:
381 self.evaluate_l1 = self.evaluate_l1.cuda()
382 self.evaluate_l2 = self.evaluate_l2.cuda()
383 self.begin_new_validation_collection: Optional[bool] = None
385 def start_validation_collection(self, ground_truth_shape):
386 if len(ground_truth_shape) != 1:
387 raise ValueError("Outputs that are not vectors are currently unsupported")
388 self.begin_new_validation_collection = True
390 def process_validation_result_batch(self, output, ground_truth):
391 # apply tensor transformer (if any)
392 if self.validationTensorTransformer is not None:
393 output = self.validationTensorTransformer.transform(output)
394 ground_truth = self.validationTensorTransformer.transform(ground_truth)
396 # check if new collection
397 if self.begin_new_validation_collection:
398 self.output_dims = ground_truth.shape[-1]
399 self.total_loss_l1 = np.zeros(self.output_dims)
400 self.total_loss_l2 = np.zeros(self.output_dims)
401 self.allTrueOutputs = None
402 self.begin_new_validation_collection = False
404 assert len(output.shape) == 2 and len(ground_truth.shape) == 2
406 # obtain series of outputs per output dimension: (batch_size, output_size) -> (output_size, batch_size)
407 predicted_output = output.permute(1, 0)
408 true_output = ground_truth.permute(1, 0)
410 if self.allTrueOutputs is None:
411 self.allTrueOutputs = true_output
412 else:
413 self.allTrueOutputs = torch.cat((self.allTrueOutputs, true_output), dim=1)
415 for i in range(self.output_dims):
416 self.total_loss_l1[i] += self.evaluate_l1(predicted_output[i], true_output[i]).item()
417 self.total_loss_l2[i] += self.evaluate_l2(predicted_output[i], true_output[i]).item()
419 def end_validation_collection(self):
420 output_dims = self.output_dims
421 rae = np.zeros(output_dims)
422 rrse = np.zeros(output_dims)
423 mae = np.zeros(output_dims)
424 mse = np.zeros(output_dims)
426 for i in range(output_dims):
427 mean = torch.mean(self.allTrueOutputs[i])
428 ref_model_errors = self.allTrueOutputs[i] - mean
429 ref_model_sum_abs_errors = torch.sum(torch.abs(ref_model_errors)).item()
430 ref_model_sum_squared_errors = torch.sum(ref_model_errors * ref_model_errors).item()
431 num_samples = ref_model_errors.size(0)
433 mae[i] = self.total_loss_l1[i] / num_samples
434 mse[i] = self.total_loss_l2[i] / num_samples
435 rae[i] = self.total_loss_l1[i] / ref_model_sum_abs_errors if ref_model_sum_abs_errors != 0 else np.inf
436 rrse[i] = np.sqrt(mse[i]) / np.sqrt(
437 ref_model_sum_squared_errors / num_samples) if ref_model_sum_squared_errors != 0 else np.inf
439 def mean(x):
440 if self.applyOutputDimWeights:
441 return np.average(x, weights=self.outputDimWeights)
442 else:
443 return np.mean(x)
445 metrics = OrderedDict([("RRSE", mean(rrse)), ("RAE", mean(rae)), ("MSE", mean(mse)), ("MAE", mean(mae))])
446 return metrics
448 def get_validation_metric_name(self):
449 if self.validation_metric_name is not None:
450 return self.validation_metric_name
451 else:
452 if self.loss_fn is self.LossFunction.L1LOSS or self.loss_fn is self.LossFunction.SMOOTHL1LOSS:
453 return "MAE"
454 elif self.loss_fn is self.LossFunction.L2LOSS or self.loss_fn is self.LossFunction.MSELOSS:
455 return "MSE"
456 else:
457 raise AssertionError(f"No validation metric defined as selection criterion for loss function {self.loss_fn}")
460class NNLossEvaluatorClassification(NNLossEvaluatorFixedDim):
461 """A loss evaluator for classification"""
463 class LossFunction(Enum):
464 CROSSENTROPY = "CrossEntropy"
465 NLL = "NegativeLogLikelihood"
467 def create_criterion(self) -> Callable:
468 if self is self.CROSSENTROPY:
469 return nn.CrossEntropyLoss(reduction='sum')
470 elif self is self.NLL:
471 return nn.NLLLoss(reduction="sum")
473 def get_validation_metric_key(self) -> str:
474 if self is self.CROSSENTROPY:
475 return "CE"
476 elif self is self.NLL:
477 return "NLL"
479 @classmethod
480 def default_for_output_mode(cls, output_mode: ClassificationOutputMode):
481 if output_mode == ClassificationOutputMode.PROBABILITIES:
482 raise ValueError(f"No loss function available for {output_mode}; Either apply log at the end and use "
483 f"{ClassificationOutputMode.LOG_PROBABILITIES} or use a different final activation (e.g. log_softmax) "
484 f"to avoid this type of output.")
485 elif output_mode == ClassificationOutputMode.LOG_PROBABILITIES:
486 return cls.NLL
487 elif output_mode == ClassificationOutputMode.UNNORMALISED_LOG_PROBABILITIES:
488 return cls.CROSSENTROPY
489 else:
490 raise ValueError(f"No default specified for {output_mode}")
492 def __init__(self, loss_fn: LossFunction):
493 self.lossFn: "NNLossEvaluatorClassification.LossFunction" = self.LossFunction(loss_fn)
495 def __str__(self):
496 return f"{self.__class__.__name__}[{self.lossFn}]"
498 def create_validation_loss_evaluator(self, cuda):
499 return self.ValidationLossEvaluator(cuda, self.lossFn)
501 def get_training_criterion(self):
502 return self.lossFn.create_criterion()
504 def get_output_dim_weights(self) -> Optional[np.ndarray]:
505 return None
507 class ValidationLossEvaluator(NNLossEvaluatorFixedDim.ValidationLossEvaluator):
508 def __init__(self, cuda: bool, loss_fn: "NNLossEvaluatorClassification.LossFunction"):
509 self.loss_fn = loss_fn
510 self.total_loss = None
511 self.num_validation_samples = None
512 self.criterion = self.loss_fn.create_criterion()
513 if cuda:
514 self.criterion = self.criterion.cuda()
516 def start_validation_collection(self, ground_truth_shape):
517 self.total_loss = 0
518 self.num_validation_samples = 0
520 def process_validation_result_batch(self, output, ground_truth):
521 self.total_loss += self.criterion(output, ground_truth).item()
522 self.num_validation_samples += output.shape[0]
524 def end_validation_collection(self):
525 mean_loss = self.total_loss / self.num_validation_samples
526 if isinstance(self.criterion, nn.CrossEntropyLoss):
527 metrics = OrderedDict([("CE", mean_loss), ("GeoMeanProbTrueClass", math.exp(-mean_loss))])
528 elif isinstance(self.criterion, nn.NLLLoss):
529 metrics = {"NLL": mean_loss}
530 else:
531 raise ValueError()
532 return metrics
534 def get_validation_metric_name(self):
535 return self.lossFn.get_validation_metric_key()
538class NNOptimiserParams(ToStringMixin):
539 REMOVED_PARAMS = {"cuda"}
540 RENAMED_PARAMS = {
541 "optimiserClip": "optimiser_clip",
542 "lossEvaluator": "loss_evaluator",
543 "optimiserLR": "optimiser_lr",
544 "earlyStoppingEpochs": "early_stopping_epochs",
545 "batchSize": "batch_size",
546 "trainFraction": "train_fraction",
547 "scaledOutputs": "scaled_outputs",
548 "useShrinkage": "use_shrinkage",
549 "shrinkageClip": "shrinkage_clip",
550 }
552 def __init__(self,
553 loss_evaluator: NNLossEvaluator = None,
554 gpu: Optional[int] = None,
555 optimiser: Union[str, Optimiser] = "adam",
556 optimiser_lr=0.001,
557 early_stopping_epochs=None,
558 batch_size=None,
559 epochs=1000,
560 train_fraction=0.75,
561 scaled_outputs=False,
562 use_shrinkage=True,
563 shrinkage_clip=10.,
564 shuffle=True,
565 optimiser_args: Optional[Dict[str, Any]] = None):
566 """
567 :param loss_evaluator: the loss evaluator to use
568 :param gpu: the index of the GPU to be used (if CUDA is enabled for the model to be trained); if None, default to first GPU
569 :param optimiser: the optimiser to use
570 :param optimiser_lr: the optimiser's learning rate
571 :param early_stopping_epochs: the number of epochs without validation score improvement after which to abort training and
572 use the best epoch's model (early stopping); if None, never abort training before all epochs are completed
573 :param batch_size: the batch size to use; for algorithms L-BFGS (optimiser='lbfgs'), which do not use batches, leave this at None.
574 If the algorithm uses batches and None is specified, batch size 64 will be used by default.
575 :param train_fraction: the fraction of the data used for training (with the remainder being used for validation).
576 If no validation is to be performed, pass 1.0.
577 :param scaled_outputs: whether to scale all outputs, resulting in computations of the loss function based on scaled values rather
578 than normalised values.
579 Enabling scaling may not be appropriate in cases where there are multiple outputs on different scales/with completely different
580 units.
581 :param use_shrinkage: whether to apply shrinkage to gradients whose norm exceeds ``shrinkageClip``, scaling the gradient down to
582 ``shrinkageClip``
583 :param shrinkage_clip: the maximum gradient norm beyond which to apply shrinkage (if ``useShrinkage`` is True)
584 :param shuffle: whether to shuffle the training data
585 :param optimiser_args: keyword arguments to be passed on to the actual torch optimiser
586 """
587 if Optimiser.from_name_or_instance(optimiser) == Optimiser.LBFGS:
588 large_batch_size = 1e12
589 if batch_size is not None:
590 log.warning(f"LBFGS does not make use of batches, therefore using large batch size {large_batch_size} "
591 f"to achieve use of a single batch")
592 batch_size = large_batch_size
593 else:
594 if batch_size is None:
595 log.debug("No batch size was specified, using batch size 64 by default")
596 batch_size = 64
598 self.epochs = epochs
599 self.batch_size = batch_size
600 self.optimiser_lr = optimiser_lr
601 self.shrinkage_clip = shrinkage_clip
602 self.optimiser = optimiser
603 self.gpu = gpu
604 self.train_fraction = train_fraction
605 self.scaled_outputs = scaled_outputs
606 self.loss_evaluator = loss_evaluator
607 self.optimiser_args = optimiser_args if optimiser_args is not None else {}
608 self.use_shrinkage = use_shrinkage
609 self.early_stopping_epochs = early_stopping_epochs
610 self.shuffle = shuffle
612 @classmethod
613 def _updated_params(cls, params: dict) -> dict:
614 return {cls.RENAMED_PARAMS.get(k, k): v for k, v in params.items() if k not in cls.REMOVED_PARAMS}
616 def __setstate__(self, state):
617 if "shuffle" not in state:
618 state["shuffle"] = True
619 self.__dict__ = self._updated_params(state)
621 @classmethod
622 def from_dict_or_instance(cls, nn_optimiser_params: Union[dict, "NNOptimiserParams"]) -> "NNOptimiserParams":
623 if isinstance(nn_optimiser_params, NNOptimiserParams):
624 return nn_optimiser_params
625 else:
626 return cls.from_dict(nn_optimiser_params)
628 @classmethod
629 def from_dict(cls, params: dict) -> "NNOptimiserParams":
630 return NNOptimiserParams(**cls._updated_params(params))
632 # TODO remove deprecated dict interface
633 @classmethod
634 def from_either_dict_or_instance(cls, nn_optimiser_dict_params: dict, nn_optimiser_params: Optional["NNOptimiserParams"]):
635 have_instance = nn_optimiser_params is not None
636 have_dict = len(nn_optimiser_dict_params)
637 if have_instance and have_dict:
638 raise ValueError("Received both a non-empty dictionary and an instance")
639 if have_instance:
640 return nn_optimiser_params
641 else:
642 return NNOptimiserParams.from_dict(nn_optimiser_dict_params)
645class NNOptimiser:
646 log = log.getChild(__qualname__)
648 def __init__(self, params: NNOptimiserParams):
649 """
650 :param params: parameters
651 """
652 if params.loss_evaluator is None:
653 raise ValueError("Must provide a loss evaluator")
655 self.params = params
656 self.cuda = None
657 self.best_epoch = None
659 def __str__(self):
660 return f"{self.__class__.__name__}[params={self.params}]"
662 def fit(self,
663 model: "TorchModel",
664 data: Union[DataUtil, List[DataUtil], TorchDataSetProvider, List[TorchDataSetProvider],
665 TorchDataSet, List[TorchDataSet], Tuple[TorchDataSet, TorchDataSet], List[Tuple[TorchDataSet, TorchDataSet]]],
666 create_torch_module=True) -> "TrainingInfo":
667 """
668 Fits the parameters of the given model to the given data, which can be a list of or single instance of one of the following:
670 * a `DataUtil` or `TorchDataSetProvider` (from which a training set and validation set will be obtained according to
671 the `trainFraction` parameter of this object)
672 * a `TorchDataSet` which shall be used as the training set (for the case where no validation set shall be used)
673 * a tuple with two `TorchDataSet` instances, where the first shall be used as the training set and the second as
674 the validation set
676 :param model: the model to be fitted
677 :param data: the data to use (see variants above)
678 :param create_torch_module: whether to newly create the torch module that is to be trained from the model's factory.
679 If False, (re-)train the existing module.
680 """
681 self.cuda = model.cuda
682 self.log.info(f"Preparing parameter learning of {model} via {self} with cuda={self.cuda}")
684 use_validation = self.params.train_fraction != 1.0
686 def to_data_set_provider(d) -> TorchDataSetProvider:
687 if isinstance(d, TorchDataSetProvider):
688 return d
689 elif isinstance(d, DataUtil):
690 return TorchDataSetProviderFromDataUtil(d, self.cuda)
691 else:
692 raise ValueError(f"Cannot create a TorchDataSetProvider from {d}")
694 training_log_entries = []
696 def training_log(s):
697 self.log.info(s)
698 training_log_entries.append(s)
700 self._init_cuda()
702 # Set the random seed manually for reproducibility.
703 seed = 42
704 torch.manual_seed(seed)
705 if self.cuda:
706 torchcuda.manual_seed_all(seed)
707 torch.backends.cudnn.benchmark = False
708 torch.backends.cudnn.deterministic = True
710 # obtain data, splitting it into training and validation set(s)
711 validation_sets = []
712 training_sets = []
713 output_scalers = []
714 if type(data) != list:
715 data = [data]
716 self.log.info("Obtaining input/output training instances")
717 for idx_data_item, data_item in enumerate(data):
718 if isinstance(data_item, TorchDataSet):
719 if use_validation:
720 raise ValueError("Passing a TorchDataSet instance is not admissible when validation is enabled (trainFraction != 1.0). "
721 "Pass a TorchDataSetProvider or another representation that supports validation instead.")
722 training_set = data_item
723 validation_set = None
724 output_scaler = TensorScalerIdentity()
725 elif type(data_item) == tuple:
726 training_set, validation_set = data_item
727 output_scaler = TensorScalerIdentity()
728 else:
729 data_set_provider = to_data_set_provider(data_item)
730 training_set, validation_set = data_set_provider.provide_split(self.params.train_fraction)
731 output_scaler = data_set_provider.get_output_tensor_scaler()
732 training_sets.append(training_set)
733 if validation_set is not None:
734 validation_sets.append(validation_set)
735 output_scalers.append(output_scaler)
736 training_log(f"Data set {idx_data_item+1}/{len(data)}: #train={training_set.size()}, "
737 f"#validation={validation_set.size() if validation_set is not None else 'None'}")
738 training_log("Number of validation sets: %d" % len(validation_sets))
740 torch_model = model.create_torch_module() if create_torch_module else model.get_torch_module()
741 if self.cuda:
742 torch_model.cuda()
743 model.set_torch_module(torch_model)
745 n_params = sum([p.nelement() for p in torch_model.parameters()])
746 self.log.info(f"Learning parameters of {model}")
747 training_log('Number of parameters: %d' % n_params)
748 training_log(f"Starting training process via {self}")
750 loss_evaluator = self.params.loss_evaluator
752 total_epochs = None
753 best_val = 1e9
754 best_epoch = 0
755 optim = _Optimiser(torch_model.parameters(), method=self.params.optimiser, lr=self.params.optimiser_lr,
756 max_grad_norm=self.params.shrinkage_clip, use_shrinkage=self.params.use_shrinkage, **self.params.optimiser_args)
758 best_model_bytes = model.get_module_bytes()
759 loss_evaluation = loss_evaluator.start_evaluation(self.cuda)
760 validation_metric_name = loss_evaluator.get_validation_metric_name()
761 training_loss_values = []
762 validation_metric_values = []
763 try:
764 self.log.info(f'Begin training with cuda={self.cuda}')
765 self.log.info('Press Ctrl+C to end training early')
766 for epoch in range(1, self.params.epochs + 1):
767 loss_evaluation.start_epoch()
768 epoch_start_time = time.time()
770 # perform training step, processing all the training data once
771 train_loss = self._train(training_sets, torch_model, optim, loss_evaluation, self.params.batch_size, output_scalers)
772 training_loss_values.append(train_loss)
774 # perform validation, computing the mean metrics across all validation sets (if more than one),
775 # and check for new best result according to validation results
776 is_new_best = False
777 if use_validation:
778 metrics_sum = None
779 metrics_keys = None
780 for i, (validation_set, output_scaler) in enumerate(zip(validation_sets, output_scalers)):
781 metrics = self._evaluate(validation_set, torch_model, loss_evaluation, output_scaler)
782 metrics_array = np.array(list(metrics.values()))
783 if i == 0:
784 metrics_sum = metrics_array
785 metrics_keys = metrics.keys()
786 else:
787 metrics_sum += metrics_array
788 metrics_sum /= len(validation_sets) # mean results
789 metrics = dict(zip(metrics_keys, metrics_sum))
790 current_val = metrics[loss_evaluator.get_validation_metric_name()]
791 validation_metric_values.append(current_val)
792 is_new_best = current_val < best_val
793 if is_new_best:
794 best_val = current_val
795 best_epoch = epoch
796 best_str = "best {:s} {:5.6f} from this epoch".format(validation_metric_name, best_val)
797 else:
798 best_str = "best {:s} {:5.6f} from epoch {:d}".format(validation_metric_name, best_val, best_epoch)
799 val_str = f' | validation {", ".join(["%s %5.4f" % e for e in metrics.items()])} | {best_str}'
800 else:
801 val_str = ""
802 training_log(
803 'Epoch {:3d}/{} completed in {:5.2f}s | train loss {:5.4f}{:s}'.format(
804 epoch, self.params.epochs, (time.time() - epoch_start_time), train_loss, val_str))
805 total_epochs = epoch
806 if use_validation:
807 if is_new_best:
808 best_model_bytes = model.get_module_bytes()
810 # check for early stopping
811 num_epochs_without_improvement = epoch - best_epoch
812 if self.params.early_stopping_epochs is not None and \
813 num_epochs_without_improvement >= self.params.early_stopping_epochs:
814 training_log(f"Stopping early: {num_epochs_without_improvement} epochs without validation metric improvement")
815 break
817 training_log("Training complete")
818 except KeyboardInterrupt:
819 training_log('Exiting from training early because of keyboard interrupt')
821 # reload best model according to validation results
822 if use_validation:
823 training_log(f'Best model is from epoch {best_epoch} with {validation_metric_name} {best_val} on validation set')
824 self.best_epoch = best_epoch
825 model.set_module_bytes(best_model_bytes)
827 return TrainingInfo(best_epoch=best_epoch if use_validation else None, log=training_log_entries, total_epochs=total_epochs,
828 training_loss_sequence=training_loss_values, validation_metric_sequence=validation_metric_values)
830 def _apply_model(self, model, input: Union[torch.Tensor, Sequence[torch.Tensor]], ground_truth, output_scaler: TensorScaler):
831 if isinstance(input, torch.Tensor):
832 output = model(input)
833 else:
834 output = model(*input)
835 if self.params.scaled_outputs:
836 output, ground_truth = self._scaled_values(output, ground_truth, output_scaler)
837 return output, ground_truth
839 @classmethod
840 def _scaled_values(cls, model_output, ground_truth, output_scaler):
841 scaled_output = output_scaler.denormalise(model_output)
842 scaled_truth = output_scaler.denormalise(ground_truth)
843 return scaled_output, scaled_truth
845 def _train(self, data_sets: Sequence[TorchDataSet], model: nn.Module, optim: _Optimiser,
846 loss_evaluation: NNLossEvaluator.Evaluation, batch_size: int, output_scalers: Sequence[TensorScaler]):
847 """Performs one training epoch"""
848 model.train()
849 for data_set, output_scaler in zip(data_sets, output_scalers):
850 for X, Y in data_set.iter_batches(batch_size, shuffle=self.params.shuffle):
851 def closure():
852 model.zero_grad()
853 output, ground_truth = self._apply_model(model, X, Y, output_scaler)
854 loss = loss_evaluation.compute_train_batch_loss(output, ground_truth, X, Y)
855 loss.backward()
856 return loss
858 optim.step(closure)
859 return loss_evaluation.get_epoch_train_loss()
861 def _evaluate(self, data_set: TorchDataSet, model: nn.Module, loss_evaluation: NNLossEvaluator.Evaluation,
862 output_scaler: TensorScaler):
863 """Evaluates the model on the given data set (a validation set)"""
864 model.eval()
865 for X, Y in data_set.iter_batches(self.params.batch_size, shuffle=False):
866 with torch.no_grad():
867 output, ground_truth = self._apply_model(model, X, Y, output_scaler)
868 loss_evaluation.process_validation_batch(output, ground_truth, X, Y)
869 return loss_evaluation.get_validation_metrics()
871 def _init_cuda(self):
872 """Initialises CUDA (for learning) by setting the appropriate device if necessary"""
873 if self.cuda:
874 device_count = torchcuda.device_count()
875 if device_count == 0:
876 raise Exception("CUDA is enabled but no device found")
877 if self.params.gpu is None:
878 if device_count > 1:
879 log.warning("More than one GPU detected but no GPU index was specified, using GPU 0 by default.")
880 gpu_index = 0
881 else:
882 gpu_index = self.params.gpu
883 torchcuda.set_device(gpu_index)
884 elif torchcuda.is_available():
885 self.log.info("NOTE: You have a CUDA device; consider running with cuda=True")
888class TrainingInfo:
889 def __init__(self, best_epoch: int = None, log: List[str] = None, training_loss_sequence: Sequence[float] = None,
890 validation_metric_sequence: Sequence[float] = None, total_epochs=None):
891 self.validation_metric_sequence = validation_metric_sequence
892 self.training_loss_sequence = training_loss_sequence
893 self.log = log
894 self.best_epoch = best_epoch
895 self.total_epochs = total_epochs
897 def __setstate__(self, state):
898 if "totalEpochs" not in state:
899 state["totalEpochs"] = None
900 self.__dict__ = state
902 def get_training_loss_series(self) -> pd.Series:
903 return self._create_series_with_one_based_index(self.training_loss_sequence, name="training loss")
905 def get_validation_metric_series(self) -> pd.Series:
906 return self._create_series_with_one_based_index(self.validation_metric_sequence, name="validation metric")
908 def _create_series_with_one_based_index(self, sequence: Sequence, name: str):
909 series = pd.Series(sequence, name=name)
910 series.index += 1
911 return series
913 def plot_all(self) -> matplotlib.figure.Figure:
914 """
915 Plots both the sequence of training loss values and the sequence of validation metric values
916 """
917 ts = self.get_training_loss_series()
918 vs = self.get_validation_metric_series()
920 fig, primary_ax = plt.subplots(1, 1)
921 secondary_ax = primary_ax.twinx()
923 training_line = primary_ax.plot(ts, color='blue')
924 validation_line = secondary_ax.plot(vs, color='orange')
925 best_epoc_line = primary_ax.axvline(self.best_epoch, color='black', linestyle='dashed')
927 primary_ax.set_xlabel("epoch")
928 primary_ax.set_ylabel(ts.name)
929 secondary_ax.set_ylabel(vs.name)
931 primary_ax.legend(training_line + validation_line + [best_epoc_line], [ts.name, vs.name, "best epoch"])
932 plt.tight_layout()
934 return fig