Coverage for src/sensai/util/cache.py: 26%

435 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import atexit 

2import enum 

3import glob 

4import logging 

5import os 

6import pickle 

7import re 

8import sqlite3 

9import threading 

10import time 

11from abc import abstractmethod, ABC 

12from functools import wraps 

13from pathlib import Path 

14from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union 

15 

16from .hash import pickle_hash 

17from .pickle import load_pickle, dump_pickle, setstate 

18 

19log = logging.getLogger(__name__) 

20 

21T = TypeVar("T") 

22TKey = TypeVar("TKey") 

23TValue = TypeVar("TValue") 

24TData = TypeVar("TData") 

25 

26 

27class BoxedValue(Generic[TValue]): 

28 """ 

29 Container for a value, which can be used in caches where values may be None (to differentiate the value not being present in the cache 

30 from the cached value being None) 

31 """ 

32 def __init__(self, value: TValue): 

33 self.value = value 

34 

35 

36class KeyValueCache(Generic[TKey, TValue], ABC): 

37 @abstractmethod 

38 def set(self, key: TKey, value: TValue): 

39 """ 

40 Sets a cached value 

41 

42 :param key: the key under which to store the value 

43 :param value: the value to store; since None is used indicate the absence of a value, None should not be 

44 used a value 

45 """ 

46 pass 

47 

48 @abstractmethod 

49 def get(self, key: TKey) -> Optional[TValue]: 

50 """ 

51 Retrieves a cached value 

52 

53 :param key: the lookup key 

54 :return: the cached value or None if no value is found 

55 """ 

56 pass 

57 

58 

59class InMemoryKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue]): 

60 """A simple in-memory cache (which uses a dictionary internally). 

61 

62 This class can be instantiated directly, but for better typing support, one can instead 

63 inherit from it and provide the types of the key and value as type arguments. For example for 

64 a cache with string keys and integer values: 

65 

66 .. code-block:: python 

67 

68 class MyCache(InMemoryKeyValueCache[str, int]): 

69 pass 

70 """ 

71 def __init__(self): 

72 self.cache = {} 

73 

74 def set(self, key: TKey, value: TValue): 

75 self.cache[key] = value 

76 

77 def get(self, key: TKey) -> Optional[TValue]: 

78 return self.cache.get(key) 

79 

80 def empty(self): 

81 self.cache = {} 

82 

83 def __len__(self): 

84 return len(self.cache) 

85 

86 

87 

88# mainly kept as a marker and for backwards compatibility, but may be extended in the future 

89class PersistentKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue], ABC): 

90 pass 

91 

92 

93class PersistentList(Generic[TValue], ABC): 

94 @abstractmethod 

95 def append(self, item: TValue): 

96 """ 

97 Adds an item to the cache 

98 

99 :param item: the item to store 

100 """ 

101 pass 

102 

103 @abstractmethod 

104 def iter_items(self) -> Iterator[TValue]: 

105 """ 

106 Iterates over the items in the persisted list 

107 

108 :return: generator of item 

109 """ 

110 pass 

111 

112 

113class DelayedUpdateHook: 

114 """ 

115 Ensures that a given function is executed after an update happens, but delay the execution until 

116 there are no further updates for a certain time period 

117 """ 

118 def __init__(self, fn: Callable[[], Any], time_period_secs, periodically_executed_fn: Optional[Callable[[], Any]] = None): 

119 """ 

120 :param fn: the function to eventually call after an update 

121 :param time_period_secs: the time that must pass while not receiving further updates for fn to be called 

122 :param periodically_executed_fn: a function to execute periodically (every timePeriodSecs seconds) in the busy waiting loop, 

123 which may, for example, log information or apply additional executions, which must not interfere with the correctness of 

124 the execution of fn 

125 """ 

126 self.periodicallyExecutedFn = periodically_executed_fn 

127 self.fn = fn 

128 self.timePeriodSecs = time_period_secs 

129 self._lastUpdateTime = None 

130 self._thread = None 

131 self._threadLock = threading.Lock() 

132 

133 def handle_update(self): 

134 """ 

135 Notifies of an update and ensures that the function passed at construction is eventually called 

136 (after no more updates are received within the respective time window) 

137 """ 

138 self._lastUpdateTime = time.time() 

139 

140 def do_periodic_check(): 

141 while True: 

142 time.sleep(self.timePeriodSecs) 

