Coverage for src/sensai/util/cache.py: 27%

456 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import atexit 

2import enum 

3import glob 

4import logging 

5import os 

6import pickle 

7import re 

8import sqlite3 

9import threading 

10import time 

11from abc import abstractmethod, ABC 

12from collections import OrderedDict 

13from functools import wraps 

14from pathlib import Path 

15from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union, Hashable 

16 

17from .hash import pickle_hash 

18from .pickle import load_pickle, dump_pickle, setstate 

19 

20log = logging.getLogger(__name__) 

21 

22T = TypeVar("T") 

23TKey = TypeVar("TKey") 

24THashableKey = TypeVar("THashableKey", bound=Hashable) 

25TValue = TypeVar("TValue") 

26TData = TypeVar("TData") 

27 

28 

29class BoxedValue(Generic[TValue]): 

30 """ 

31 Container for a value, which can be used in caches where values may be None (to differentiate the value not being present in the cache 

32 from the cached value being None) 

33 """ 

34 def __init__(self, value: TValue): 

35 self.value = value 

36 

37 

38class KeyValueCache(Generic[TKey, TValue], ABC): 

39 @abstractmethod 

40 def set(self, key: TKey, value: TValue): 

41 """ 

42 Sets a cached value 

43 

44 :param key: the key under which to store the value 

45 :param value: the value to store; since None is used indicate the absence of a value, None should not be 

46 used a value 

47 """ 

48 pass 

49 

50 @abstractmethod 

51 def get(self, key: TKey) -> Optional[TValue]: 

52 """ 

53 Retrieves a cached value 

54 

55 :param key: the lookup key 

56 :return: the cached value or None if no value is found 

57 """ 

58 pass 

59 

60 

61class InMemoryKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue]): 

62 """A simple in-memory cache (which uses a dictionary internally). 

63 

64 This class can be instantiated directly, but for better typing support, one can instead 

65 inherit from it and provide the types of the key and value as type arguments. For example for 

66 a cache with string keys and integer values: 

67 

68 .. code-block:: python 

69 

70 class MyCache(InMemoryKeyValueCache[str, int]): 

71 pass 

72 """ 

73 def __init__(self): 

74 self.cache = {} 

75 

76 def set(self, key: TKey, value: TValue): 

77 self.cache[key] = value 

78 

79 def get(self, key: TKey) -> Optional[TValue]: 

80 return self.cache.get(key) 

81 

82 def empty(self): 

83 self.cache = {} 

84 

85 def __len__(self): 

86 return len(self.cache) 

87 

88 

89 

90# mainly kept as a marker and for backwards compatibility, but may be extended in the future 

91class PersistentKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue], ABC): 

92 pass 

93 

94 

95class PersistentList(Generic[TValue], ABC): 

96 @abstractmethod 

97 def append(self, item: TValue): 

98 """ 

99 Adds an item to the cache 

100 

101 :param item: the item to store 

102 """ 

103 pass 

104 

105 @abstractmethod 

106 def iter_items(self) -> Iterator[TValue]: 

107 """ 

108 Iterates over the items in the persisted list 

109 

110 :return: generator of item 

111 """ 

112 pass 

113 

114 

115class DelayedUpdateHook: 

116 """ 

117 Ensures that a given function is executed after an update happens, but delay the execution until 

118 there are no further updates for a certain time period 

119 """ 

120 def __init__(self, fn: Callable[[], Any], time_period_secs, periodically_executed_fn: Optional[Callable[[], Any]] = None): 

121 """ 

122 :param fn: the function to eventually call after an update 

123 :param time_period_secs: the time that must pass while not receiving further updates for fn to be called 

124 :param periodically_executed_fn: a function to execute periodically (every timePeriodSecs seconds) in the busy waiting loop, 

125 which may, for example, log information or apply additional executions, which must not interfere with the correctness of 

126 the execution of fn 

127 """ 

128 self.periodicallyExecutedFn = periodically_executed_fn 

129 self.fn = fn 

