Coverage for src/sensai/hyperopt.py: 23%

295 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from dataclasses import dataclass 

2from datetime import datetime 

3from concurrent.futures import ProcessPoolExecutor 

4import uuid 

5 

6import logging 

7import os 

8import pandas as pd 

9from abc import ABC 

10from abc import abstractmethod 

11from random import Random 

12from typing import Dict, Sequence, Any, Callable, Generator, Union, Tuple, List, Optional, Hashable 

13 

14from .evaluation.evaluator import MetricsDictProvider 

15from .local_search import SACostValue, SACostValueNumeric, SAOperator, SAState, SimulatedAnnealing, \ 

16 SAProbabilitySchedule, SAProbabilityFunctionLinear 

17from .tracking.tracking_base import TrackingMixin, TrackedExperiment 

18from .vector_model import VectorModel 

19 

20log = logging.getLogger(__name__) 

21 

22 

23def iter_param_combinations(hyper_param_values: Dict[str, Sequence[Any]]) -> Generator[Dict[str, Any], None, None]: 

24 """ 

25 Create all possible combinations of values from a dictionary of possible parameter values 

26 

27 :param hyper_param_values: a mapping from parameter names to lists of possible values 

28 :return: a dictionary mapping each parameter name to one of the values 

29 """ 

30 pairs = list(hyper_param_values.items()) 

31 

32 def _iter_recursive_param_combinations(pairs, i, params): 

33 """ 

34 Recursive function to create all possible combinations from a list of key-array entries. 

35 :param pairs: a dictionary of parameter names and their corresponding values 

36 :param i: the recursive step 

37 :param params: a dictionary for the iteration results 

38 """ 

39 if i == len(pairs): 

40 yield dict(params) 

41 else: 

42 param_name, param_values = pairs[i] 

43 for paramValue in param_values: 

44 params[param_name] = paramValue 

45 yield from _iter_recursive_param_combinations(pairs, i+1, params) 

46 

47 return _iter_recursive_param_combinations(pairs, 0, {}) 

48 

49 

50class ParameterCombinationSkipDecider(ABC): 

51 """ 

52 Abstraction for a functional component which is told all parameter combinations that have been considered 

53 and can use these as a basis for deciding whether another parameter combination shall be skipped/not be considered. 

54 """ 

55 

56 @abstractmethod 

57 def tell(self, params: Dict[str, Any], metrics: Dict[str, Any]): 

58 """ 

59 Informs the decider about a previously evaluated parameter combination 

60 

61 :param params: the parameter combination 

62 :param metrics: the evaluation metrics 

63 """ 

64 pass 

65 

66 @abstractmethod 

67 def is_skipped(self, params: Dict[str, Any]): 

68 """ 

69 Decides whether the given parameter combination shall be skipped 

70 

71 :param params: 

72 :return: True iff it shall be skipped 

73 """ 

74 pass 

75 

76 

77class ParameterCombinationEquivalenceClassValueCache(ABC): 

78 """ 

79 Represents a cache which stores (arbitrary) values for parameter combinations, i.e. keys in the cache 

80 are derived from parameter combinations. 

81 The cache may map the equivalent parameter combinations to the same keys to indicate that the 

82 parameter combinations are equivalent; the keys thus correspond to representations of equivalence classes over 

83 parameter combinations. 

84 This enables hyper-parameter search to skip the re-computation of results for equivalent parameter combinations. 

85 """ 

86 def __init__(self): 

87 self._cache = {} 

88 

89 @abstractmethod 

90 def _equivalence_class(self, params: Dict[str, Any]) -> Hashable: 

91 """ 

92 Computes a (hashable) equivalence class representation for the given parameter combination. 

93 For instance, if all parameters have influence on the evaluation of a model and no two combinations would  

94 lead to equivalent results, this could simply return a tuple containing all parameter values (in a fixed order). 

95 

96 :param params: the parameter combination 

97 :return: a hashable key containing all the information from the parameter combination that influences the 

98 computation of model evaluation results 

99 """ 

100 pass 

101 

102 def set(self, params: Dict[str, Any], value: Any): 

103 self._cache[self._equivalence_class(params)] = value 

104 

105 def get(self, params: Dict[str, Any]): 

