Coverage for src/sensai/util/cache.py: 26%
435 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import atexit
2import enum
3import glob
4import logging
5import os
6import pickle
7import re
8import sqlite3
9import threading
10import time
11from abc import abstractmethod, ABC
12from functools import wraps
13from pathlib import Path
14from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union
16from .hash import pickle_hash
17from .pickle import load_pickle, dump_pickle, setstate
19log = logging.getLogger(__name__)
21T = TypeVar("T")
22TKey = TypeVar("TKey")
23TValue = TypeVar("TValue")
24TData = TypeVar("TData")
27class BoxedValue(Generic[TValue]):
28 """
29 Container for a value, which can be used in caches where values may be None (to differentiate the value not being present in the cache
30 from the cached value being None)
31 """
32 def __init__(self, value: TValue):
33 self.value = value
36class KeyValueCache(Generic[TKey, TValue], ABC):
37 @abstractmethod
38 def set(self, key: TKey, value: TValue):
39 """
40 Sets a cached value
42 :param key: the key under which to store the value
43 :param value: the value to store; since None is used indicate the absence of a value, None should not be
44 used a value
45 """
46 pass
48 @abstractmethod
49 def get(self, key: TKey) -> Optional[TValue]:
50 """
51 Retrieves a cached value
53 :param key: the lookup key
54 :return: the cached value or None if no value is found
55 """
56 pass
59class InMemoryKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue]):
60 """A simple in-memory cache (which uses a dictionary internally).
62 This class can be instantiated directly, but for better typing support, one can instead
63 inherit from it and provide the types of the key and value as type arguments. For example for
64 a cache with string keys and integer values:
66 .. code-block:: python
68 class MyCache(InMemoryKeyValueCache[str, int]):
69 pass
70 """
71 def __init__(self):
72 self.cache = {}
74 def set(self, key: TKey, value: TValue):
75 self.cache[key] = value
77 def get(self, key: TKey) -> Optional[TValue]:
78 return self.cache.get(key)
80 def empty(self):
81 self.cache = {}
83 def __len__(self):
84 return len(self.cache)
88# mainly kept as a marker and for backwards compatibility, but may be extended in the future
89class PersistentKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue], ABC):
90 pass
93class PersistentList(Generic[TValue], ABC):
94 @abstractmethod
95 def append(self, item: TValue):
96 """
97 Adds an item to the cache
99 :param item: the item to store
100 """
101 pass
103 @abstractmethod
104 def iter_items(self) -> Iterator[TValue]:
105 """
106 Iterates over the items in the persisted list
108 :return: generator of item
109 """
110 pass
113class DelayedUpdateHook:
114 """
115 Ensures that a given function is executed after an update happens, but delay the execution until
116 there are no further updates for a certain time period
117 """
118 def __init__(self, fn: Callable[[], Any], time_period_secs, periodically_executed_fn: Optional[Callable[[], Any]] = None):
119 """
120 :param fn: the function to eventually call after an update
121 :param time_period_secs: the time that must pass while not receiving further updates for fn to be called
122 :param periodically_executed_fn: a function to execute periodically (every timePeriodSecs seconds) in the busy waiting loop,
123 which may, for example, log information or apply additional executions, which must not interfere with the correctness of
124 the execution of fn
125 """
126 self.periodicallyExecutedFn = periodically_executed_fn
127 self.fn = fn
128 self.timePeriodSecs = time_period_secs
129 self._lastUpdateTime = None
130 self._thread = None
131 self._threadLock = threading.Lock()
133 def handle_update(self):
134 """
135 Notifies of an update and ensures that the function passed at construction is eventually called
136 (after no more updates are received within the respective time window)
137 """
138 self._lastUpdateTime = time.time()
140 def do_periodic_check():
141 while True:
142 time.sleep(self.timePeriodSecs)
143 time_passed_since_last_update = time.time() - self._lastUpdateTime
144 if self.periodicallyExecutedFn is not None:
145 self.periodicallyExecutedFn()
146 if time_passed_since_last_update >= self.timePeriodSecs:
147 self.fn()
148 return
150 # noinspection DuplicatedCode
151 if self._thread is None or not self._thread.is_alive():
152 self._threadLock.acquire()
153 if self._thread is None or not self._thread.is_alive():
154 self._thread = threading.Thread(target=do_periodic_check, daemon=False)
155 self._thread.start()
156 self._threadLock.release()
159class PeriodicUpdateHook:
160 """
161 Periodically checks whether a function shall be called as a result of an update, the function potentially
162 being non-atomic (i.e. it may take a long time to execute such that new updates may come in while it is
163 executing). Two function all mechanisms are in place:
165 * a function which is called if there has not been a new update for a certain time period (which may be called
166 several times if updates come in while the function is being executed)
167 * a function which is called periodically
169 """
170 def __init__(self, check_interval_secs: float, no_update_time_period_secs: float = None, no_update_fn: Callable[[], Any] = None,
171 periodic_fn: Optional[Callable[[], Any]] = None):
172 """
173 :param check_interval_secs: the time period, in seconds, between checks
174 :param no_update_time_period_secs: the time period after which to execute noUpdateFn if no further updates have come in.
175 This must be at least as large as checkIntervalSecs. If None, use checkIntervalSecs.
176 :param no_update_fn: the function to call if there have been no further updates for noUpdateTimePeriodSecs seconds
177 :param periodic_fn: a function to execute periodically (every checkIntervalSecs seconds) in the busy waiting loop,
178 which may, for example, log information or apply additional executions, which must not interfere with the correctness of
179 the execution of fn
180 """
181 if no_update_time_period_secs is None:
182 no_update_time_period_secs = check_interval_secs
183 elif no_update_time_period_secs < check_interval_secs:
184 raise ValueError("noUpdateTimePeriodSecs must be at least as large as checkIntervalSecs")
185 self._periodic_fn = periodic_fn
186 self._check_interval_secs = check_interval_secs
187 self._no_update_time_period_secs = no_update_time_period_secs
188 self._no_update_fn = no_update_fn
189 self._last_update_time = None
190 self._thread = None
191 self._thread_lock = threading.Lock()
193 def handle_update(self):
194 """
195 Notifies of an update, making sure the functions given at construction will be called as specified
196 """
197 self._last_update_time = time.time()
199 def do_periodic_check():
200 while True:
201 time.sleep(self._check_interval_secs)
202 check_time = time.time()
203 if self._periodic_fn is not None:
204 self._periodic_fn()
205 time_passed_since_last_update = check_time - self._last_update_time
206 if time_passed_since_last_update >= self._no_update_time_period_secs:
207 if self._no_update_fn is not None:
208 self._no_update_fn()
209 # if no further updates have come in, we terminate the thread
210 if self._last_update_time < check_time:
211 return
213 # noinspection DuplicatedCode
214 if self._thread is None or not self._thread.is_alive():
215 self._thread_lock.acquire()
216 if self._thread is None or not self._thread.is_alive():
217 self._thread = threading.Thread(target=do_periodic_check, daemon=False)
218 self._thread.start()
219 self._thread_lock.release()
222class PicklePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]):
223 """
224 Represents a key-value cache as a dictionary which is persisted in a file using pickle
225 """
226 def __init__(self, pickle_path, version=1, save_on_update=True, deferred_save_delay_secs=1.0):
227 """
228 :param pickle_path: the path of the file where the cache values are to be persisted
229 :param version: the version of cache entries. If a persisted cache with a non-matching version is found,
230 it is discarded
231 :param save_on_update: whether to persist the cache after an update; the cache is saved in a deferred
232 manner and will be saved after deferredSaveDelaySecs if no new updates have arrived in the meantime,
233 i.e. it will ultimately be saved deferredSaveDelaySecs after the latest update
234 :param deferred_save_delay_secs: the number of seconds to wait for additional data to be added to the cache
235 before actually storing the cache after a cache update
236 """
237 self.deferred_save_delay_secs = deferred_save_delay_secs
238 self.pickle_path = pickle_path
239 self.version = version
240 self.save_on_update = save_on_update
241 cache_found = False
242 if os.path.exists(pickle_path):
243 try:
244 log.info(f"Loading cache from {pickle_path}")
245 persisted_version, self.cache = load_pickle(pickle_path)
246 if persisted_version == version:
247 cache_found = True
248 except EOFError:
249 log.warning(f"The cache file in {pickle_path} is corrupt")
250 if not cache_found:
251 self.cache = {}
252 self._update_hook = DelayedUpdateHook(self.save, deferred_save_delay_secs)
253 self._write_lock = threading.Lock()
255 def save(self):
256 """
257 Saves the cache in the file whose path was provided at construction
258 """
259 with self._write_lock: # avoid concurrent modification while saving
260 log.info(f"Saving cache to {self.pickle_path}")
261 dump_pickle((self.version, self.cache), self.pickle_path)
263 def get(self, key: TKey) -> Optional[TValue]:
264 return self.cache.get(key)
266 def set(self, key: TKey, value: TValue):
267 with self._write_lock:
268 self.cache[key] = value
269 if self.save_on_update:
270 self._update_hook.handle_update()
273class SlicedPicklePersistentList(PersistentList):
274 """
275 Object handling the creation and access to sliced pickle caches
276 """
277 def __init__(self, directory, pickle_base_name, num_entries_per_slice=100000):
278 """
279 :param directory: path to the directory where the sliced caches are to be stored
280 :param pickle_base_name: base name for the pickle, where slices will have the names {pickleBaseName}_sliceX.pickle
281 :param num_entries_per_slice: how many entries should be stored in each cache
282 """
283 self.directory = directory
284 self.pickleBaseName = pickle_base_name
285 self.numEntriesPerSlice = num_entries_per_slice
287 # Set up the variables for the sliced cache
288 self.slice_id = 0
289 self.index_in_slice = 0
290 self.cache_of_slice = []
292 # Search directory for already present sliced caches
293 self.slicedFiles = self._find_sliced_caches()
295 # Helper variable to ensure object is only modified within a with-clause
296 self._currentlyInWithClause = False
298 def __enter__(self):
299 self._currentlyInWithClause = True
300 if self.cache_exists():
301 # Reset state to enable the appending of more items to the cache
302 self._set_last_cache_state()
303 return self
305 def __exit__(self, exc_type, exc_val, exc_tb):
306 self._dump()
307 self._currentlyInWithClause = False
309 def append(self, item):
310 """
311 Append item to cache
312 :param item: entry in the cache
313 """
314 if not self._currentlyInWithClause:
315 raise Exception("Class needs to be instantiated within a with-clause to ensure correct storage")
317 if (self.index_in_slice + 1) % self.numEntriesPerSlice == 0:
318 self._dump()
320 self.cache_of_slice.append(item)
321 self.index_in_slice += 1
323 def iter_items(self) -> Iterator[Any]:
324 """
325 Iterate over entries in the sliced cache
326 :return: iterator over all items in the cache
327 """
328 for filePath in self.slicedFiles:
329 log.info(f"Loading sliced pickle list from {filePath}")
330 cached_pickle = self._load_pickle(filePath)
331 for item in cached_pickle:
332 yield item
334 def clear(self):
335 """
336 Clears the cache if it exists
337 """
338 if self.cache_exists():
339 for filePath in self.slicedFiles:
340 os.unlink(filePath)
342 def cache_exists(self) -> bool:
343 """
344 Does this cache already exist
345 :return: True if cache exists, False if not
346 """
347 return len(self.slicedFiles) > 0
349 def _set_last_cache_state(self):
350 """
351 Sets the state so as to be able to add items to an existing cache
352 """
353 log.info("Resetting last state of cache...")
354 self.slice_id = len(self.slicedFiles) - 1
355 self.cache_of_slice = self._load_pickle(self._pickle_path(self.slice_id))
356 self.index_in_slice = len(self.cache_of_slice) - 1
357 if self.index_in_slice >= self.numEntriesPerSlice:
358 self._next_slice()
360 def _dump(self):
361 """
362 Dumps the current cache (if non-empty)
363 """
364 if len(self.cache_of_slice) > 0:
365 pickle_path = self._pickle_path(str(self.slice_id))
366 log.info(f"Saving sliced cache to {pickle_path}")
367 dump_pickle(self.cache_of_slice, pickle_path)
368 self.slicedFiles.append(pickle_path)
370 # Update slice number and reset indexing and cache
371 self._next_slice()
372 else:
373 log.warning("Unexpected behavior: Dump was called when cache of slice is 0!")
375 def _next_slice(self):
376 """
377 Updates sliced cache state for the next slice
378 """
379 self.slice_id += 1
380 self.index_in_slice = 0
381 self.cache_of_slice = []
383 def _find_sliced_caches(self) -> List[str]:
384 """
385 Finds all pickled slices associated with this cache
386 :return: list of sliced pickled files
387 """
388 # glob.glob permits the usage of unix-style pathnames matching. (below we find all ..._slice*.pickle files)
389 list_of_file_names = glob.glob(self._pickle_path("*"))
390 # Sort the slices to ensure it is in the same order as they was produced (regex replaces everything not a number with empty string).
391 list_of_file_names.sort(key=lambda f: int(re.sub(r'\D', '', f)))
392 return list_of_file_names
394 def _load_pickle(self, pickle_path: str) -> List[Any]:
395 """
396 Loads pickle if file path exists, and persisted version is correct.
397 :param pickle_path: file path
398 :return: list with objects
399 """
400 cached_pickle = []
401 if os.path.exists(pickle_path):
402 try:
403 cached_pickle = load_pickle(pickle_path)
404 except EOFError:
405 log.warning(f"The cache file in {pickle_path} is corrupt")
406 else:
407 raise Exception(f"The file {pickle_path} does not exist!")
408 return cached_pickle
410 def _pickle_path(self, slice_suffix) -> str:
411 return f"{os.path.join(self.directory, self.pickleBaseName)}_slice{slice_suffix}.pickle"
414class SqliteConnectionManager:
415 _connections: List[sqlite3.Connection] = []
416 _atexit_handler_registered = False
418 @classmethod
419 def _register_at_exit_handler(cls):
420 if not cls._atexit_handler_registered:
421 cls._atexit_handler_registered = True
422 atexit.register(cls._cleanup)
424 @classmethod
425 def open_connection(cls, path):
426 cls._register_at_exit_handler()
427 conn = sqlite3.connect(path, check_same_thread=False)
428 cls._connections.append(conn)
429 return conn
431 @classmethod
432 def _cleanup(cls):
433 for conn in cls._connections:
434 conn.close()
435 cls._connections = []
438class SqlitePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]):
439 class KeyType(enum.Enum):
440 STRING = ("VARCHAR(%d)", )
441 INTEGER = ("LONG", )
443 def __init__(self, path, table_name="cache", deferred_commit_delay_secs=1.0, key_type: KeyType = KeyType.STRING,
444 max_key_length=255):
445 """
446 :param path: the path to the file that is to hold the SQLite database
447 :param table_name: the name of the table to create in the database
448 :param deferred_commit_delay_secs: the time frame during which no new data must be added for a pending transaction to be committed
449 :param key_type: the type to use for keys; for complex keys (i.e. tuples), use STRING (conversions to string are automatic)
450 :param max_key_length: the maximum key length for the case where the key_type can be parametrised (e.g. STRING)
451 """
452 self.path = path
453 self.conn = SqliteConnectionManager.open_connection(path)
454 self.table_name = table_name
455 self.max_key_length = 255
456 self.key_type = key_type
457 self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs)
458 self._num_entries_to_be_committed = 0
459 self._conn_mutex = threading.Lock()
461 cursor = self.conn.cursor()
462 cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table';")
463 if table_name not in [r[0] for r in cursor.fetchall()]:
464 log.info(f"Creating cache table '{self.table_name}' in {path}")
465 key_db_type = key_type.value[0]
466 if "%d" in key_db_type:
467 key_db_type = key_db_type % max_key_length
468 cursor.execute(f"CREATE TABLE {table_name} (cache_key {key_db_type} PRIMARY KEY, cache_value BLOB);")
469 cursor.close()
471 def _key_db_value(self, key):
472 if self.key_type == self.KeyType.STRING:
473 s = str(key)
474 if len(s) > self.max_key_length:
475 raise ValueError(f"Key too long, maximal key length is {self.max_key_length}")
476 return s
477 elif self.key_type == self.KeyType.INTEGER:
478 return int(key)
479 else:
480 raise Exception(f"Unhandled key type {self.key_type}")
482 def _commit(self):
483 self._conn_mutex.acquire()
484 try:
485 log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the SQLite database {self.path}")
486 self.conn.commit()
487 self._num_entries_to_be_committed = 0
488 finally:
489 self._conn_mutex.release()
491 def set(self, key: TKey, value: TValue):
492 self._conn_mutex.acquire()
493 try:
494 cursor = self.conn.cursor()
495 key = self._key_db_value(key)
496 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=?", (key,))
497 if cursor.fetchone()[0] == 0:
498 cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (?, ?)",
499 (key, pickle.dumps(value)))
500 else:
501 cursor.execute(f"UPDATE {self.table_name} SET cache_value=? WHERE cache_key=?", (pickle.dumps(value), key))
502 self._num_entries_to_be_committed += 1
503 cursor.close()
504 finally:
505 self._conn_mutex.release()
507 self._update_hook.handle_update()
509 def _execute(self, cursor, *query):
510 try:
511 cursor.execute(*query)
512 except sqlite3.DatabaseError as e:
513 raise Exception(f"Error executing query for {self.path}: {e}")
515 def get(self, key: TKey) -> Optional[TValue]:
516 self._conn_mutex.acquire()
517 try:
518 cursor = self.conn.cursor()
519 key = self._key_db_value(key)
520 self._execute(cursor, f"SELECT cache_value FROM {self.table_name} WHERE cache_key=?", (key,))
521 row = cursor.fetchone()
522 cursor.close()
523 if row is None:
524 return None
525 return pickle.loads(row[0])
526 finally:
527 self._conn_mutex.release()
529 def __len__(self):
530 self._conn_mutex.acquire()
531 try:
532 cursor = self.conn.cursor()
533 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
534 cnt = cursor.fetchone()[0]
535 cursor.close()
536 return cnt
537 finally:
538 self._conn_mutex.release()
540 def iter_items(self):
541 self._conn_mutex.acquire()
542 try:
543 cursor = self.conn.cursor()
544 cursor.execute(f"SELECT cache_key, cache_value FROM {self.table_name}")
545 while True:
546 row = cursor.fetchone()
547 if row is None:
548 break
549 yield row[0], pickle.loads(row[1])
550 cursor.close()
551 finally:
552 self._conn_mutex.release()
555class SqlitePersistentList(PersistentList):
556 def __init__(self, path):
557 self.keyValueCache = SqlitePersistentKeyValueCache(path, key_type=SqlitePersistentKeyValueCache.KeyType.INTEGER)
558 self.nextKey = len(self.keyValueCache)
560 def append(self, item):
561 self.keyValueCache.set(self.nextKey, item)
562 self.nextKey += 1
564 def iter_items(self):
565 for item in self.keyValueCache.iter_items():
566 yield item[1]
569class CachedValueProviderMixin(Generic[TKey, TValue, TData], ABC):
570 """
571 Represents a value provider that can provide values associated with (hashable) keys via a cache or, if
572 cached values are not yet present, by computing them.
573 """
574 def __init__(self, cache: Optional[KeyValueCache[TKey, TValue]] = None,
575 cache_factory: Optional[Callable[[], KeyValueCache[TKey, TValue]]] = None, persist_cache=False, box_values=False):
576 """
577 :param cache: the cache to use or None. If None, caching will be disabled
578 :param cache_factory: a factory with which to create the cache (or recreate it after unpickling if `persistCache` is False, in which
579 case this factory must be picklable)
580 :param persist_cache: whether to persist the cache when pickling
581 :param box_values: whether to box values, such that None is admissible as a value
582 """
583 self._persistCache = persist_cache
584 self._boxValues = box_values
585 self._cache = cache
586 self._cacheFactory = cache_factory
587 if self._cache is None and cache_factory is not None:
588 self._cache = cache_factory()
590 def __getstate__(self):
591 if not self._persistCache:
592 d = self.__dict__.copy()
593 d["_cache"] = None
594 return d
595 return self.__dict__
597 def __setstate__(self, state):
598 setstate(CachedValueProviderMixin, self, state, renamed_properties={"persistCache": "_persistCache"})
599 if not self._persistCache and self._cacheFactory is not None:
600 self._cache = self._cacheFactory()
602 def _provide_value(self, key, data: Optional[TData] = None):
603 """
604 Provides the value for the key by retrieving the associated value from the cache or, if no entry in the
605 cache is found, by computing the value via _computeValue
607 :param key: the key for which to provide the value
608 :param data: optional data required to compute a value
609 :return: the retrieved or computed value
610 """
611 if self._cache is None:
612 return self._compute_value(key, data)
613 value = self._cache.get(key)
614 if value is None:
615 value = self._compute_value(key, data)
616 self._cache.set(key, value if not self._boxValues else BoxedValue(value))
617 else:
618 if self._boxValues:
619 value: BoxedValue[TValue]
620 value = value.value
621 return value
623 @abstractmethod
624 def _compute_value(self, key: TKey, data: Optional[TData]) -> TValue:
625 """
626 Computes the value for the given key
628 :param key: the key for which to compute the value
629 :return: the computed value
630 """
631 pass
634def cached(fn: Callable[[], T], pickle_path, function_name=None, validity_check_fn: Optional[Callable[[T], bool]] = None,
635 backend="pickle", protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None) -> T:
636 """
637 Calls the given function unless its result is already cached (in a pickle), in which case it will read the cached result
638 and return it.
640 Rather than directly calling this function, consider using the decorator variant :func:`pickle_cached`.
642 :param fn: the function whose result is to be cached
643 :param pickle_path: the path in which to store the cached result
644 :param function_name: the name of the function fn (for the case where its __name__ attribute is not
645 informative)
646 :param validity_check_fn: an optional function to call in order to check whether a cached result is still valid;
647 the function shall return True if the result is still valid and false otherwise. If a cached result is invalid,
648 the function fn is called to compute the result and the cached result is updated.
649 :param backend: pickle or joblib
650 :param protocol: the pickle protocol version
651 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result
652 :param version: if not None, previously persisted data will only be returned if it was stored with the same version
653 :return: the result (either obtained from the cache or the function)
654 """
655 if function_name is None:
656 function_name = fn.__name__
658 def call_fn_and_cache_result():
659 res = fn()
660 log.info(f"Saving cached result in {pickle_path}")
661 if version is not None:
662 persisted_res = {"__cacheVersion": version, "obj": res}
663 else:
664 persisted_res = res
665 dump_pickle(persisted_res, pickle_path, backend=backend, protocol=protocol)
666 return res
668 if os.path.exists(pickle_path):
669 if load:
670 log.info(f"Loading cached result of function '{function_name}' from {pickle_path}")
671 result = load_pickle(pickle_path, backend=backend)
672 if validity_check_fn is not None:
673 if not validity_check_fn(result):
674 log.info(f"Cached result is no longer valid, recomputing ...")
675 result = call_fn_and_cache_result()
676 if version is not None:
677 cached_version = None
678 if type(result) == dict:
679 cached_version = result.get("__cacheVersion")
680 if cached_version != version:
681 log.info(f"Cached result has incorrect version ({cached_version}, expected {version}), recomputing ...")
682 result = call_fn_and_cache_result()
683 else:
684 result = result["obj"]
685 return result
686 else:
687 log.info(f"Ignoring previously stored result in {pickle_path}, calling function '{function_name}' ...")
688 return call_fn_and_cache_result()
689 else:
690 log.info(f"No cached result found in {pickle_path}, calling function '{function_name}' ...")
691 return call_fn_and_cache_result()
694def pickle_cached(cache_base_path: str, filename_prefix: str = None, filename: str = None, backend="pickle",
695 protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None):
696 """
697 Function decorator for caching function results via pickle.
699 Add this decorator to any function to cache its results in pickle files.
700 The function may have arguments, in which case the cache will be specific to the actual arguments
701 by computing a hash code from their pickled representation.
703 :param cache_base_path: the directory where the pickle cache file will be stored
704 :param filename_prefix: a prefix of the name of the cache file to be created, to which the function name and, where applicable,
705 a hash code of the function arguments as well as the extension ".cache.pickle" will be appended.
706 The prefix need not end in a separator, as "-" will automatically be added between filename components.
707 :param filename: the full file name of the cache file to be created; if the function takes arguments, the filename must
708 contain a placeholder '%s' for the argument hash
709 :param backend: the serialisation backend to use (see dumpPickle)
710 :param protocol: the pickle protocol version to use
711 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result
712 :param version: if not None, previously persisted data will only be returned if it was stored with the same version
713 """
714 os.makedirs(cache_base_path, exist_ok=True)
716 if filename_prefix is None:
717 filename_prefix = ""
718 else:
719 filename_prefix += "-"
721 def decorator(fn: Callable, *_args, **_kwargs):
723 @wraps(fn)
724 def wrapped(*args, **kwargs):
725 hash_code_str = None
726 have_args = len(args) > 0 or len(kwargs) > 0
727 if have_args:
728 hash_code_str = pickle_hash((args, kwargs))
729 if filename is None:
730 pickle_filename = filename_prefix + fn.__qualname__.replace(".<locals>.", ".")
731 if hash_code_str is not None:
732 pickle_filename += "-" + hash_code_str
733 pickle_filename += ".cache.pickle"
734 else:
735 if hash_code_str is not None:
736 if "%s" not in filename:
737 raise Exception("Function called with arguments but full cache filename contains no placeholder (%s) "
738 "for argument hash")
739 pickle_filename = filename % hash_code_str
740 else:
741 if "%s" in filename:
742 raise Exception("Function without arguments but full cache filename with placeholder (%s) was specified")
743 pickle_filename = filename
744 pickle_path = os.path.join(cache_base_path, pickle_filename)
745 return cached(lambda: fn(*args, **kwargs), pickle_path, function_name=fn.__name__, backend=backend, load=load,
746 version=version, protocol=protocol)
748 return wrapped
750 return decorator
753PickleCached = pickle_cached # for backward compatibility
756class LoadSaveInterface(ABC):
757 @abstractmethod
758 def save(self, path: str) -> None:
759 pass
761 @classmethod
762 @abstractmethod
763 def load(cls: T, path: str) -> T:
764 pass
767class PickleLoadSaveMixin(LoadSaveInterface):
768 def save(self, path: Union[str, Path], backend="pickle"):
769 """
770 Saves the instance as pickle
772 :param path:
773 :param backend: pickle, cloudpickle, or joblib
774 """
775 dump_pickle(self, path, backend=backend)
777 @classmethod
778 def load(cls, path: Union[str, Path], backend="pickle"):
779 """
780 Loads a class instance from pickle
782 :param path:
783 :param backend: pickle, cloudpickle, or joblib
784 :return: instance of the present class
785 """
786 log.info(f"Loading instance of {cls} from {path}")
787 result = load_pickle(path, backend=backend)
788 if not isinstance(result, cls):
789 raise Exception(f"Excepted instance of {cls}, instead got: {result.__class__.__name__}")
790 return result