130 self.timePeriodSecs = time_period_secs 

131 self._lastUpdateTime = None 

132 self._thread = None 

133 self._threadLock = threading.Lock() 

134 

135 def handle_update(self): 

136 """ 

137 Notifies of an update and ensures that the function passed at construction is eventually called 

138 (after no more updates are received within the respective time window) 

139 """ 

140 self._lastUpdateTime = time.time() 

141 

142 def do_periodic_check(): 

143 while True: 

144 time.sleep(self.timePeriodSecs) 

145 time_passed_since_last_update = time.time() - self._lastUpdateTime 

146 if self.periodicallyExecutedFn is not None: 

147 self.periodicallyExecutedFn() 

148 if time_passed_since_last_update >= self.timePeriodSecs: 

149 self.fn() 

150 return 

151 

152 # noinspection DuplicatedCode 

153 if self._thread is None or not self._thread.is_alive(): 

154 self._threadLock.acquire() 

155 if self._thread is None or not self._thread.is_alive(): 

156 self._thread = threading.Thread(target=do_periodic_check, daemon=False) 

157 self._thread.start() 

158 self._threadLock.release() 

159 

160 

161class PeriodicUpdateHook: 

162 """ 

163 Periodically checks whether a function shall be called as a result of an update, the function potentially 

164 being non-atomic (i.e. it may take a long time to execute such that new updates may come in while it is 

165 executing). Two function all mechanisms are in place: 

166 

167 * a function which is called if there has not been a new update for a certain time period (which may be called 

168 several times if updates come in while the function is being executed) 

169 * a function which is called periodically 

170 

171 """ 

172 def __init__(self, check_interval_secs: float, no_update_time_period_secs: float = None, no_update_fn: Callable[[], Any] = None, 

173 periodic_fn: Optional[Callable[[], Any]] = None): 

174 """ 

175 :param check_interval_secs: the time period, in seconds, between checks 

176 :param no_update_time_period_secs: the time period after which to execute noUpdateFn if no further updates have come in. 

177 This must be at least as large as checkIntervalSecs. If None, use checkIntervalSecs. 

178 :param no_update_fn: the function to call if there have been no further updates for noUpdateTimePeriodSecs seconds 

179 :param periodic_fn: a function to execute periodically (every checkIntervalSecs seconds) in the busy waiting loop, 

180 which may, for example, log information or apply additional executions, which must not interfere with the correctness of 

181 the execution of fn 

182 """ 

183 if no_update_time_period_secs is None: 

184 no_update_time_period_secs = check_interval_secs 

185 elif no_update_time_period_secs < check_interval_secs: 

186 raise ValueError("noUpdateTimePeriodSecs must be at least as large as checkIntervalSecs") 

187 self._periodic_fn = periodic_fn 

188 self._check_interval_secs = check_interval_secs 

189 self._no_update_time_period_secs = no_update_time_period_secs 

190 self._no_update_fn = no_update_fn 

191 self._last_update_time = None 

192 self._thread = None 

193 self._thread_lock = threading.Lock() 

194 

195 def handle_update(self): 

196 """ 

197 Notifies of an update, making sure the functions given at construction will be called as specified 

198 """ 

199 self._last_update_time = time.time() 

200 

201 def do_periodic_check(): 

202 while True: 

203 time.sleep(self._check_interval_secs) 

204 check_time = time.time() 

205 if self._periodic_fn is not None: 

206 self._periodic_fn() 

207 time_passed_since_last_update = check_time - self._last_update_time 

208 if time_passed_since_last_update >= self._no_update_time_period_secs: 

209 if self._no_update_fn is not None: 

210 self._no_update_fn() 

211 # if no further updates have come in, we terminate the thread 

212 if self._last_update_time < check_time: 

213 return 

214 

215 # noinspection DuplicatedCode 

216 if self._thread is None or not self._thread.is_alive(): 

217 self._thread_lock.acquire() 

218 if self._thread is None or not self._thread.is_alive(): 