106 """ 

107 Gets the value associated with the (equivalence class of the) parameter combination 

108 :param params: the parameter combination 

109 :return: 

110 """ 

111 return self._cache.get(self._equivalence_class(params)) 

112 

113 

114class ParametersMetricsCollection: 

115 """ 

116 Utility class for holding and persisting evaluation results 

117 """ 

118 def __init__(self, csv_path=None, sort_column_name=None, ascending=True, incremental=False): 

119 """ 

120 :param csv_path: path to save the data frame to upon every update 

121 :param sort_column_name: the column name by which to sort the data frame that is collected; if None, do not sort 

122 :param ascending: whether to sort in ascending order; has an effect only if sortColumnName is not None 

123 :param incremental: whether to add to an existing CSV file instead of overwriting it 

124 """ 

125 self.sort_column_name = sort_column_name 

126 self.csv_path = csv_path 

127 self.ascending = ascending 

128 csv_path_exists = csv_path is not None and os.path.exists(csv_path) 

129 if csv_path_exists and incremental: 

130 self.df = pd.read_csv(csv_path) 

131 log.info(f"Found existing CSV file with {len(self.df)} entries; {csv_path} will be extended (incremental mode)") 

132 self._current_row = len(self.df) 

133 self.value_dicts = [nt._asdict() for nt in self.df.itertuples()] 

134 else: 

135 if csv_path is not None: 

136 if not csv_path_exists: 

137 log.info(f"Results will be written to new file {csv_path}") 

138 else: 

139 log.warning(f"Results in existing file ({csv_path}) will be overwritten (non-incremental mode)") 

140 self.df = None 

141 self._current_row = 0 

142 self.value_dicts = [] 

143 

144 def add_values(self, values: Dict[str, Any]): 

145 """ 

146 Adds the provided values as a new row to the collection. 

147 If csvPath was provided in the constructor, saves the updated collection to that file. 

148 

149 :param values: Dict holding the evaluation results and parameters 

150 :return: 

151 """ 

152 if self.df is None: 

153 cols = list(values.keys()) 

154 

155 # check sort column and move it to the front 

156 if self.sort_column_name is not None: 

157 if self.sort_column_name not in cols: 

158 log.warning(f"Specified sort column '{self.sort_column_name}' not in list of columns: {cols}; " 

159 f"sorting will not take place!") 

160 else: 

161 cols.remove(self.sort_column_name) 

162 cols.insert(0, self.sort_column_name) 

163 

164 self.df = pd.DataFrame(columns=cols) 

165 else: 

166 # check for new columns 

167 for col in values.keys(): 

168 if col not in self.df.columns: 

169 self.df[col] = None 

170 

171 # append data to data frame 

172 self.df.loc[self._current_row] = [values.get(c) for c in self.df.columns] 

173 self._current_row += 1 

174 

175 # sort where applicable 

176 if self.sort_column_name is not None and self.sort_column_name in self.df.columns: 

177 self.df.sort_values(self.sort_column_name, axis=0, inplace=True, ascending=self.ascending) 

178 self.df.reset_index(drop=True, inplace=True) 

179 

180 self._save_csv() 

181 

182 def _save_csv(self): 

183 if self.csv_path is not None: 

184 dirname = os.path.dirname(self.csv_path) 

185 if dirname != "": 

186 os.makedirs(dirname, exist_ok=True) 

187 self.df.to_csv(self.csv_path, index=False) 

188 

189 def get_data_frame(self) -> pd.DataFrame: 

190 return self.df 

191 

192 def contains(self, values: Dict[str, Any]): 

193 for existingValues in self.value_dicts: 

194 is_contained = True 

195 for k, v in values.items(): 

196 ev = existingValues.get(k) 

197 if ev != v and str(ev) != str(v): 

198 is_contained = False 

199 break 

200 if is_contained: 

201 return True 

202 

203 

204class GridSearch(TrackingMixin): 

205 """ 

206 Instances of this class can be used for evaluating models with different user-provided parametrizations 

207 over the same data and persisting the results 

208 """ 

209 log = log.getChild(__qualname__) 

210 