143 time_passed_since_last_update = time.time() - self._lastUpdateTime 

144 if self.periodicallyExecutedFn is not None: 

145 self.periodicallyExecutedFn() 

146 if time_passed_since_last_update >= self.timePeriodSecs: 

147 self.fn() 

148 return 

149 

150 # noinspection DuplicatedCode 

151 if self._thread is None or not self._thread.is_alive(): 

152 self._threadLock.acquire() 

153 if self._thread is None or not self._thread.is_alive(): 

154 self._thread = threading.Thread(target=do_periodic_check, daemon=False) 

155 self._thread.start() 

156 self._threadLock.release() 

157 

158 

159class PeriodicUpdateHook: 

160 """ 

161 Periodically checks whether a function shall be called as a result of an update, the function potentially 

162 being non-atomic (i.e. it may take a long time to execute such that new updates may come in while it is 

163 executing). Two function all mechanisms are in place: 

164 

165 * a function which is called if there has not been a new update for a certain time period (which may be called 

166 several times if updates come in while the function is being executed) 

167 * a function which is called periodically 

168 

169 """ 

170 def __init__(self, check_interval_secs: float, no_update_time_period_secs: float = None, no_update_fn: Callable[[], Any] = None, 

171 periodic_fn: Optional[Callable[[], Any]] = None): 

172 """ 

173 :param check_interval_secs: the time period, in seconds, between checks 

174 :param no_update_time_period_secs: the time period after which to execute noUpdateFn if no further updates have come in. 

175 This must be at least as large as checkIntervalSecs. If None, use checkIntervalSecs. 

176 :param no_update_fn: the function to call if there have been no further updates for noUpdateTimePeriodSecs seconds 

177 :param periodic_fn: a function to execute periodically (every checkIntervalSecs seconds) in the busy waiting loop, 

178 which may, for example, log information or apply additional executions, which must not interfere with the correctness of 

179 the execution of fn 

180 """ 

181 if no_update_time_period_secs is None: 

182 no_update_time_period_secs = check_interval_secs 

183 elif no_update_time_period_secs < check_interval_secs: 

184 raise ValueError("noUpdateTimePeriodSecs must be at least as large as checkIntervalSecs") 

185 self._periodic_fn = periodic_fn 

186 self._check_interval_secs = check_interval_secs 

187 self._no_update_time_period_secs = no_update_time_period_secs 

188 self._no_update_fn = no_update_fn 

189 self._last_update_time = None 

190 self._thread = None 

191 self._thread_lock = threading.Lock() 

192 

193 def handle_update(self): 

194 """ 

195 Notifies of an update, making sure the functions given at construction will be called as specified 

196 """ 

197 self._last_update_time = time.time() 

198 

199 def do_periodic_check(): 

200 while True: 

201 time.sleep(self._check_interval_secs) 

202 check_time = time.time() 

203 if self._periodic_fn is not None: 

204 self._periodic_fn() 

205 time_passed_since_last_update = check_time - self._last_update_time 

206 if time_passed_since_last_update >= self._no_update_time_period_secs: 

207 if self._no_update_fn is not None: 

208 self._no_update_fn() 

209 # if no further updates have come in, we terminate the thread 

210 if self._last_update_time < check_time: 

211 return 

212 

213 # noinspection DuplicatedCode 

214 if self._thread is None or not self._thread.is_alive(): 

215 self._thread_lock.acquire() 

216 if self._thread is None or not self._thread.is_alive(): 

217 self._thread = threading.Thread(target=do_periodic_check, daemon=False) 

218 self._thread.start() 

219 self._thread_lock.release() 

220 

221 

222class PicklePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]): 

223 """ 

224 Represents a key-value cache as a dictionary which is persisted in a file using pickle 

225 """ 

226 def __init__(self, pickle_path, version=1, save_on_update=True, deferred_save_delay_secs=1.0): 

227 """ 

228 :param pickle_path: the path of the file where the cache values are to be persisted 

229 :param version: the version of cache entries. If a persisted cache with a non-matching version is found, 

230 it is discarded 

231 :param save_on_update: whether to persist the cache after an update; the cache is saved in a deferred 

232 manner and will be saved after deferredSaveDelaySecs if no new updates have arrived in the meantime, 

233 i.e. it will ultimately be saved deferredSaveDelaySecs after the latest update 

234 :param deferred_save_delay_secs: the number of seconds to wait for additional data to be added to the cache 

235 before actually storing the cache after a cache update 

236 """ 