219 self._thread = threading.Thread(target=do_periodic_check, daemon=False) 

220 self._thread.start() 

221 self._thread_lock.release() 

222 

223 

224class PicklePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]): 

225 """ 

226 Represents a key-value cache as a dictionary which is persisted in a file using pickle 

227 """ 

228 def __init__(self, pickle_path, version=1, save_on_update=True, deferred_save_delay_secs=1.0): 

229 """ 

230 :param pickle_path: the path of the file where the cache values are to be persisted 

231 :param version: the version of cache entries. If a persisted cache with a non-matching version is found, 

232 it is discarded 

233 :param save_on_update: whether to persist the cache after an update; the cache is saved in a deferred 

234 manner and will be saved after deferredSaveDelaySecs if no new updates have arrived in the meantime, 

235 i.e. it will ultimately be saved deferredSaveDelaySecs after the latest update 

236 :param deferred_save_delay_secs: the number of seconds to wait for additional data to be added to the cache 

237 before actually storing the cache after a cache update 

238 """ 

239 self.deferred_save_delay_secs = deferred_save_delay_secs 

240 self.pickle_path = pickle_path 

241 self.version = version 

242 self.save_on_update = save_on_update 

243 cache_found = False 

244 if os.path.exists(pickle_path): 

245 try: 

246 log.info(f"Loading cache from {pickle_path}") 

247 persisted_version, self.cache = load_pickle(pickle_path) 

248 if persisted_version == version: 

249 cache_found = True 

250 except EOFError: 

251 log.warning(f"The cache file in {pickle_path} is corrupt") 

252 if not cache_found: 

253 self.cache = {} 

254 self._update_hook = DelayedUpdateHook(self.save, deferred_save_delay_secs) 

255 self._write_lock = threading.Lock() 

256 

257 def save(self): 

258 """ 

259 Saves the cache in the file whose path was provided at construction 

260 """ 

261 with self._write_lock: # avoid concurrent modification while saving 

262 log.info(f"Saving cache to {self.pickle_path}") 

263 dump_pickle((self.version, self.cache), self.pickle_path) 

264 

265 def get(self, key: TKey) -> Optional[TValue]: 

266 return self.cache.get(key) 

267 

268 def set(self, key: TKey, value: TValue): 

269 with self._write_lock: 

270 self.cache[key] = value 

271 if self.save_on_update: 

272 self._update_hook.handle_update() 

273 

274 

275class SlicedPicklePersistentList(PersistentList): 

276 """ 

277 Object handling the creation and access to sliced pickle caches 

278 """ 

279 def __init__(self, directory, pickle_base_name, num_entries_per_slice=100000): 

280 """ 

281 :param directory: path to the directory where the sliced caches are to be stored 

282 :param pickle_base_name: base name for the pickle, where slices will have the names {pickleBaseName}_sliceX.pickle 

283 :param num_entries_per_slice: how many entries should be stored in each cache 

284 """ 

285 self.directory = directory 

286 self.pickleBaseName = pickle_base_name 

287 self.numEntriesPerSlice = num_entries_per_slice 

288 

289 # Set up the variables for the sliced cache 

290 self.slice_id = 0 

291 self.index_in_slice = 0 

292 self.cache_of_slice = [] 

293 

294 # Search directory for already present sliced caches 

295 self.slicedFiles = self._find_sliced_caches() 

296 

297 # Helper variable to ensure object is only modified within a with-clause 

298 self._currentlyInWithClause = False 

299 

300 def __enter__(self): 

301 self._currentlyInWithClause = True 

302 if self.cache_exists(): 

303 # Reset state to enable the appending of more items to the cache 

304 self._set_last_cache_state() 

305 return self 

306 

307 def __exit__(self, exc_type, exc_val, exc_tb): 

308 self._dump() 

309 self._currentlyInWithClause = False 

310 

311 def append(self, item): 

312 """ 

313 Append item to cache 

314 :param item: entry in the cache 

315 """ 