211 def __init__(self, 

212 model_factory: Callable[..., VectorModel], 

213 parameter_options: Union[Dict[str, Sequence[Any]], List[Dict[str, Sequence[Any]]]], 

214 num_processes=1, 

215 csv_results_path: str = None, 

216 incremental=False, 

217 incremental_skip_existing=False, 

218 parameter_combination_skip_decider: ParameterCombinationSkipDecider = None, 

219 model_save_directory: str = None, 

220 name: str = None): 

221 """ 

222 :param model_factory: the function to call with keyword arguments reflecting the parameters to try in order to obtain a model 

223 instance 

224 :param parameter_options: a dictionary which maps from parameter names to lists of possible values - or a list of such dictionaries, 

225 where each dictionary in the list has the same keys 

226 :param num_processes: the number of parallel processes to use for the search (use 1 to run without multi-processing) 

227 :param csv_results_path: the path to a directory or concrete CSV file to which the results shall be written; 

228 if it is None, no CSV data will be written; if it is a directory, a file name starting with this grid search's name (see below) 

229 will be created. 

230 The resulting CSV data will contain one line per evaluated parameter combination. 

231 :param incremental: whether to add to an existing CSV file instead of overwriting it 

232 :param incremental_skip_existing: if incremental mode is on, whether to skip any parameter combinations that are already present 

233 in the CSV file 

234 :param parameter_combination_skip_decider: an instance to which parameters combinations can be passed in order to decide whether the 

235 combination shall be skipped (e.g. because it is redundant/equivalent to another combination or inadmissible) 

236 :param model_save_directory: the directory where the serialized models shall be saved; if None, models are not saved 

237 :param name: the name of this grid search, which will, in particular, be prepended to all saved model files; 

238 if None, a default name will be generated of the form "gridSearch_<timestamp>" 

239 """ 

240 self.model_factory = model_factory 

241 if type(parameter_options) == list: 

242 self.parameter_options_list = parameter_options 

243 else: 

244 self.parameter_options_list = [parameter_options] 

245 self.param_names = set(self.parameter_options_list[0].keys()) 

246 for d in self.parameter_options_list[1:]: 

247 if set(d.keys()) != self.param_names: 

248 raise ValueError("Keys must be the same for all parameter options dictionaries") 

249 self.num_processes = num_processes 

250 self.parameter_combination_skip_decider = parameter_combination_skip_decider 

251 self.model_save_directory = model_save_directory 

252 self.name = name if name is not None else "gridSearch_" + datetime.now().strftime('%Y%m%d-%H%M%S') 

253 self.csv_results_path = csv_results_path 

254 self.incremental = incremental 

255 self.incremental_skip_existing = incremental_skip_existing 

256 if self.csv_results_path is not None and os.path.isdir(csv_results_path): 

257 self.csv_results_path = os.path.join(self.csv_results_path, f"{self.name}_results.csv") 

258 

259 self.num_combinations = 0 

260 for parameter_options in self.parameter_options_list: 

261 n = 1 

262 for options in parameter_options.values(): 

263 n *= len(options) 

264 self.num_combinations += n 

265 log.info(f"Created GridSearch object for {self.num_combinations} parameter combinations") 

266 

267 self._executor = None 

268 

269 @classmethod 

270 def _eval_params(cls, 

271 model_factory: Callable[..., VectorModel], 

272 metrics_evaluator: MetricsDictProvider, 

273 skip_decider: ParameterCombinationSkipDecider, 

274 grid_search_name, combination_idx, 

275 model_save_directory: Optional[str], 

276 **params) -> Optional[Dict[str, Any]]: 

277 if skip_decider is not None: 

278 if skip_decider.is_skipped(params): 

279 cls.log.info(f"Parameter combination is skipped according to {skip_decider}: {params}") 

280 return None 

281 cls.log.info(f"Evaluating {params}") 

282 model = model_factory(**params) 

283 values = metrics_evaluator.compute_metrics(model) 

284 if model_save_directory is not None: 

285 filename = f"{grid_search_name}_{combination_idx}.pickle" 

286 log.info(f"Saving trained model to {filename} ...") 

287 model.save(os.path.join(model_save_directory, filename)) 

288 values["filename"] = filename 

289 values["str(model)"] = str(model) 