237 self.deferred_save_delay_secs = deferred_save_delay_secs 

238 self.pickle_path = pickle_path 

239 self.version = version 

240 self.save_on_update = save_on_update 

241 cache_found = False 

242 if os.path.exists(pickle_path): 

243 try: 

244 log.info(f"Loading cache from {pickle_path}") 

245 persisted_version, self.cache = load_pickle(pickle_path) 

246 if persisted_version == version: 

247 cache_found = True 

248 except EOFError: 

249 log.warning(f"The cache file in {pickle_path} is corrupt") 

250 if not cache_found: 

251 self.cache = {} 

252 self._update_hook = DelayedUpdateHook(self.save, deferred_save_delay_secs) 

253 self._write_lock = threading.Lock() 

254 

255 def save(self): 

256 """ 

257 Saves the cache in the file whose path was provided at construction 

258 """ 

259 with self._write_lock: # avoid concurrent modification while saving 

260 log.info(f"Saving cache to {self.pickle_path}") 

261 dump_pickle((self.version, self.cache), self.pickle_path) 

262 

263 def get(self, key: TKey) -> Optional[TValue]: 

264 return self.cache.get(key) 

265 

266 def set(self, key: TKey, value: TValue): 

267 with self._write_lock: 

268 self.cache[key] = value 

269 if self.save_on_update: 

270 self._update_hook.handle_update() 

271 

272 

273class SlicedPicklePersistentList(PersistentList): 

274 """ 

275 Object handling the creation and access to sliced pickle caches 

276 """ 

277 def __init__(self, directory, pickle_base_name, num_entries_per_slice=100000): 

278 """ 

279 :param directory: path to the directory where the sliced caches are to be stored 

280 :param pickle_base_name: base name for the pickle, where slices will have the names {pickleBaseName}_sliceX.pickle 

281 :param num_entries_per_slice: how many entries should be stored in each cache 

282 """ 

283 self.directory = directory 

284 self.pickleBaseName = pickle_base_name 

285 self.numEntriesPerSlice = num_entries_per_slice 

286 

287 # Set up the variables for the sliced cache 

288 self.slice_id = 0 

289 self.index_in_slice = 0 

290 self.cache_of_slice = [] 

291 

292 # Search directory for already present sliced caches 

293 self.slicedFiles = self._find_sliced_caches() 

294 

295 # Helper variable to ensure object is only modified within a with-clause 

296 self._currentlyInWithClause = False 

297 

298 def __enter__(self): 

299 self._currentlyInWithClause = True 

300 if self.cache_exists(): 

301 # Reset state to enable the appending of more items to the cache 

302 self._set_last_cache_state() 

303 return self 

304 

305 def __exit__(self, exc_type, exc_val, exc_tb): 

306 self._dump() 

307 self._currentlyInWithClause = False 

308 

309 def append(self, item): 

310 """ 

311 Append item to cache 

312 :param item: entry in the cache 

313 """ 

314 if not self._currentlyInWithClause: 

315 raise Exception("Class needs to be instantiated within a with-clause to ensure correct storage") 

316 

317 if (self.index_in_slice + 1) % self.numEntriesPerSlice == 0: 

318 self._dump() 

319 

320 self.cache_of_slice.append(item) 

321 self.index_in_slice += 1 

322 

323 def iter_items(self) -> Iterator[Any]: 

324 """ 

325 Iterate over entries in the sliced cache 

326 :return: iterator over all items in the cache 

327 """ 

328 for filePath in self.slicedFiles: 

329 log.info(f"Loading sliced pickle list from {filePath}") 

330 cached_pickle = self._load_pickle(filePath) 

331 for item in cached_pickle: 

332 yield item 

333 

334 def clear(self): 

335 """ 

336 Clears the cache if it exists 

337 """ 

338 if self.cache_exists(): 

339 for filePath in self.slicedFiles: 

340 os.unlink(filePath) 

341 

342 def cache_exists(self) -> bool: 

343 """ 

344 Does this cache already exist 

345 :return: True if cache exists, False if not 

346 """ 

347 return len(self.slicedFiles) > 0 

348 

349 def _set_last_cache_state(self): 

350 """ 

351 Sets the state so as to be able to add items to an existing cache 

352 """ 

353 log.info("Resetting last state of cache...") 

354 self.slice_id = len(self.slicedFiles) - 1 

