Coverage for src/sensai/tensorflow/tf_base.py: 31%
78 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1from abc import ABC, abstractmethod
2import logging
3import os
4import tempfile
5from typing import Optional
7import pandas as pd
8import tensorflow as tf
10from .. import normalisation
11from ..vector_model import VectorRegressionModel
13log = logging.getLogger(__name__)
16class TensorFlowSession:
17 session = None
18 _isKerasSessionSet = False
20 @classmethod
21 def configure_session(cls, gpu_allow_growth=True, gpu_per_process_memory_fraction=None):
22 tf_config = tf.compat.v1.ConfigProto()
23 tf_config.gpu_options.allow_growth = gpu_allow_growth # dynamically grow the memory used on the GPU
24 tf_config.log_device_placement = False
25 if gpu_per_process_memory_fraction is not None:
26 tf_config.gpu_options.per_process_gpu_memory_fraction = gpu_per_process_memory_fraction # in case we get CUDNN_STATUS_INTERNAL_ERROR
27 cls.session = tf.compat.v1.Session(config=tf_config)
29 @classmethod
30 def set_keras_session(cls, allow_default=True):
31 """
32 Sets the (previously configured) session for use with keras if it has not been previously been set.
33 If no session has been configured, the parameter allowDefault controls whether it is admissible to create a session with default
34 parameters.
36 :param allow_default: whether to configure, for the case where no session was previously configured, a new session with the defaults.
37 """
38 if cls.session is None:
39 if allow_default:
40 log.info("No TensorFlow session was configured. Creating a new session with default values.")
41 cls.configure_session()
42 else:
43 raise Exception(f"The session has not yet been configured. Call {cls.__name__}.{cls.configure_session.__name__} beforehand")
44 if not cls._isKerasSessionSet:
45 tf.keras.backend.set_session(cls.session)
46 cls._isKerasSessionSet = True
49class KerasVectorRegressionModel(VectorRegressionModel, ABC):
50 """An abstract simple model which maps vectors to vectors and works on pandas.DataFrames (for inputs and outputs)"""
52 def __init__(self, normalisation_mode: normalisation.NormalisationMode, loss, metrics, optimiser,
53 batch_size=64, epochs=1000, validation_fraction=0.2):
54 """
55 :param normalisation_mode:
56 :param loss:
57 :param metrics:
58 :param optimiser:
59 :param batch_size:
60 :param epochs:
61 :param validation_fraction:
62 """
63 super().__init__()
64 self.normalisation_mode = normalisation_mode
65 self.batch_size = batch_size
66 self.epochs = epochs
67 self.optimiser = optimiser
68 self.loss = loss
69 self.metrics = list(metrics)
70 self.validation_fraction = validation_fraction
72 self.model = None
73 self.input_scaler = None
74 self.output_scaler = None
75 self.training_history = None
77 def __str__(self):
78 params = dict(normalisationMode=self.normalisation_mode, optimiser=self.optimiser, loss=self.loss, metrics=self.metrics,
79 epochs=self.epochs, validationFraction=self.validation_fraction, batchSize=self.batch_size)
80 return f"{self.__class__.__name__}{params}"
82 @abstractmethod
83 def _create_model(self, input_dim, output_dim):
84 """
85 Creates a keras model
87 :param input_dim: the number of input dimensions
88 :param output_dim: the number of output dimensions
89 :return: the model
90 """
91 pass
93 def _fit(self, inputs: pd.DataFrame, outputs: pd.DataFrame, weights: Optional[pd.Series]):
94 self._warn_sample_weights_unsupported(False, weights)
95 # normalise data
96 self.input_scaler = normalisation.VectorDataScaler(inputs, self.normalisation_mode)
97 self.output_scaler = normalisation.VectorDataScaler(outputs, self.normalisation_mode)
98 norm_inputs = self.input_scaler.get_normalised_array(inputs)
99 norm_outputs = self.output_scaler.get_normalised_array(outputs)
101 # split data into training and validation set
102 train_split = int(norm_inputs.shape[0] * (1-self.validation_fraction))
103 train_inputs = norm_inputs[:train_split]
104 train_outputs = norm_outputs[:train_split]
105 val_inputs = norm_inputs[train_split:]
106 val_outputs = norm_outputs[train_split:]
108 # create and fit model
109 TensorFlowSession.set_keras_session()
110 model = self._create_model(inputs.shape[1], outputs.shape[1])
111 model.compile(optimizer=self.optimiser, loss=self.loss, metrics=self.metrics)
112 temp_file_handle, temp_file_path = tempfile.mkstemp(".keras.model")
113 try:
114 os.close(temp_file_handle)
115 checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(temp_file_path, monitor='val_loss', save_best_only=True,
116 save_weights_only=True)
117 self.training_history = model.fit(train_inputs, train_outputs, batch_size=self.batch_size, epochs=self.epochs, verbose=2,
118 validation_data=(val_inputs, val_outputs), callbacks=[checkpoint_callback])
119 model.load_weights(temp_file_path)
120 finally:
121 os.unlink(temp_file_path)
122 self.model = model
124 def _predict(self, inputs: pd.DataFrame) -> pd.DataFrame:
125 x = self.input_scaler.get_normalised_array(inputs)
126 y = self.model.predict(x)
127 y = self.output_scaler.get_denormalised_array(y)
128 return pd.DataFrame(y, columns=self.output_scaler.dimension_names)