290 values.update(**params) 

291 if skip_decider is not None: 

292 skip_decider.tell(params, values) 

293 return values 

294 

295 def run(self, metrics_evaluator: MetricsDictProvider, sort_column_name=None, ascending=True) -> "GridSearch.Result": 

296 """ 

297 Run the grid search. If csvResultsPath was provided in the constructor, each evaluation result will be saved 

298 to that file directly after being computed 

299 

300 :param metrics_evaluator: the evaluator or cross-validator with which to evaluate models 

301 :param sort_column_name: the name of the metric (column) by which to sort the data frame of results; if None, do not sort. 

302 Note that all Metric instances have a static member `name`, e.g. you could use `RegressionMetricMSE.name`. 

303 :param ascending: whether to sort in ascending order; has an effect only if `sort_column_name` is specified. 

304 The result object will assume, by default, that the resulting top/first element is the best, 

305 i.e. ascending=False means "higher is better", and ascending=True means "Lower is better". 

306 :return: an object holding the results 

307 """ 

308 if self.tracked_experiment is not None: 

309 logging_callback = self.tracked_experiment.track_values 

310 elif metrics_evaluator.tracked_experiment is not None: 

311 logging_callback = metrics_evaluator.tracked_experiment.track_values 

312 else: 

313 logging_callback = None 

314 params_metrics_collection = ParametersMetricsCollection(csv_path=self.csv_results_path, sort_column_name=sort_column_name, 

315 ascending=ascending, incremental=self.incremental) 

316 

317 def collect_result(values): 

318 if values is None: 

319 return 

320 if logging_callback is not None: 

321 logging_callback(values) 

322 params_metrics_collection.add_values(values) 

323 log.info(f"Updated grid search result:\n{params_metrics_collection.get_data_frame().to_string()}") 

324 

325 if self.num_processes == 1: 

326 combination_idx = 0 

327 for parameter_options in self.parameter_options_list: 

328 for params_dict in iter_param_combinations(parameter_options): 

329 if self.incremental_skip_existing and self.incremental: 

330 if params_metrics_collection.contains(params_dict): 

331 log.info(f"Skipped because parameters are already present in collection (incremental mode): {params_dict}") 

332 continue 

333 collect_result(self._eval_params(self.model_factory, metrics_evaluator, self.parameter_combination_skip_decider, 

334 self.name, combination_idx, self.model_save_directory, **params_dict)) 

335 combination_idx += 1 

336 else: 

337 executor = ProcessPoolExecutor(max_workers=self.num_processes) 

338 futures = [] 

339 combination_idx = 0 

340 for parameter_options in self.parameter_options_list: 

341 for params_dict in iter_param_combinations(parameter_options): 

342 if self.incremental_skip_existing and self.incremental: 

343 if params_metrics_collection.contains(params_dict): 

344 log.info(f"Skipped because parameters are already present in collection (incremental mode): {params_dict}") 

345 continue 

346 futures.append(executor.submit(self._eval_params, self.model_factory, metrics_evaluator, 

347 self.parameter_combination_skip_decider, 

348 self.name, combination_idx, self.model_save_directory, **params_dict)) 

349 combination_idx += 1 

350 for future in futures: 

351 collect_result(future.result()) 

352 

353 df = params_metrics_collection.get_data_frame() 

354 return self.Result(df, self.param_names, default_metric_name=sort_column_name, default_higher_is_better=not ascending) 

355 

356 class Result: 

357 def __init__(self, df: pd.DataFrame, param_names: List[str], default_metric_name: Optional[str] = None, 

358 default_higher_is_better: Optional[bool] = None): 

359 self.df = df 

360 self.param_names = param_names 

361 self.default_metric_name = default_metric_name 

362 if default_metric_name is not None: 

363 self.default_higher_is_better = default_higher_is_better 

364 else: 

365 self.default_higher_is_better = None 

366 

367 @dataclass 

368 class BestParams: 

369 metric_name: str 

370 metric_value: float 

371 params: dict 

372 

373 def get_best_params(self, metric_name: Optional[str] = None, higher_is_better: Optional[bool] = None) -> BestParams: 