316 if not self._currentlyInWithClause: 

317 raise Exception("Class needs to be instantiated within a with-clause to ensure correct storage") 

318 

319 if (self.index_in_slice + 1) % self.numEntriesPerSlice == 0: 

320 self._dump() 

321 

322 self.cache_of_slice.append(item) 

323 self.index_in_slice += 1 

324 

325 def iter_items(self) -> Iterator[Any]: 

326 """ 

327 Iterate over entries in the sliced cache 

328 :return: iterator over all items in the cache 

329 """ 

330 for filePath in self.slicedFiles: 

331 log.info(f"Loading sliced pickle list from {filePath}") 

332 cached_pickle = self._load_pickle(filePath) 

333 for item in cached_pickle: 

334 yield item 

335 

336 def clear(self): 

337 """ 

338 Clears the cache if it exists 

339 """ 

340 if self.cache_exists(): 

341 for filePath in self.slicedFiles: 

342 os.unlink(filePath) 

343 

344 def cache_exists(self) -> bool: 

345 """ 

346 Does this cache already exist 

347 :return: True if cache exists, False if not 

348 """ 

349 return len(self.slicedFiles) > 0 

350 

351 def _set_last_cache_state(self): 

352 """ 

353 Sets the state so as to be able to add items to an existing cache 

354 """ 

355 log.info("Resetting last state of cache...") 

356 self.slice_id = len(self.slicedFiles) - 1 

357 self.cache_of_slice = self._load_pickle(self._pickle_path(self.slice_id)) 

358 self.index_in_slice = len(self.cache_of_slice) - 1 

359 if self.index_in_slice >= self.numEntriesPerSlice: 

360 self._next_slice() 

361 

362 def _dump(self): 

363 """ 

364 Dumps the current cache (if non-empty) 

365 """ 

366 if len(self.cache_of_slice) > 0: 

367 pickle_path = self._pickle_path(str(self.slice_id)) 

368 log.info(f"Saving sliced cache to {pickle_path}") 

369 dump_pickle(self.cache_of_slice, pickle_path) 

370 self.slicedFiles.append(pickle_path) 

371 

372 # Update slice number and reset indexing and cache 

373 self._next_slice() 

374 else: 

375 log.warning("Unexpected behavior: Dump was called when cache of slice is 0!") 

376 

377 def _next_slice(self): 

378 """ 

379 Updates sliced cache state for the next slice 

380 """ 

381 self.slice_id += 1 

382 self.index_in_slice = 0 

383 self.cache_of_slice = [] 

384 

385 def _find_sliced_caches(self) -> List[str]: 

386 """ 

387 Finds all pickled slices associated with this cache 

388 :return: list of sliced pickled files 

389 """ 

390 # glob.glob permits the usage of unix-style pathnames matching. (below we find all ..._slice*.pickle files) 

391 list_of_file_names = glob.glob(self._pickle_path("*")) 

392 # Sort the slices to ensure it is in the same order as they was produced (regex replaces everything not a number with empty string). 

393 list_of_file_names.sort(key=lambda f: int(re.sub(r'\D', '', f))) 

394 return list_of_file_names 

395 

396 def _load_pickle(self, pickle_path: str) -> List[Any]: 

397 """ 

398 Loads pickle if file path exists, and persisted version is correct. 

399 :param pickle_path: file path 

400 :return: list with objects 

401 """ 

402 cached_pickle = [] 

403 if os.path.exists(pickle_path): 

404 try: 

405 cached_pickle = load_pickle(pickle_path) 

406 except EOFError: 

407 log.warning(f"The cache file in {pickle_path} is corrupt") 

408 else: 

409 raise Exception(f"The file {pickle_path} does not exist!") 

410 return cached_pickle 

411 

412 def _pickle_path(self, slice_suffix) -> str: 

413 return f"{os.path.join(self.directory, self.pickleBaseName)}_slice{slice_suffix}.pickle" 

414 

415 

416class SqliteConnectionManager: 