355 self.cache_of_slice = self._load_pickle(self._pickle_path(self.slice_id)) 

356 self.index_in_slice = len(self.cache_of_slice) - 1 

357 if self.index_in_slice >= self.numEntriesPerSlice: 

358 self._next_slice() 

359 

360 def _dump(self): 

361 """ 

362 Dumps the current cache (if non-empty) 

363 """ 

364 if len(self.cache_of_slice) > 0: 

365 pickle_path = self._pickle_path(str(self.slice_id)) 

366 log.info(f"Saving sliced cache to {pickle_path}") 

367 dump_pickle(self.cache_of_slice, pickle_path) 

368 self.slicedFiles.append(pickle_path) 

369 

370 # Update slice number and reset indexing and cache 

371 self._next_slice() 

372 else: 

373 log.warning("Unexpected behavior: Dump was called when cache of slice is 0!") 

374 

375 def _next_slice(self): 

376 """ 

377 Updates sliced cache state for the next slice 

378 """ 

379 self.slice_id += 1 

380 self.index_in_slice = 0 

381 self.cache_of_slice = [] 

382 

383 def _find_sliced_caches(self) -> List[str]: 

384 """ 

385 Finds all pickled slices associated with this cache 

386 :return: list of sliced pickled files 

387 """ 

388 # glob.glob permits the usage of unix-style pathnames matching. (below we find all ..._slice*.pickle files) 

389 list_of_file_names = glob.glob(self._pickle_path("*")) 

390 # Sort the slices to ensure it is in the same order as they was produced (regex replaces everything not a number with empty string). 

391 list_of_file_names.sort(key=lambda f: int(re.sub(r'\D', '', f))) 

392 return list_of_file_names 

393 

394 def _load_pickle(self, pickle_path: str) -> List[Any]: 

395 """ 

396 Loads pickle if file path exists, and persisted version is correct. 

397 :param pickle_path: file path 

398 :return: list with objects 

399 """ 

400 cached_pickle = [] 

401 if os.path.exists(pickle_path): 

402 try: 

403 cached_pickle = load_pickle(pickle_path) 

404 except EOFError: 

405 log.warning(f"The cache file in {pickle_path} is corrupt") 

406 else: 

407 raise Exception(f"The file {pickle_path} does not exist!") 

408 return cached_pickle 

409 

410 def _pickle_path(self, slice_suffix) -> str: 

411 return f"{os.path.join(self.directory, self.pickleBaseName)}_slice{slice_suffix}.pickle" 

412 

413 

414class SqliteConnectionManager: 

415 _connections: List[sqlite3.Connection] = [] 

416 _atexit_handler_registered = False 

417 

418 @classmethod 

419 def _register_at_exit_handler(cls): 

420 if not cls._atexit_handler_registered: 

421 cls._atexit_handler_registered = True 

422 atexit.register(cls._cleanup) 

423 

424 @classmethod 

425 def open_connection(cls, path): 

426 cls._register_at_exit_handler() 

427 conn = sqlite3.connect(path, check_same_thread=False) 

428 cls._connections.append(conn) 

429 return conn 

430 

431 @classmethod 

432 def _cleanup(cls): 

433 for conn in cls._connections: 

434 conn.close() 

435 cls._connections = [] 

436 

437 

438class SqlitePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]): 

439 class KeyType(enum.Enum): 

440 STRING = ("VARCHAR(%d)", ) 

441 INTEGER = ("LONG", ) 

442 

443 def __init__(self, path, table_name="cache", deferred_commit_delay_secs=1.0, key_type: KeyType = KeyType.STRING, 

444 max_key_length=255): 

445 """ 

446 :param path: the path to the file that is to hold the SQLite database 

447 :param table_name: the name of the table to create in the database 

448 :param deferred_commit_delay_secs: the time frame during which no new data must be added for a pending transaction to be committed 

449 :param key_type: the type to use for keys; for complex keys (i.e. tuples), use STRING (conversions to string are automatic) 

450 :param max_key_length: the maximum key length for the case where the key_type can be parametrised (e.g. STRING) 

451 """ 

452 self.path = path 

453 self.conn = SqliteConnectionManager.open_connection(path) 

454 self.table_name = table_name 

455 self.max_key_length = 255 

456 self.key_type = key_type 

457 self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs) 

458 self._num_entries_to_be_committed = 0 

459 self._conn_mutex = threading.Lock() 

460 