374 """ 

375 

376 :param metric_name: the metric name for which to return the best result; can be None if the GridSearch used 

377 a metric to sort by 

378 :param higher_is_better: whether higher is better for the metric to sort by; can be None if the GridSearch 

379 use a metric to sort by and configured the sort order such that the best configuration is at the top 

380 :return: a pair (d, v) where d dictionary with the best parameters found during the grid search and v is the 

381 corresponding metric value 

382 """ 

383 if metric_name is None: 

384 metric_name = self.default_metric_name 

385 if metric_name is None: 

386 raise ValueError("metric_name must be specified") 

387 if higher_is_better is None: 

388 higher_is_better = self.default_higher_is_better 

389 if higher_is_better is None: 

390 raise ValueError("higher_is_better must be specified") 

391 

392 df = self.df.sort_values(metric_name, axis=0, inplace=False, ascending=not higher_is_better) 

393 

394 best_params = {} 

395 for param in self.param_names: 

396 best_params[param] = df.iloc[0][param] 

397 best_metric_value = df.iloc[0][metric_name] 

398 

399 return self.BestParams(metric_name=metric_name, metric_value=best_metric_value, params=best_params) 

400 

401 

402class SAHyperOpt(TrackingMixin): 

403 log = log.getChild(__qualname__) 

404 

405 class State(SAState): 

406 def __init__(self, params: Dict[str, Any], random_state: Random, results: Dict, compute_metric: Callable[[Dict[str, Any]], float]): 

407 self.compute_metric = compute_metric 

408 self.results = results 

409 self.params = dict(params) 

410 super().__init__(random_state) 

411 

412 def compute_cost_value(self) -> SACostValueNumeric: 

413 return SACostValueNumeric(self.compute_metric(self.params)) 

414 

415 def get_state_representation(self): 

416 return self.params 

417 

418 def apply_state_representation(self, representation): 

419 self.results.update(representation) 

420 

421 class ParameterChangeOperator(SAOperator[State]): 

422 def __init__(self, state: 'SAHyperOpt.State'): 

423 super().__init__(state) 

424 

425 def apply_state_change(self, params): 

426 self.state.params.update(params) 

427 

428 def cost_delta(self, params) -> SACostValue: 

429 model_params = dict(self.state.params) 

430 model_params.update(params) 

431 return SACostValueNumeric(self.state.compute_metric(model_params) - self.state.cost.value()) 

432 

433 def choose_params(self) -> Optional[Tuple[Tuple, Optional[SACostValue]]]: 

434 params = self._choose_changed_model_parameters() 

435 if params is None: 

436 return None 

437 return ((params, ), None) 

438 

439 @abstractmethod 

440 def _choose_changed_model_parameters(self) -> Dict[str, Any]: 

441 pass 

442 

443 def __init__(self, 

444 model_factory: Callable[..., VectorModel], 

445 ops_and_weights: List[Tuple[Callable[['SAHyperOpt.State'], 'SAHyperOpt.ParameterChangeOperator'], float]], 

446 initial_parameters: Dict[str, Any], 

447 metrics_evaluator: MetricsDictProvider, 

448 metric_to_optimise: str, 

449 minimise_metric: bool = False, 

450 collect_data_frame: bool = True, 

451 csv_results_path: Optional[str] = None, 

452 parameter_combination_equivalence_class_value_cache: ParameterCombinationEquivalenceClassValueCache = None, 

453 p0: float = 0.5, 

454 p1: float = 0.0): 

455 """ 

456 :param model_factory: a factory for the generation of models which is called with the current parameter combination 

457 (all keyword arguments), initially initialParameters 

458 :param ops_and_weights: a sequence of tuples (operator factory, operator weight) for simulated annealing 

459 :param initial_parameters: the initial parameter combination 

460 :param metrics_evaluator: the evaluator/validator to use in order to evaluate models 

461 :param metric_to_optimise: the name of the metric (as generated by the evaluator/validator) to optimise 

462 :param minimise_metric: whether the metric is to be minimised; if False, maximise the metric 

463 :param collect_data_frame: whether to collect (and regularly log) the data frame of all parameter combinations and 

464 evaluation results 

465 :param csv_results_path: the (optional) path of a CSV file in which to store a table of all computed results; 

466 if this is not None, then collectDataFrame is automatically set to True 

467 :param parameter_combination_equivalence_class_value_cache: a cache in which to store computed results and whose notion 

468 of equivalence can be used to avoid duplicate computations 

469 :param p0: the initial probability (at the start of the optimisation) of accepting a state with an inferior evaluation 

470 to the current state's (for the mean observed evaluation delta) 

471 :param p1: the final probability (at the end of the optimisation) of accepting a state with an inferior evaluation 

472 to the current state's (for the mean observed evaluation delta) 

473 """ 