417 _connections: List[sqlite3.Connection] = [] 

418 _atexit_handler_registered = False 

419 

420 @classmethod 

421 def _register_at_exit_handler(cls): 

422 if not cls._atexit_handler_registered: 

423 cls._atexit_handler_registered = True 

424 atexit.register(cls._cleanup) 

425 

426 @classmethod 

427 def open_connection(cls, path): 

428 cls._register_at_exit_handler() 

429 conn = sqlite3.connect(path, check_same_thread=False) 

430 cls._connections.append(conn) 

431 return conn 

432 

433 @classmethod 

434 def _cleanup(cls): 

435 for conn in cls._connections: 

436 conn.close() 

437 cls._connections = [] 

438 

439 

440class SqlitePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]): 

441 class KeyType(enum.Enum): 

442 STRING = ("VARCHAR(%d)", ) 

443 INTEGER = ("LONG", ) 

444 

445 def __init__(self, path, table_name="cache", deferred_commit_delay_secs=1.0, key_type: KeyType = KeyType.STRING, 

446 max_key_length=255): 

447 """ 

448 :param path: the path to the file that is to hold the SQLite database 

449 :param table_name: the name of the table to create in the database 

450 :param deferred_commit_delay_secs: the time frame during which no new data must be added for a pending transaction to be committed 

451 :param key_type: the type to use for keys; for complex keys (i.e. tuples), use STRING (conversions to string are automatic) 

452 :param max_key_length: the maximum key length for the case where the key_type can be parametrised (e.g. STRING) 

453 """ 

454 self.path = path 

455 self.conn = SqliteConnectionManager.open_connection(path) 

456 self.table_name = table_name 

457 self.max_key_length = 255 

458 self.key_type = key_type 

459 self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs) 

460 self._num_entries_to_be_committed = 0 

461 self._conn_mutex = threading.Lock() 

462 

463 cursor = self.conn.cursor() 

464 cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table';") 

465 if table_name not in [r[0] for r in cursor.fetchall()]: 

466 log.info(f"Creating cache table '{self.table_name}' in {path}") 

467 key_db_type = key_type.value[0] 

468 if "%d" in key_db_type: 

469 key_db_type = key_db_type % max_key_length 

470 cursor.execute(f"CREATE TABLE {table_name} (cache_key {key_db_type} PRIMARY KEY, cache_value BLOB);") 

471 cursor.close() 

472 

473 def _key_db_value(self, key): 

474 if self.key_type == self.KeyType.STRING: 

475 s = str(key) 

476 if len(s) > self.max_key_length: 

477 raise ValueError(f"Key too long, maximal key length is {self.max_key_length}") 

478 return s 

479 elif self.key_type == self.KeyType.INTEGER: 

480 return int(key) 

481 else: 

482 raise Exception(f"Unhandled key type {self.key_type}") 

483 

484 def _commit(self): 

485 self._conn_mutex.acquire() 

486 try: 

487 log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the SQLite database {self.path}") 

488 self.conn.commit() 

489 self._num_entries_to_be_committed = 0 

490 finally: 

491 self._conn_mutex.release() 

492 

493 def set(self, key: TKey, value: TValue): 

494 self._conn_mutex.acquire() 

495 try: 

496 cursor = self.conn.cursor() 

497 key = self._key_db_value(key) 

498 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=?", (key,)) 

499 if cursor.fetchone()[0] == 0: 

500 cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (?, ?)", 

501 (key, pickle.dumps(value))) 

502 else: 

503 cursor.execute(f"UPDATE {self.table_name} SET cache_value=? WHERE cache_key=?", (pickle.dumps(value), key)) 

504 self._num_entries_to_be_committed += 1 

505 cursor.close() 

506 finally: 

507 self._conn_mutex.release() 

508 

509 self._update_hook.handle_update() 

510 

511 def _execute(self, cursor, *query): 

512 try: 

513 cursor.execute(*query) 

514 except sqlite3.DatabaseError as e: 

515 raise Exception(f"Error executing query for {self.path}: {e}") 