461 cursor = self.conn.cursor() 

462 cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table';") 

463 if table_name not in [r[0] for r in cursor.fetchall()]: 

464 log.info(f"Creating cache table '{self.table_name}' in {path}") 

465 key_db_type = key_type.value[0] 

466 if "%d" in key_db_type: 

467 key_db_type = key_db_type % max_key_length 

468 cursor.execute(f"CREATE TABLE {table_name} (cache_key {key_db_type} PRIMARY KEY, cache_value BLOB);") 

469 cursor.close() 

470 

471 def _key_db_value(self, key): 

472 if self.key_type == self.KeyType.STRING: 

473 s = str(key) 

474 if len(s) > self.max_key_length: 

475 raise ValueError(f"Key too long, maximal key length is {self.max_key_length}") 

476 return s 

477 elif self.key_type == self.KeyType.INTEGER: 

478 return int(key) 

479 else: 

480 raise Exception(f"Unhandled key type {self.key_type}") 

481 

482 def _commit(self): 

483 self._conn_mutex.acquire() 

484 try: 

485 log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the SQLite database {self.path}") 

486 self.conn.commit() 

487 self._num_entries_to_be_committed = 0 

488 finally: 

489 self._conn_mutex.release() 

490 

491 def set(self, key: TKey, value: TValue): 

492 self._conn_mutex.acquire() 

493 try: 

494 cursor = self.conn.cursor() 

495 key = self._key_db_value(key) 

496 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=?", (key,)) 

497 if cursor.fetchone()[0] == 0: 

498 cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (?, ?)", 

499 (key, pickle.dumps(value))) 

500 else: 

501 cursor.execute(f"UPDATE {self.table_name} SET cache_value=? WHERE cache_key=?", (pickle.dumps(value), key)) 

502 self._num_entries_to_be_committed += 1 

503 cursor.close() 

504 finally: 

505 self._conn_mutex.release() 

506 

507 self._update_hook.handle_update() 

508 

509 def _execute(self, cursor, *query): 

510 try: 

511 cursor.execute(*query) 

512 except sqlite3.DatabaseError as e: 

513 raise Exception(f"Error executing query for {self.path}: {e}") 

514 

515 def get(self, key: TKey) -> Optional[TValue]: 

516 self._conn_mutex.acquire() 

517 try: 

518 cursor = self.conn.cursor() 

519 key = self._key_db_value(key) 

520 self._execute(cursor, f"SELECT cache_value FROM {self.table_name} WHERE cache_key=?", (key,)) 

521 row = cursor.fetchone() 

522 cursor.close() 

523 if row is None: 

524 return None 

525 return pickle.loads(row[0]) 

526 finally: 

527 self._conn_mutex.release() 

528 

529 def __len__(self): 

530 self._conn_mutex.acquire() 

531 try: 

532 cursor = self.conn.cursor() 

533 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") 

534 cnt = cursor.fetchone()[0] 

535 cursor.close() 

536 return cnt 

537 finally: 

538 self._conn_mutex.release() 

539 

540 def iter_items(self): 

541 self._conn_mutex.acquire() 

542 try: 

543 cursor = self.conn.cursor() 

544 cursor.execute(f"SELECT cache_key, cache_value FROM {self.table_name}") 

545 while True: 

546 row = cursor.fetchone() 

547 if row is None: 

548 break 

549 yield row[0], pickle.loads(row[1]) 

550 cursor.close() 

551 finally: 

552 self._conn_mutex.release() 

553 

554 

555class SqlitePersistentList(PersistentList): 

556 def __init__(self, path): 

557 self.keyValueCache = SqlitePersistentKeyValueCache(path, key_type=SqlitePersistentKeyValueCache.KeyType.INTEGER) 

558 self.nextKey = len(self.keyValueCache) 

559 

560 def append(self, item): 

561 self.keyValueCache.set(self.nextKey, item) 

562 self.nextKey += 1 

563 

564 def iter_items(self): 

565 for item in self.keyValueCache.iter_items(): 

566 yield item[1] 

567 

568 

569class CachedValueProviderMixin(Generic[TKey, TValue, TData], ABC): 

570 """ 

571 Represents a value provider that can provide values associated with (hashable) keys via a cache or, if 

572 cached values are not yet present, by computing them. 

573 """ 

574 def __init__(self, cache: Optional[KeyValueCache[TKey, TValue]] = None, 

575 cache_factory: Optional[Callable[[], KeyValueCache[TKey, TValue]]] = None, persist_cache=False, box_values=False): 