474 self.minimise_metric = minimise_metric 

475 self.evaluator_or_validator = metrics_evaluator 

476 self.metric_to_optimise = metric_to_optimise 

477 self.initial_parameters = initial_parameters 

478 self.ops_and_weights = ops_and_weights 

479 self.model_factory = model_factory 

480 self.csv_results_path = csv_results_path 

481 if csv_results_path is not None: 

482 collect_data_frame = True 

483 self.parameters_metrics_collection = ParametersMetricsCollection(csv_path=csv_results_path) if collect_data_frame else None 

484 self.parameter_combination_equivalence_class_value_cache = parameter_combination_equivalence_class_value_cache 

485 self.p0 = p0 

486 self.p1 = p1 

487 self._sa = None 

488 

489 @classmethod 

490 def _eval_params(cls, 

491 model_factory, 

492 metrics_evaluator: MetricsDictProvider, 

493 parameters_metrics_collection: Optional[ParametersMetricsCollection], 

494 parameter_combination_equivalence_class_value_cache, 

495 tracked_experiment: Optional[TrackedExperiment], 

496 **params): 

497 if tracked_experiment is not None and metrics_evaluator.tracked_experiment is not None: 

498 log.warning(f"Tracked experiment already set in evaluator, results will be tracked twice and" 

499 f"might get overwritten!") 

500 

501 metrics = None 

502 if parameter_combination_equivalence_class_value_cache is not None: 

503 metrics = parameter_combination_equivalence_class_value_cache.get(params) 

504 if metrics is not None: 

505 cls.log.info(f"Result for parameter combination {params} could be retrieved from cache, not adding new result") 

506 else: 

507 cls.log.info(f"Evaluating parameter combination {params}") 

508 model = model_factory(**params) 

509 metrics = metrics_evaluator.compute_metrics(model) 

510 cls.log.info(f"Got metrics {metrics} for {params}") 

511 

512 values = dict(metrics) 

513 values["str(model)"] = str(model) 

514 values.update(**params) 

515 if tracked_experiment is not None: 

516 tracked_experiment.track_values(values) 

517 if parameters_metrics_collection is not None: 

518 parameters_metrics_collection.add_values(values) 

519 cls.log.info(f"Data frame with all results:\n\n{parameters_metrics_collection.get_data_frame().to_string()}\n") 

520 if parameter_combination_equivalence_class_value_cache is not None: 

521 parameter_combination_equivalence_class_value_cache.set(params, metrics) 

522 return metrics 

523 

524 def _compute_metric(self, params: Dict[str, Any]): 

525 metrics = self._eval_params(self.model_factory, self.evaluator_or_validator, self.parameters_metrics_collection, 

526 self.parameter_combination_equivalence_class_value_cache, self.tracked_experiment, **params) 

527 metric_value = metrics[self.metric_to_optimise] 

528 if not self.minimise_metric: 

529 return -metric_value 

530 return metric_value 

531 

532 def run(self, max_steps: Optional[int] = None, duration: Optional[float] = None, random_seed: int = 42, collect_stats: bool = True): 

533 sa = SimulatedAnnealing(lambda: SAProbabilitySchedule(None, SAProbabilityFunctionLinear(p0=self.p0, p1=self.p1)), 

534 self.ops_and_weights, max_steps=max_steps, duration=duration, random_seed=random_seed, collect_stats=collect_stats) 

535 results = {} 

536 self._sa = sa 

537 sa.optimise(lambda r: self.State(self.initial_parameters, r, results, self._compute_metric)) 

538 return results 

539 

540 def get_simulated_annealing(self) -> SimulatedAnnealing: 

541 return self._sa