516 

517 def get(self, key: TKey) -> Optional[TValue]: 

518 self._conn_mutex.acquire() 

519 try: 

520 cursor = self.conn.cursor() 

521 key = self._key_db_value(key) 

522 self._execute(cursor, f"SELECT cache_value FROM {self.table_name} WHERE cache_key=?", (key,)) 

523 row = cursor.fetchone() 

524 cursor.close() 

525 if row is None: 

526 return None 

527 return pickle.loads(row[0]) 

528 finally: 

529 self._conn_mutex.release() 

530 

531 def __len__(self): 

532 self._conn_mutex.acquire() 

533 try: 

534 cursor = self.conn.cursor() 

535 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") 

536 cnt = cursor.fetchone()[0] 

537 cursor.close() 

538 return cnt 

539 finally: 

540 self._conn_mutex.release() 

541 

542 def iter_items(self): 

543 self._conn_mutex.acquire() 

544 try: 

545 cursor = self.conn.cursor() 

546 cursor.execute(f"SELECT cache_key, cache_value FROM {self.table_name}") 

547 while True: 

548 row = cursor.fetchone() 

549 if row is None: 

550 break 

551 yield row[0], pickle.loads(row[1]) 

552 cursor.close() 

553 finally: 

554 self._conn_mutex.release() 

555 

556 

557class SqlitePersistentList(PersistentList): 

558 def __init__(self, path): 

559 self.keyValueCache = SqlitePersistentKeyValueCache(path, key_type=SqlitePersistentKeyValueCache.KeyType.INTEGER) 

560 self.nextKey = len(self.keyValueCache) 

561 

562 def append(self, item): 

563 self.keyValueCache.set(self.nextKey, item) 

564 self.nextKey += 1 

565 

566 def iter_items(self): 

567 for item in self.keyValueCache.iter_items(): 

568 yield item[1] 

569 

570 

571class CachedValueProviderMixin(Generic[TKey, TValue, TData], ABC): 

572 """ 

573 Represents a value provider that can provide values associated with (hashable) keys via a cache or, if 

574 cached values are not yet present, by computing them. 

575 """ 

576 def __init__(self, cache: Optional[KeyValueCache[TKey, TValue]] = None, 

577 cache_factory: Optional[Callable[[], KeyValueCache[TKey, TValue]]] = None, persist_cache=False, box_values=False): 

578 """ 

579 :param cache: the cache to use or None. If None, caching will be disabled 

580 :param cache_factory: a factory with which to create the cache (or recreate it after unpickling if `persistCache` is False, in which 

581 case this factory must be picklable) 

582 :param persist_cache: whether to persist the cache when pickling 

583 :param box_values: whether to box values, such that None is admissible as a value 

584 """ 

585 self._persistCache = persist_cache 

586 self._boxValues = box_values 

587 self._cache = cache 

588 self._cacheFactory = cache_factory 

589 if self._cache is None and cache_factory is not None: 

590 self._cache = cache_factory() 

591 

592 def __getstate__(self): 

593 if not self._persistCache: 

594 d = self.__dict__.copy() 

595 d["_cache"] = None 

596 return d 

597 return self.__dict__ 

598 

599 def __setstate__(self, state): 

600 setstate(CachedValueProviderMixin, self, state, renamed_properties={"persistCache": "_persistCache"}) 

601 if not self._persistCache and self._cacheFactory is not None: 

602 self._cache = self._cacheFactory() 

603 

604 def _provide_value(self, key, data: Optional[TData] = None): 

605 """ 

606 Provides the value for the key by retrieving the associated value from the cache or, if no entry in the 

607 cache is found, by computing the value via _computeValue 

608 

609 :param key: the key for which to provide the value 

610 :param data: optional data required to compute a value 

611 :return: the retrieved or computed value 

612 """ 

613 if self._cache is None: 

614 return self._compute_value(key, data) 

615 value = self._cache.get(key) 

616 if value is None: 

