Coverage for src/sensai/hyperopt.py: 23%
295 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from dataclasses import dataclass
2from datetime import datetime
3from concurrent.futures import ProcessPoolExecutor
4import uuid
6import logging
7import os
8import pandas as pd
9from abc import ABC
10from abc import abstractmethod
11from random import Random
12from typing import Dict, Sequence, Any, Callable, Generator, Union, Tuple, List, Optional, Hashable
14from .evaluation.evaluator import MetricsDictProvider
15from .local_search import SACostValue, SACostValueNumeric, SAOperator, SAState, SimulatedAnnealing, \
16 SAProbabilitySchedule, SAProbabilityFunctionLinear
17from .tracking.tracking_base import TrackingMixin, TrackedExperiment
18from .vector_model import VectorModel
20log = logging.getLogger(__name__)
23def iter_param_combinations(hyper_param_values: Dict[str, Sequence[Any]]) -> Generator[Dict[str, Any], None, None]:
24 """
25 Create all possible combinations of values from a dictionary of possible parameter values
27 :param hyper_param_values: a mapping from parameter names to lists of possible values
28 :return: a dictionary mapping each parameter name to one of the values
29 """
30 pairs = list(hyper_param_values.items())
32 def _iter_recursive_param_combinations(pairs, i, params):
33 """
34 Recursive function to create all possible combinations from a list of key-array entries.
35 :param pairs: a dictionary of parameter names and their corresponding values
36 :param i: the recursive step
37 :param params: a dictionary for the iteration results
38 """
39 if i == len(pairs):
40 yield dict(params)
41 else:
42 param_name, param_values = pairs[i]
43 for paramValue in param_values:
44 params[param_name] = paramValue
45 yield from _iter_recursive_param_combinations(pairs, i+1, params)
47 return _iter_recursive_param_combinations(pairs, 0, {})
50class ParameterCombinationSkipDecider(ABC):
51 """
52 Abstraction for a functional component which is told all parameter combinations that have been considered
53 and can use these as a basis for deciding whether another parameter combination shall be skipped/not be considered.
54 """
56 @abstractmethod
57 def tell(self, params: Dict[str, Any], metrics: Dict[str, Any]):
58 """
59 Informs the decider about a previously evaluated parameter combination
61 :param params: the parameter combination
62 :param metrics: the evaluation metrics
63 """
64 pass
66 @abstractmethod
67 def is_skipped(self, params: Dict[str, Any]):
68 """
69 Decides whether the given parameter combination shall be skipped
71 :param params:
72 :return: True iff it shall be skipped
73 """
74 pass
77class ParameterCombinationEquivalenceClassValueCache(ABC):
78 """
79 Represents a cache which stores (arbitrary) values for parameter combinations, i.e. keys in the cache
80 are derived from parameter combinations.
81 The cache may map the equivalent parameter combinations to the same keys to indicate that the
82 parameter combinations are equivalent; the keys thus correspond to representations of equivalence classes over
83 parameter combinations.
84 This enables hyper-parameter search to skip the re-computation of results for equivalent parameter combinations.
85 """
86 def __init__(self):
87 self._cache = {}
89 @abstractmethod
90 def _equivalence_class(self, params: Dict[str, Any]) -> Hashable:
91 """
92 Computes a (hashable) equivalence class representation for the given parameter combination.
93 For instance, if all parameters have influence on the evaluation of a model and no two combinations would
94 lead to equivalent results, this could simply return a tuple containing all parameter values (in a fixed order).
96 :param params: the parameter combination
97 :return: a hashable key containing all the information from the parameter combination that influences the
98 computation of model evaluation results
99 """
100 pass
102 def set(self, params: Dict[str, Any], value: Any):
103 self._cache[self._equivalence_class(params)] = value
105 def get(self, params: Dict[str, Any]):
106 """
107 Gets the value associated with the (equivalence class of the) parameter combination
108 :param params: the parameter combination
109 :return:
110 """
111 return self._cache.get(self._equivalence_class(params))
114class ParametersMetricsCollection:
115 """
116 Utility class for holding and persisting evaluation results
117 """
118 def __init__(self, csv_path=None, sort_column_name=None, ascending=True, incremental=False):
119 """
120 :param csv_path: path to save the data frame to upon every update
121 :param sort_column_name: the column name by which to sort the data frame that is collected; if None, do not sort
122 :param ascending: whether to sort in ascending order; has an effect only if sortColumnName is not None
123 :param incremental: whether to add to an existing CSV file instead of overwriting it
124 """
125 self.sort_column_name = sort_column_name
126 self.csv_path = csv_path
127 self.ascending = ascending
128 csv_path_exists = csv_path is not None and os.path.exists(csv_path)
129 if csv_path_exists and incremental:
130 self.df = pd.read_csv(csv_path)
131 log.info(f"Found existing CSV file with {len(self.df)} entries; {csv_path} will be extended (incremental mode)")
132 self._current_row = len(self.df)
133 self.value_dicts = [nt._asdict() for nt in self.df.itertuples()]
134 else:
135 if csv_path is not None:
136 if not csv_path_exists:
137 log.info(f"Results will be written to new file {csv_path}")
138 else:
139 log.warning(f"Results in existing file ({csv_path}) will be overwritten (non-incremental mode)")
140 self.df = None
141 self._current_row = 0
142 self.value_dicts = []
144 def add_values(self, values: Dict[str, Any]):
145 """
146 Adds the provided values as a new row to the collection.
147 If csvPath was provided in the constructor, saves the updated collection to that file.
149 :param values: Dict holding the evaluation results and parameters
150 :return:
151 """
152 if self.df is None:
153 cols = list(values.keys())
155 # check sort column and move it to the front
156 if self.sort_column_name is not None:
157 if self.sort_column_name not in cols:
158 log.warning(f"Specified sort column '{self.sort_column_name}' not in list of columns: {cols}; "
159 f"sorting will not take place!")
160 else:
161 cols.remove(self.sort_column_name)
162 cols.insert(0, self.sort_column_name)
164 self.df = pd.DataFrame(columns=cols)
165 else:
166 # check for new columns
167 for col in values.keys():
168 if col not in self.df.columns:
169 self.df[col] = None
171 # append data to data frame
172 self.df.loc[self._current_row] = [values.get(c) for c in self.df.columns]
173 self._current_row += 1
175 # sort where applicable
176 if self.sort_column_name is not None and self.sort_column_name in self.df.columns:
177 self.df.sort_values(self.sort_column_name, axis=0, inplace=True, ascending=self.ascending)
178 self.df.reset_index(drop=True, inplace=True)
180 self._save_csv()
182 def _save_csv(self):
183 if self.csv_path is not None:
184 dirname = os.path.dirname(self.csv_path)
185 if dirname != "":
186 os.makedirs(dirname, exist_ok=True)
187 self.df.to_csv(self.csv_path, index=False)
189 def get_data_frame(self) -> pd.DataFrame:
190 return self.df
192 def contains(self, values: Dict[str, Any]):
193 for existingValues in self.value_dicts:
194 is_contained = True
195 for k, v in values.items():
196 ev = existingValues.get(k)
197 if ev != v and str(ev) != str(v):
198 is_contained = False
199 break
200 if is_contained:
201 return True
204class GridSearch(TrackingMixin):
205 """
206 Instances of this class can be used for evaluating models with different user-provided parametrizations
207 over the same data and persisting the results
208 """
209 log = log.getChild(__qualname__)
211 def __init__(self,
212 model_factory: Callable[..., VectorModel],
213 parameter_options: Union[Dict[str, Sequence[Any]], List[Dict[str, Sequence[Any]]]],
214 num_processes=1,
215 csv_results_path: str = None,
216 incremental=False,
217 incremental_skip_existing=False,
218 parameter_combination_skip_decider: ParameterCombinationSkipDecider = None,
219 model_save_directory: str = None,
220 name: str = None):
221 """
222 :param model_factory: the function to call with keyword arguments reflecting the parameters to try in order to obtain a model
223 instance
224 :param parameter_options: a dictionary which maps from parameter names to lists of possible values - or a list of such dictionaries,
225 where each dictionary in the list has the same keys
226 :param num_processes: the number of parallel processes to use for the search (use 1 to run without multi-processing)
227 :param csv_results_path: the path to a directory or concrete CSV file to which the results shall be written;
228 if it is None, no CSV data will be written; if it is a directory, a file name starting with this grid search's name (see below)
229 will be created.
230 The resulting CSV data will contain one line per evaluated parameter combination.
231 :param incremental: whether to add to an existing CSV file instead of overwriting it
232 :param incremental_skip_existing: if incremental mode is on, whether to skip any parameter combinations that are already present
233 in the CSV file
234 :param parameter_combination_skip_decider: an instance to which parameters combinations can be passed in order to decide whether the
235 combination shall be skipped (e.g. because it is redundant/equivalent to another combination or inadmissible)
236 :param model_save_directory: the directory where the serialized models shall be saved; if None, models are not saved
237 :param name: the name of this grid search, which will, in particular, be prepended to all saved model files;
238 if None, a default name will be generated of the form "gridSearch_<timestamp>"
239 """
240 self.model_factory = model_factory
241 if type(parameter_options) == list:
242 self.parameter_options_list = parameter_options
243 else:
244 self.parameter_options_list = [parameter_options]
245 self.param_names = set(self.parameter_options_list[0].keys())
246 for d in self.parameter_options_list[1:]:
247 if set(d.keys()) != self.param_names:
248 raise ValueError("Keys must be the same for all parameter options dictionaries")
249 self.num_processes = num_processes
250 self.parameter_combination_skip_decider = parameter_combination_skip_decider
251 self.model_save_directory = model_save_directory
252 self.name = name if name is not None else "gridSearch_" + datetime.now().strftime('%Y%m%d-%H%M%S')
253 self.csv_results_path = csv_results_path
254 self.incremental = incremental
255 self.incremental_skip_existing = incremental_skip_existing
256 if self.csv_results_path is not None and os.path.isdir(csv_results_path):
257 self.csv_results_path = os.path.join(self.csv_results_path, f"{self.name}_results.csv")
259 self.num_combinations = 0
260 for parameter_options in self.parameter_options_list:
261 n = 1
262 for options in parameter_options.values():
263 n *= len(options)
264 self.num_combinations += n
265 log.info(f"Created GridSearch object for {self.num_combinations} parameter combinations")
267 self._executor = None
269 @classmethod
270 def _eval_params(cls,
271 model_factory: Callable[..., VectorModel],
272 metrics_evaluator: MetricsDictProvider,
273 skip_decider: ParameterCombinationSkipDecider,
274 grid_search_name, combination_idx,
275 model_save_directory: Optional[str],
276 **params) -> Optional[Dict[str, Any]]:
277 if skip_decider is not None:
278 if skip_decider.is_skipped(params):
279 cls.log.info(f"Parameter combination is skipped according to {skip_decider}: {params}")
280 return None
281 cls.log.info(f"Evaluating {params}")
282 model = model_factory(**params)
283 values = metrics_evaluator.compute_metrics(model)
284 if model_save_directory is not None:
285 filename = f"{grid_search_name}_{combination_idx}.pickle"
286 log.info(f"Saving trained model to {filename} ...")
287 model.save(os.path.join(model_save_directory, filename))
288 values["filename"] = filename
289 values["str(model)"] = str(model)
290 values.update(**params)
291 if skip_decider is not None:
292 skip_decider.tell(params, values)
293 return values
295 def run(self, metrics_evaluator: MetricsDictProvider, sort_column_name=None, ascending=True) -> "GridSearch.Result":
296 """
297 Run the grid search. If csvResultsPath was provided in the constructor, each evaluation result will be saved
298 to that file directly after being computed
300 :param metrics_evaluator: the evaluator or cross-validator with which to evaluate models
301 :param sort_column_name: the name of the metric (column) by which to sort the data frame of results; if None, do not sort.
302 Note that all Metric instances have a static member `name`, e.g. you could use `RegressionMetricMSE.name`.
303 :param ascending: whether to sort in ascending order; has an effect only if `sort_column_name` is specified.
304 The result object will assume, by default, that the resulting top/first element is the best,
305 i.e. ascending=False means "higher is better", and ascending=True means "Lower is better".
306 :return: an object holding the results
307 """
308 if self.tracked_experiment is not None:
309 logging_callback = self.tracked_experiment.track_values
310 elif metrics_evaluator.tracked_experiment is not None:
311 logging_callback = metrics_evaluator.tracked_experiment.track_values
312 else:
313 logging_callback = None
314 params_metrics_collection = ParametersMetricsCollection(csv_path=self.csv_results_path, sort_column_name=sort_column_name,
315 ascending=ascending, incremental=self.incremental)
317 def collect_result(values):
318 if values is None:
319 return
320 if logging_callback is not None:
321 logging_callback(values)
322 params_metrics_collection.add_values(values)
323 log.info(f"Updated grid search result:\n{params_metrics_collection.get_data_frame().to_string()}")
325 if self.num_processes == 1:
326 combination_idx = 0
327 for parameter_options in self.parameter_options_list:
328 for params_dict in iter_param_combinations(parameter_options):
329 if self.incremental_skip_existing and self.incremental:
330 if params_metrics_collection.contains(params_dict):
331 log.info(f"Skipped because parameters are already present in collection (incremental mode): {params_dict}")
332 continue
333 collect_result(self._eval_params(self.model_factory, metrics_evaluator, self.parameter_combination_skip_decider,
334 self.name, combination_idx, self.model_save_directory, **params_dict))
335 combination_idx += 1
336 else:
337 executor = ProcessPoolExecutor(max_workers=self.num_processes)
338 futures = []
339 combination_idx = 0
340 for parameter_options in self.parameter_options_list:
341 for params_dict in iter_param_combinations(parameter_options):
342 if self.incremental_skip_existing and self.incremental:
343 if params_metrics_collection.contains(params_dict):
344 log.info(f"Skipped because parameters are already present in collection (incremental mode): {params_dict}")
345 continue
346 futures.append(executor.submit(self._eval_params, self.model_factory, metrics_evaluator,
347 self.parameter_combination_skip_decider,
348 self.name, combination_idx, self.model_save_directory, **params_dict))
349 combination_idx += 1
350 for future in futures:
351 collect_result(future.result())
353 df = params_metrics_collection.get_data_frame()
354 return self.Result(df, self.param_names, default_metric_name=sort_column_name, default_higher_is_better=not ascending)
356 class Result:
357 def __init__(self, df: pd.DataFrame, param_names: List[str], default_metric_name: Optional[str] = None,
358 default_higher_is_better: Optional[bool] = None):
359 self.df = df
360 self.param_names = param_names
361 self.default_metric_name = default_metric_name
362 if default_metric_name is not None:
363 self.default_higher_is_better = default_higher_is_better
364 else:
365 self.default_higher_is_better = None
367 @dataclass
368 class BestParams:
369 metric_name: str
370 metric_value: float
371 params: dict
373 def get_best_params(self, metric_name: Optional[str] = None, higher_is_better: Optional[bool] = None) -> BestParams:
374 """
376 :param metric_name: the metric name for which to return the best result; can be None if the GridSearch used
377 a metric to sort by
378 :param higher_is_better: whether higher is better for the metric to sort by; can be None if the GridSearch
379 use a metric to sort by and configured the sort order such that the best configuration is at the top
380 :return: a pair (d, v) where d dictionary with the best parameters found during the grid search and v is the
381 corresponding metric value
382 """
383 if metric_name is None:
384 metric_name = self.default_metric_name
385 if metric_name is None:
386 raise ValueError("metric_name must be specified")
387 if higher_is_better is None:
388 higher_is_better = self.default_higher_is_better
389 if higher_is_better is None:
390 raise ValueError("higher_is_better must be specified")
392 df = self.df.sort_values(metric_name, axis=0, inplace=False, ascending=not higher_is_better)
394 best_params = {}
395 for param in self.param_names:
396 best_params[param] = df.iloc[0][param]
397 best_metric_value = df.iloc[0][metric_name]
399 return self.BestParams(metric_name=metric_name, metric_value=best_metric_value, params=best_params)
402class SAHyperOpt(TrackingMixin):
403 log = log.getChild(__qualname__)
405 class State(SAState):
406 def __init__(self, params: Dict[str, Any], random_state: Random, results: Dict, compute_metric: Callable[[Dict[str, Any]], float]):
407 self.compute_metric = compute_metric
408 self.results = results
409 self.params = dict(params)
410 super().__init__(random_state)
412 def compute_cost_value(self) -> SACostValueNumeric:
413 return SACostValueNumeric(self.compute_metric(self.params))
415 def get_state_representation(self):
416 return self.params
418 def apply_state_representation(self, representation):
419 self.results.update(representation)
421 class ParameterChangeOperator(SAOperator[State]):
422 def __init__(self, state: 'SAHyperOpt.State'):
423 super().__init__(state)
425 def apply_state_change(self, params):
426 self.state.params.update(params)
428 def cost_delta(self, params) -> SACostValue:
429 model_params = dict(self.state.params)
430 model_params.update(params)
431 return SACostValueNumeric(self.state.compute_metric(model_params) - self.state.cost.value())
433 def choose_params(self) -> Optional[Tuple[Tuple, Optional[SACostValue]]]:
434 params = self._choose_changed_model_parameters()
435 if params is None:
436 return None
437 return ((params, ), None)
439 @abstractmethod
440 def _choose_changed_model_parameters(self) -> Dict[str, Any]:
441 pass
443 def __init__(self,
444 model_factory: Callable[..., VectorModel],
445 ops_and_weights: List[Tuple[Callable[['SAHyperOpt.State'], 'SAHyperOpt.ParameterChangeOperator'], float]],
446 initial_parameters: Dict[str, Any],
447 metrics_evaluator: MetricsDictProvider,
448 metric_to_optimise: str,
449 minimise_metric: bool = False,
450 collect_data_frame: bool = True,
451 csv_results_path: Optional[str] = None,
452 parameter_combination_equivalence_class_value_cache: ParameterCombinationEquivalenceClassValueCache = None,
453 p0: float = 0.5,
454 p1: float = 0.0):
455 """
456 :param model_factory: a factory for the generation of models which is called with the current parameter combination
457 (all keyword arguments), initially initialParameters
458 :param ops_and_weights: a sequence of tuples (operator factory, operator weight) for simulated annealing
459 :param initial_parameters: the initial parameter combination
460 :param metrics_evaluator: the evaluator/validator to use in order to evaluate models
461 :param metric_to_optimise: the name of the metric (as generated by the evaluator/validator) to optimise
462 :param minimise_metric: whether the metric is to be minimised; if False, maximise the metric
463 :param collect_data_frame: whether to collect (and regularly log) the data frame of all parameter combinations and
464 evaluation results
465 :param csv_results_path: the (optional) path of a CSV file in which to store a table of all computed results;
466 if this is not None, then collectDataFrame is automatically set to True
467 :param parameter_combination_equivalence_class_value_cache: a cache in which to store computed results and whose notion
468 of equivalence can be used to avoid duplicate computations
469 :param p0: the initial probability (at the start of the optimisation) of accepting a state with an inferior evaluation
470 to the current state's (for the mean observed evaluation delta)
471 :param p1: the final probability (at the end of the optimisation) of accepting a state with an inferior evaluation
472 to the current state's (for the mean observed evaluation delta)
473 """
474 self.minimise_metric = minimise_metric
475 self.evaluator_or_validator = metrics_evaluator
476 self.metric_to_optimise = metric_to_optimise
477 self.initial_parameters = initial_parameters
478 self.ops_and_weights = ops_and_weights
479 self.model_factory = model_factory
480 self.csv_results_path = csv_results_path
481 if csv_results_path is not None:
482 collect_data_frame = True
483 self.parameters_metrics_collection = ParametersMetricsCollection(csv_path=csv_results_path) if collect_data_frame else None
484 self.parameter_combination_equivalence_class_value_cache = parameter_combination_equivalence_class_value_cache
485 self.p0 = p0
486 self.p1 = p1
487 self._sa = None
489 @classmethod
490 def _eval_params(cls,
491 model_factory,
492 metrics_evaluator: MetricsDictProvider,
493 parameters_metrics_collection: Optional[ParametersMetricsCollection],
494 parameter_combination_equivalence_class_value_cache,
495 tracked_experiment: Optional[TrackedExperiment],
496 **params):
497 if tracked_experiment is not None and metrics_evaluator.tracked_experiment is not None:
498 log.warning(f"Tracked experiment already set in evaluator, results will be tracked twice and"
499 f"might get overwritten!")
501 metrics = None
502 if parameter_combination_equivalence_class_value_cache is not None:
503 metrics = parameter_combination_equivalence_class_value_cache.get(params)
504 if metrics is not None:
505 cls.log.info(f"Result for parameter combination {params} could be retrieved from cache, not adding new result")
506 else:
507 cls.log.info(f"Evaluating parameter combination {params}")
508 model = model_factory(**params)
509 metrics = metrics_evaluator.compute_metrics(model)
510 cls.log.info(f"Got metrics {metrics} for {params}")
512 values = dict(metrics)
513 values["str(model)"] = str(model)
514 values.update(**params)
515 if tracked_experiment is not None:
516 tracked_experiment.track_values(values)
517 if parameters_metrics_collection is not None:
518 parameters_metrics_collection.add_values(values)
519 cls.log.info(f"Data frame with all results:\n\n{parameters_metrics_collection.get_data_frame().to_string()}\n")
520 if parameter_combination_equivalence_class_value_cache is not None:
521 parameter_combination_equivalence_class_value_cache.set(params, metrics)
522 return metrics
524 def _compute_metric(self, params: Dict[str, Any]):
525 metrics = self._eval_params(self.model_factory, self.evaluator_or_validator, self.parameters_metrics_collection,
526 self.parameter_combination_equivalence_class_value_cache, self.tracked_experiment, **params)
527 metric_value = metrics[self.metric_to_optimise]
528 if not self.minimise_metric:
529 return -metric_value
530 return metric_value
532 def run(self, max_steps: Optional[int] = None, duration: Optional[float] = None, random_seed: int = 42, collect_stats: bool = True):
533 sa = SimulatedAnnealing(lambda: SAProbabilitySchedule(None, SAProbabilityFunctionLinear(p0=self.p0, p1=self.p1)),
534 self.ops_and_weights, max_steps=max_steps, duration=duration, random_seed=random_seed, collect_stats=collect_stats)
535 results = {}
536 self._sa = sa
537 sa.optimise(lambda r: self.State(self.initial_parameters, r, results, self._compute_metric))
538 return results
540 def get_simulated_annealing(self) -> SimulatedAnnealing:
541 return self._sa