576 """ 

577 :param cache: the cache to use or None. If None, caching will be disabled 

578 :param cache_factory: a factory with which to create the cache (or recreate it after unpickling if `persistCache` is False, in which 

579 case this factory must be picklable) 

580 :param persist_cache: whether to persist the cache when pickling 

581 :param box_values: whether to box values, such that None is admissible as a value 

582 """ 

583 self._persistCache = persist_cache 

584 self._boxValues = box_values 

585 self._cache = cache 

586 self._cacheFactory = cache_factory 

587 if self._cache is None and cache_factory is not None: 

588 self._cache = cache_factory() 

589 

590 def __getstate__(self): 

591 if not self._persistCache: 

592 d = self.__dict__.copy() 

593 d["_cache"] = None 

594 return d 

595 return self.__dict__ 

596 

597 def __setstate__(self, state): 

598 setstate(CachedValueProviderMixin, self, state, renamed_properties={"persistCache": "_persistCache"}) 

599 if not self._persistCache and self._cacheFactory is not None: 

600 self._cache = self._cacheFactory() 

601 

602 def _provide_value(self, key, data: Optional[TData] = None): 

603 """ 

604 Provides the value for the key by retrieving the associated value from the cache or, if no entry in the 

605 cache is found, by computing the value via _computeValue 

606 

607 :param key: the key for which to provide the value 

608 :param data: optional data required to compute a value 

609 :return: the retrieved or computed value 

610 """ 

611 if self._cache is None: 

612 return self._compute_value(key, data) 

613 value = self._cache.get(key) 

614 if value is None: 

615 value = self._compute_value(key, data) 

616 self._cache.set(key, value if not self._boxValues else BoxedValue(value)) 

617 else: 

618 if self._boxValues: 

619 value: BoxedValue[TValue] 

620 value = value.value 

621 return value 

622 

623 @abstractmethod 

624 def _compute_value(self, key: TKey, data: Optional[TData]) -> TValue: 

625 """ 

626 Computes the value for the given key 

627 

628 :param key: the key for which to compute the value 

629 :return: the computed value 

630 """ 

631 pass 

632 

633 

634def cached(fn: Callable[[], T], pickle_path, function_name=None, validity_check_fn: Optional[Callable[[T], bool]] = None, 

635 backend="pickle", protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None) -> T: 

636 """ 

637 Calls the given function unless its result is already cached (in a pickle), in which case it will read the cached result 

638 and return it. 

639 

640 Rather than directly calling this function, consider using the decorator variant :func:`pickle_cached`. 

641 

642 :param fn: the function whose result is to be cached 

643 :param pickle_path: the path in which to store the cached result 

644 :param function_name: the name of the function fn (for the case where its __name__ attribute is not 

645 informative) 

646 :param validity_check_fn: an optional function to call in order to check whether a cached result is still valid; 

647 the function shall return True if the result is still valid and false otherwise. If a cached result is invalid, 

648 the function fn is called to compute the result and the cached result is updated. 

649 :param backend: pickle or joblib 

650 :param protocol: the pickle protocol version 

651 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result 

652 :param version: if not None, previously persisted data will only be returned if it was stored with the same version 

653 :return: the result (either obtained from the cache or the function) 

654 """ 

655 if function_name is None: 

656 function_name = fn.__name__ 

657 

658 def call_fn_and_cache_result(): 

659 res = fn() 

660 log.info(f"Saving cached result in {pickle_path}") 

661 if version is not None: 

662 persisted_res = {"__cacheVersion": version, "obj": res} 

663 else: 

664 persisted_res = res 

665 dump_pickle(persisted_res, pickle_path, backend=backend, protocol=protocol) 

666 return res 

667 

668 if os.path.exists(pickle_path): 

669 if load: 

670 log.info(f"Loading cached result of function '{function_name}' from {pickle_path}") 

671 result = load_pickle(pickle_path, backend=backend) 

672 if validity_check_fn is not None: 

673 if not validity_check_fn(result): 

674 log.info(f"Cached result is no longer valid, recomputing ...") 

675 result = call_fn_and_cache_result() 

676 if version is not None: 

677 cached_version = None 

678 if type(result) == dict: 

679 cached_version = result.get("__cacheVersion") 

680 if cached_version != version: 