617 value = self._compute_value(key, data) 

618 self._cache.set(key, value if not self._boxValues else BoxedValue(value)) 

619 else: 

620 if self._boxValues: 

621 value: BoxedValue[TValue] 

622 value = value.value 

623 return value 

624 

625 @abstractmethod 

626 def _compute_value(self, key: TKey, data: Optional[TData]) -> TValue: 

627 """ 

628 Computes the value for the given key 

629 

630 :param key: the key for which to compute the value 

631 :return: the computed value 

632 """ 

633 pass 

634 

635 

636def cached(fn: Callable[[], T], pickle_path, function_name=None, validity_check_fn: Optional[Callable[[T], bool]] = None, 

637 backend="pickle", protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None) -> T: 

638 """ 

639 Calls the given function unless its result is already cached (in a pickle), in which case it will read the cached result 

640 and return it. 

641 

642 Rather than directly calling this function, consider using the decorator variant :func:`pickle_cached`. 

643 

644 :param fn: the function whose result is to be cached 

645 :param pickle_path: the path in which to store the cached result 

646 :param function_name: the name of the function fn (for the case where its __name__ attribute is not 

647 informative) 

648 :param validity_check_fn: an optional function to call in order to check whether a cached result is still valid; 

649 the function shall return True if the result is still valid and false otherwise. If a cached result is invalid, 

650 the function fn is called to compute the result and the cached result is updated. 

651 :param backend: pickle or joblib 

652 :param protocol: the pickle protocol version 

653 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result 

654 :param version: if not None, previously persisted data will only be returned if it was stored with the same version 

655 :return: the result (either obtained from the cache or the function) 

656 """ 

657 if function_name is None: 

658 function_name = fn.__name__ 

659 

660 def call_fn_and_cache_result(): 

661 res = fn() 

662 log.info(f"Saving cached result in {pickle_path}") 

663 if version is not None: 

664 persisted_res = {"__cacheVersion": version, "obj": res} 

665 else: 

666 persisted_res = res 

667 dump_pickle(persisted_res, pickle_path, backend=backend, protocol=protocol) 

668 return res 

669 

670 if os.path.exists(pickle_path): 

671 if load: 

672 log.info(f"Loading cached result of function '{function_name}' from {pickle_path}") 

673 result = load_pickle(pickle_path, backend=backend) 

674 if validity_check_fn is not None: 

675 if not validity_check_fn(result): 

676 log.info(f"Cached result is no longer valid, recomputing ...") 

677 result = call_fn_and_cache_result() 

678 if version is not None: 

679 cached_version = None 

680 if type(result) == dict: 

681 cached_version = result.get("__cacheVersion") 

682 if cached_version != version: 

683 log.info(f"Cached result has incorrect version ({cached_version}, expected {version}), recomputing ...") 

684 result = call_fn_and_cache_result() 

685 else: 

686 result = result["obj"] 

687 return result 

688 else: 

689 log.info(f"Ignoring previously stored result in {pickle_path}, calling function '{function_name}' ...") 

690 return call_fn_and_cache_result() 

691 else: 

692 log.info(f"No cached result found in {pickle_path}, calling function '{function_name}' ...") 

693 return call_fn_and_cache_result() 

694 

695 

696def pickle_cached(cache_base_path: str, filename_prefix: str = None, filename: str = None, backend="pickle", 

697 protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None): 

698 """ 

699 Function decorator for caching function results via pickle. 

700 

701 Add this decorator to any function to cache its results in pickle files. 

702 The function may have arguments, in which case the cache will be specific to the actual arguments 

703 by computing a hash code from their pickled representation. 

704 

705 :param cache_base_path: the directory where the pickle cache file will be stored 

706 :param filename_prefix: a prefix of the name of the cache file to be created, to which the function name and, where applicable, 

707 a hash code of the function arguments as well as the extension ".cache.pickle" will be appended. 

708 The prefix need not end in a separator, as "-" will automatically be added between filename components. 

709 :param filename: the full file name of the cache file to be created; if the function takes arguments, the filename must 

710 contain a placeholder '%s' for the argument hash 

711 :param backend: the serialisation backend to use (see dumpPickle) 

712 :param protocol: the pickle protocol version to use 

713 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result 

714 :param version: if not None, previously persisted data will only be returned if it was stored with the same version 

715 """ 

716 os.makedirs(cache_base_path, exist_ok=True) 

717 

718 if filename_prefix is None: 

719 filename_prefix = "" 

720 else: 

721 filename_prefix += "-" 

722 

723 def decorator(fn: Callable, *_args, **_kwargs): 

724 

725 @wraps(fn) 

726 def wrapped(*args, **kwargs): 

727 hash_code_str = None 

728 have_args = len(args) > 0 or len(kwargs) > 0 

729 if have_args: 

730 hash_code_str = pickle_hash((args, kwargs)) 

731 if filename is None: 

732 pickle_filename = filename_prefix + fn.__qualname__.replace(".<locals>.", ".") 

733 if hash_code_str is not None: 

734 pickle_filename += "-" + hash_code_str 

735 pickle_filename += ".cache.pickle" 

736 else: 

737 if hash_code_str is not None: 

738 if "%s" not in filename: 

739 raise Exception("Function called with arguments but full cache filename contains no placeholder (%s) " 

740 "for argument hash") 

741 pickle_filename = filename % hash_code_str 

742 else: 

743 if "%s" in filename: 

744 raise Exception("Function without arguments but full cache filename with placeholder (%s) was specified") 

745 pickle_filename = filename 

746 pickle_path = os.path.join(cache_base_path, pickle_filename) 

747 return cached(lambda: fn(*args, **kwargs), pickle_path, function_name=fn.__name__, backend=backend, load=load, 

748 version=version, protocol=protocol) 

749 

750 return wrapped 

751 

752 return decorator 

753 

754 

755PickleCached = pickle_cached # for backward compatibility 

756 

757 

758class LoadSaveInterface(ABC): 

759 @abstractmethod 

760 def save(self, path: str) -> None: 

761 pass 

762 

763 @classmethod 

764 @abstractmethod 

765 def load(cls: T, path: str) -> T: 

766 pass 

767 

768 

769class PickleLoadSaveMixin(LoadSaveInterface): 

770 def save(self, path: Union[str, Path], backend="pickle"): 

771 """ 

772 Saves the instance as pickle 

773 

774 :param path: 

775 :param backend: pickle, cloudpickle, or joblib 

776 """ 

777 dump_pickle(self, path, backend=backend) 

778 

779 @classmethod 

780 def load(cls, path: Union[str, Path], backend="pickle"): 

781 """ 

782 Loads a class instance from pickle 

783 

784 :param path: 

785 :param backend: pickle, cloudpickle, or joblib 

786 :return: instance of the present class 

787 """ 

788 log.info(f"Loading instance of {cls} from {path}") 

789 result = load_pickle(path, backend=backend) 

790 if not isinstance(result, cls): 

791 raise Exception(f"Excepted instance of {cls}, instead got: {result.__class__.__name__}") 

792 return result 

793 

794 

795class LRUCache(KeyValueCache[THashableKey, TValue], Generic[THashableKey, TValue]): 

796 def __init__(self, capacity: int) -> None: 

797 self._cache = OrderedDict() 

798 self._capacity = capacity 

799 

800 def get(self, key: THashableKey) -> TValue: 

801 if key not in self._cache: 

802 return None 

803 self._cache.move_to_end(key) 

804 return self._cache[key] 

805 

806 def set(self, key: THashableKey, value: TValue): 

807 if key in self._cache: 

808 self._cache.move_to_end(key) 

809 self._cache[key] = value 

810 if len(self._cache) > self._capacity: 

811 self._cache.popitem(last=False) 

812 

813 def __len__(self) -> int: 

814 return len(self._cache) 

815 

816 def clear(self) -> None: 

817 self._cache.clear()