681 log.info(f"Cached result has incorrect version ({cached_version}, expected {version}), recomputing ...") 

682 result = call_fn_and_cache_result() 

683 else: 

684 result = result["obj"] 

685 return result 

686 else: 

687 log.info(f"Ignoring previously stored result in {pickle_path}, calling function '{function_name}' ...") 

688 return call_fn_and_cache_result() 

689 else: 

690 log.info(f"No cached result found in {pickle_path}, calling function '{function_name}' ...") 

691 return call_fn_and_cache_result() 

692 

693 

694def pickle_cached(cache_base_path: str, filename_prefix: str = None, filename: str = None, backend="pickle", 

695 protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None): 

696 """ 

697 Function decorator for caching function results via pickle. 

698 

699 Add this decorator to any function to cache its results in pickle files. 

700 The function may have arguments, in which case the cache will be specific to the actual arguments 

701 by computing a hash code from their pickled representation. 

702 

703 :param cache_base_path: the directory where the pickle cache file will be stored 

704 :param filename_prefix: a prefix of the name of the cache file to be created, to which the function name and, where applicable, 

705 a hash code of the function arguments as well as the extension ".cache.pickle" will be appended. 

706 The prefix need not end in a separator, as "-" will automatically be added between filename components. 

707 :param filename: the full file name of the cache file to be created; if the function takes arguments, the filename must 

708 contain a placeholder '%s' for the argument hash 

709 :param backend: the serialisation backend to use (see dumpPickle) 

710 :param protocol: the pickle protocol version to use 

711 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result 

712 :param version: if not None, previously persisted data will only be returned if it was stored with the same version 

713 """ 

714 os.makedirs(cache_base_path, exist_ok=True) 

715 

716 if filename_prefix is None: 

717 filename_prefix = "" 

718 else: 

719 filename_prefix += "-" 

720 

721 def decorator(fn: Callable, *_args, **_kwargs): 

722 

723 @wraps(fn) 

724 def wrapped(*args, **kwargs): 

725 hash_code_str = None 

726 have_args = len(args) > 0 or len(kwargs) > 0 

727 if have_args: 

728 hash_code_str = pickle_hash((args, kwargs)) 

729 if filename is None: 

730 pickle_filename = filename_prefix + fn.__qualname__.replace(".<locals>.", ".") 

731 if hash_code_str is not None: 

732 pickle_filename += "-" + hash_code_str 

733 pickle_filename += ".cache.pickle" 

734 else: 

735 if hash_code_str is not None: 

736 if "%s" not in filename: 

737 raise Exception("Function called with arguments but full cache filename contains no placeholder (%s) " 

738 "for argument hash") 

739 pickle_filename = filename % hash_code_str 

740 else: 

741 if "%s" in filename: 

742 raise Exception("Function without arguments but full cache filename with placeholder (%s) was specified") 

743 pickle_filename = filename 

744 pickle_path = os.path.join(cache_base_path, pickle_filename) 

745 return cached(lambda: fn(*args, **kwargs), pickle_path, function_name=fn.__name__, backend=backend, load=load, 

746 version=version, protocol=protocol) 

747 

748 return wrapped 

749 

750 return decorator 

751 

752 

753PickleCached = pickle_cached # for backward compatibility 

754 

755 

756class LoadSaveInterface(ABC): 

757 @abstractmethod 

758 def save(self, path: str) -> None: 

759 pass 

760 

761 @classmethod 

762 @abstractmethod 

763 def load(cls: T, path: str) -> T: 

764 pass 

765 

766 

767class PickleLoadSaveMixin(LoadSaveInterface): 

768 def save(self, path: Union[str, Path], backend="pickle"): 

769 """ 

770 Saves the instance as pickle 

771 

772 :param path: 

773 :param backend: pickle, cloudpickle, or joblib 

774 """ 

775 dump_pickle(self, path, backend=backend) 

776 

777 @classmethod 

778 def load(cls, path: Union[str, Path], backend="pickle"): 

779 """ 

780 Loads a class instance from pickle 

781 

782 :param path: 

783 :param backend: pickle, cloudpickle, or joblib 

784 :return: instance of the present class 

785 """ 

786 log.info(f"Loading instance of {cls} from {path}") 

787 result = load_pickle(path, backend=backend) 

788 if not isinstance(result, cls): 

789 raise Exception(f"Excepted instance of {cls}, instead got: {result.__class__.__name__}") 

790 return result