Coverage for src/sensai/util/cache.py: 27%
456 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import atexit
2import enum
3import glob
4import logging
5import os
6import pickle
7import re
8import sqlite3
9import threading
10import time
11from abc import abstractmethod, ABC
12from collections import OrderedDict
13from functools import wraps
14from pathlib import Path
15from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union, Hashable
17from .hash import pickle_hash
18from .pickle import load_pickle, dump_pickle, setstate
20log = logging.getLogger(__name__)
22T = TypeVar("T")
23TKey = TypeVar("TKey")
24THashableKey = TypeVar("THashableKey", bound=Hashable)
25TValue = TypeVar("TValue")
26TData = TypeVar("TData")
29class BoxedValue(Generic[TValue]):
30 """
31 Container for a value, which can be used in caches where values may be None (to differentiate the value not being present in the cache
32 from the cached value being None)
33 """
34 def __init__(self, value: TValue):
35 self.value = value
38class KeyValueCache(Generic[TKey, TValue], ABC):
39 @abstractmethod
40 def set(self, key: TKey, value: TValue):
41 """
42 Sets a cached value
44 :param key: the key under which to store the value
45 :param value: the value to store; since None is used indicate the absence of a value, None should not be
46 used a value
47 """
48 pass
50 @abstractmethod
51 def get(self, key: TKey) -> Optional[TValue]:
52 """
53 Retrieves a cached value
55 :param key: the lookup key
56 :return: the cached value or None if no value is found
57 """
58 pass
61class InMemoryKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue]):
62 """A simple in-memory cache (which uses a dictionary internally).
64 This class can be instantiated directly, but for better typing support, one can instead
65 inherit from it and provide the types of the key and value as type arguments. For example for
66 a cache with string keys and integer values:
68 .. code-block:: python
70 class MyCache(InMemoryKeyValueCache[str, int]):
71 pass
72 """
73 def __init__(self):
74 self.cache = {}
76 def set(self, key: TKey, value: TValue):
77 self.cache[key] = value
79 def get(self, key: TKey) -> Optional[TValue]:
80 return self.cache.get(key)
82 def empty(self):
83 self.cache = {}
85 def __len__(self):
86 return len(self.cache)
90# mainly kept as a marker and for backwards compatibility, but may be extended in the future
91class PersistentKeyValueCache(KeyValueCache[TKey, TValue], Generic[TKey, TValue], ABC):
92 pass
95class PersistentList(Generic[TValue], ABC):
96 @abstractmethod
97 def append(self, item: TValue):
98 """
99 Adds an item to the cache
101 :param item: the item to store
102 """
103 pass
105 @abstractmethod
106 def iter_items(self) -> Iterator[TValue]:
107 """
108 Iterates over the items in the persisted list
110 :return: generator of item
111 """
112 pass
115class DelayedUpdateHook:
116 """
117 Ensures that a given function is executed after an update happens, but delay the execution until
118 there are no further updates for a certain time period
119 """
120 def __init__(self, fn: Callable[[], Any], time_period_secs, periodically_executed_fn: Optional[Callable[[], Any]] = None):
121 """
122 :param fn: the function to eventually call after an update
123 :param time_period_secs: the time that must pass while not receiving further updates for fn to be called
124 :param periodically_executed_fn: a function to execute periodically (every timePeriodSecs seconds) in the busy waiting loop,
125 which may, for example, log information or apply additional executions, which must not interfere with the correctness of
126 the execution of fn
127 """
128 self.periodicallyExecutedFn = periodically_executed_fn
129 self.fn = fn
130 self.timePeriodSecs = time_period_secs
131 self._lastUpdateTime = None
132 self._thread = None
133 self._threadLock = threading.Lock()
135 def handle_update(self):
136 """
137 Notifies of an update and ensures that the function passed at construction is eventually called
138 (after no more updates are received within the respective time window)
139 """
140 self._lastUpdateTime = time.time()
142 def do_periodic_check():
143 while True:
144 time.sleep(self.timePeriodSecs)
145 time_passed_since_last_update = time.time() - self._lastUpdateTime
146 if self.periodicallyExecutedFn is not None:
147 self.periodicallyExecutedFn()
148 if time_passed_since_last_update >= self.timePeriodSecs:
149 self.fn()
150 return
152 # noinspection DuplicatedCode
153 if self._thread is None or not self._thread.is_alive():
154 self._threadLock.acquire()
155 if self._thread is None or not self._thread.is_alive():
156 self._thread = threading.Thread(target=do_periodic_check, daemon=False)
157 self._thread.start()
158 self._threadLock.release()
161class PeriodicUpdateHook:
162 """
163 Periodically checks whether a function shall be called as a result of an update, the function potentially
164 being non-atomic (i.e. it may take a long time to execute such that new updates may come in while it is
165 executing). Two function all mechanisms are in place:
167 * a function which is called if there has not been a new update for a certain time period (which may be called
168 several times if updates come in while the function is being executed)
169 * a function which is called periodically
171 """
172 def __init__(self, check_interval_secs: float, no_update_time_period_secs: float = None, no_update_fn: Callable[[], Any] = None,
173 periodic_fn: Optional[Callable[[], Any]] = None):
174 """
175 :param check_interval_secs: the time period, in seconds, between checks
176 :param no_update_time_period_secs: the time period after which to execute noUpdateFn if no further updates have come in.
177 This must be at least as large as checkIntervalSecs. If None, use checkIntervalSecs.
178 :param no_update_fn: the function to call if there have been no further updates for noUpdateTimePeriodSecs seconds
179 :param periodic_fn: a function to execute periodically (every checkIntervalSecs seconds) in the busy waiting loop,
180 which may, for example, log information or apply additional executions, which must not interfere with the correctness of
181 the execution of fn
182 """
183 if no_update_time_period_secs is None:
184 no_update_time_period_secs = check_interval_secs
185 elif no_update_time_period_secs < check_interval_secs:
186 raise ValueError("noUpdateTimePeriodSecs must be at least as large as checkIntervalSecs")
187 self._periodic_fn = periodic_fn
188 self._check_interval_secs = check_interval_secs
189 self._no_update_time_period_secs = no_update_time_period_secs
190 self._no_update_fn = no_update_fn
191 self._last_update_time = None
192 self._thread = None
193 self._thread_lock = threading.Lock()
195 def handle_update(self):
196 """
197 Notifies of an update, making sure the functions given at construction will be called as specified
198 """
199 self._last_update_time = time.time()
201 def do_periodic_check():
202 while True:
203 time.sleep(self._check_interval_secs)
204 check_time = time.time()
205 if self._periodic_fn is not None:
206 self._periodic_fn()
207 time_passed_since_last_update = check_time - self._last_update_time
208 if time_passed_since_last_update >= self._no_update_time_period_secs:
209 if self._no_update_fn is not None:
210 self._no_update_fn()
211 # if no further updates have come in, we terminate the thread
212 if self._last_update_time < check_time:
213 return
215 # noinspection DuplicatedCode
216 if self._thread is None or not self._thread.is_alive():
217 self._thread_lock.acquire()
218 if self._thread is None or not self._thread.is_alive():
219 self._thread = threading.Thread(target=do_periodic_check, daemon=False)
220 self._thread.start()
221 self._thread_lock.release()
224class PicklePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]):
225 """
226 Represents a key-value cache as a dictionary which is persisted in a file using pickle
227 """
228 def __init__(self, pickle_path, version=1, save_on_update=True, deferred_save_delay_secs=1.0):
229 """
230 :param pickle_path: the path of the file where the cache values are to be persisted
231 :param version: the version of cache entries. If a persisted cache with a non-matching version is found,
232 it is discarded
233 :param save_on_update: whether to persist the cache after an update; the cache is saved in a deferred
234 manner and will be saved after deferredSaveDelaySecs if no new updates have arrived in the meantime,
235 i.e. it will ultimately be saved deferredSaveDelaySecs after the latest update
236 :param deferred_save_delay_secs: the number of seconds to wait for additional data to be added to the cache
237 before actually storing the cache after a cache update
238 """
239 self.deferred_save_delay_secs = deferred_save_delay_secs
240 self.pickle_path = pickle_path
241 self.version = version
242 self.save_on_update = save_on_update
243 cache_found = False
244 if os.path.exists(pickle_path):
245 try:
246 log.info(f"Loading cache from {pickle_path}")
247 persisted_version, self.cache = load_pickle(pickle_path)
248 if persisted_version == version:
249 cache_found = True
250 except EOFError:
251 log.warning(f"The cache file in {pickle_path} is corrupt")
252 if not cache_found:
253 self.cache = {}
254 self._update_hook = DelayedUpdateHook(self.save, deferred_save_delay_secs)
255 self._write_lock = threading.Lock()
257 def save(self):
258 """
259 Saves the cache in the file whose path was provided at construction
260 """
261 with self._write_lock: # avoid concurrent modification while saving
262 log.info(f"Saving cache to {self.pickle_path}")
263 dump_pickle((self.version, self.cache), self.pickle_path)
265 def get(self, key: TKey) -> Optional[TValue]:
266 return self.cache.get(key)
268 def set(self, key: TKey, value: TValue):
269 with self._write_lock:
270 self.cache[key] = value
271 if self.save_on_update:
272 self._update_hook.handle_update()
275class SlicedPicklePersistentList(PersistentList):
276 """
277 Object handling the creation and access to sliced pickle caches
278 """
279 def __init__(self, directory, pickle_base_name, num_entries_per_slice=100000):
280 """
281 :param directory: path to the directory where the sliced caches are to be stored
282 :param pickle_base_name: base name for the pickle, where slices will have the names {pickleBaseName}_sliceX.pickle
283 :param num_entries_per_slice: how many entries should be stored in each cache
284 """
285 self.directory = directory
286 self.pickleBaseName = pickle_base_name
287 self.numEntriesPerSlice = num_entries_per_slice
289 # Set up the variables for the sliced cache
290 self.slice_id = 0
291 self.index_in_slice = 0
292 self.cache_of_slice = []
294 # Search directory for already present sliced caches
295 self.slicedFiles = self._find_sliced_caches()
297 # Helper variable to ensure object is only modified within a with-clause
298 self._currentlyInWithClause = False
300 def __enter__(self):
301 self._currentlyInWithClause = True
302 if self.cache_exists():
303 # Reset state to enable the appending of more items to the cache
304 self._set_last_cache_state()
305 return self
307 def __exit__(self, exc_type, exc_val, exc_tb):
308 self._dump()
309 self._currentlyInWithClause = False
311 def append(self, item):
312 """
313 Append item to cache
314 :param item: entry in the cache
315 """
316 if not self._currentlyInWithClause:
317 raise Exception("Class needs to be instantiated within a with-clause to ensure correct storage")
319 if (self.index_in_slice + 1) % self.numEntriesPerSlice == 0:
320 self._dump()
322 self.cache_of_slice.append(item)
323 self.index_in_slice += 1
325 def iter_items(self) -> Iterator[Any]:
326 """
327 Iterate over entries in the sliced cache
328 :return: iterator over all items in the cache
329 """
330 for filePath in self.slicedFiles:
331 log.info(f"Loading sliced pickle list from {filePath}")
332 cached_pickle = self._load_pickle(filePath)
333 for item in cached_pickle:
334 yield item
336 def clear(self):
337 """
338 Clears the cache if it exists
339 """
340 if self.cache_exists():
341 for filePath in self.slicedFiles:
342 os.unlink(filePath)
344 def cache_exists(self) -> bool:
345 """
346 Does this cache already exist
347 :return: True if cache exists, False if not
348 """
349 return len(self.slicedFiles) > 0
351 def _set_last_cache_state(self):
352 """
353 Sets the state so as to be able to add items to an existing cache
354 """
355 log.info("Resetting last state of cache...")
356 self.slice_id = len(self.slicedFiles) - 1
357 self.cache_of_slice = self._load_pickle(self._pickle_path(self.slice_id))
358 self.index_in_slice = len(self.cache_of_slice) - 1
359 if self.index_in_slice >= self.numEntriesPerSlice:
360 self._next_slice()
362 def _dump(self):
363 """
364 Dumps the current cache (if non-empty)
365 """
366 if len(self.cache_of_slice) > 0:
367 pickle_path = self._pickle_path(str(self.slice_id))
368 log.info(f"Saving sliced cache to {pickle_path}")
369 dump_pickle(self.cache_of_slice, pickle_path)
370 self.slicedFiles.append(pickle_path)
372 # Update slice number and reset indexing and cache
373 self._next_slice()
374 else:
375 log.warning("Unexpected behavior: Dump was called when cache of slice is 0!")
377 def _next_slice(self):
378 """
379 Updates sliced cache state for the next slice
380 """
381 self.slice_id += 1
382 self.index_in_slice = 0
383 self.cache_of_slice = []
385 def _find_sliced_caches(self) -> List[str]:
386 """
387 Finds all pickled slices associated with this cache
388 :return: list of sliced pickled files
389 """
390 # glob.glob permits the usage of unix-style pathnames matching. (below we find all ..._slice*.pickle files)
391 list_of_file_names = glob.glob(self._pickle_path("*"))
392 # Sort the slices to ensure it is in the same order as they was produced (regex replaces everything not a number with empty string).
393 list_of_file_names.sort(key=lambda f: int(re.sub(r'\D', '', f)))
394 return list_of_file_names
396 def _load_pickle(self, pickle_path: str) -> List[Any]:
397 """
398 Loads pickle if file path exists, and persisted version is correct.
399 :param pickle_path: file path
400 :return: list with objects
401 """
402 cached_pickle = []
403 if os.path.exists(pickle_path):
404 try:
405 cached_pickle = load_pickle(pickle_path)
406 except EOFError:
407 log.warning(f"The cache file in {pickle_path} is corrupt")
408 else:
409 raise Exception(f"The file {pickle_path} does not exist!")
410 return cached_pickle
412 def _pickle_path(self, slice_suffix) -> str:
413 return f"{os.path.join(self.directory, self.pickleBaseName)}_slice{slice_suffix}.pickle"
416class SqliteConnectionManager:
417 _connections: List[sqlite3.Connection] = []
418 _atexit_handler_registered = False
420 @classmethod
421 def _register_at_exit_handler(cls):
422 if not cls._atexit_handler_registered:
423 cls._atexit_handler_registered = True
424 atexit.register(cls._cleanup)
426 @classmethod
427 def open_connection(cls, path):
428 cls._register_at_exit_handler()
429 conn = sqlite3.connect(path, check_same_thread=False)
430 cls._connections.append(conn)
431 return conn
433 @classmethod
434 def _cleanup(cls):
435 for conn in cls._connections:
436 conn.close()
437 cls._connections = []
440class SqlitePersistentKeyValueCache(PersistentKeyValueCache[TKey, TValue]):
441 class KeyType(enum.Enum):
442 STRING = ("VARCHAR(%d)", )
443 INTEGER = ("LONG", )
445 def __init__(self, path, table_name="cache", deferred_commit_delay_secs=1.0, key_type: KeyType = KeyType.STRING,
446 max_key_length=255):
447 """
448 :param path: the path to the file that is to hold the SQLite database
449 :param table_name: the name of the table to create in the database
450 :param deferred_commit_delay_secs: the time frame during which no new data must be added for a pending transaction to be committed
451 :param key_type: the type to use for keys; for complex keys (i.e. tuples), use STRING (conversions to string are automatic)
452 :param max_key_length: the maximum key length for the case where the key_type can be parametrised (e.g. STRING)
453 """
454 self.path = path
455 self.conn = SqliteConnectionManager.open_connection(path)
456 self.table_name = table_name
457 self.max_key_length = 255
458 self.key_type = key_type
459 self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs)
460 self._num_entries_to_be_committed = 0
461 self._conn_mutex = threading.Lock()
463 cursor = self.conn.cursor()
464 cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table';")
465 if table_name not in [r[0] for r in cursor.fetchall()]:
466 log.info(f"Creating cache table '{self.table_name}' in {path}")
467 key_db_type = key_type.value[0]
468 if "%d" in key_db_type:
469 key_db_type = key_db_type % max_key_length
470 cursor.execute(f"CREATE TABLE {table_name} (cache_key {key_db_type} PRIMARY KEY, cache_value BLOB);")
471 cursor.close()
473 def _key_db_value(self, key):
474 if self.key_type == self.KeyType.STRING:
475 s = str(key)
476 if len(s) > self.max_key_length:
477 raise ValueError(f"Key too long, maximal key length is {self.max_key_length}")
478 return s
479 elif self.key_type == self.KeyType.INTEGER:
480 return int(key)
481 else:
482 raise Exception(f"Unhandled key type {self.key_type}")
484 def _commit(self):
485 self._conn_mutex.acquire()
486 try:
487 log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the SQLite database {self.path}")
488 self.conn.commit()
489 self._num_entries_to_be_committed = 0
490 finally:
491 self._conn_mutex.release()
493 def set(self, key: TKey, value: TValue):
494 self._conn_mutex.acquire()
495 try:
496 cursor = self.conn.cursor()
497 key = self._key_db_value(key)
498 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=?", (key,))
499 if cursor.fetchone()[0] == 0:
500 cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (?, ?)",
501 (key, pickle.dumps(value)))
502 else:
503 cursor.execute(f"UPDATE {self.table_name} SET cache_value=? WHERE cache_key=?", (pickle.dumps(value), key))
504 self._num_entries_to_be_committed += 1
505 cursor.close()
506 finally:
507 self._conn_mutex.release()
509 self._update_hook.handle_update()
511 def _execute(self, cursor, *query):
512 try:
513 cursor.execute(*query)
514 except sqlite3.DatabaseError as e:
515 raise Exception(f"Error executing query for {self.path}: {e}")
517 def get(self, key: TKey) -> Optional[TValue]:
518 self._conn_mutex.acquire()
519 try:
520 cursor = self.conn.cursor()
521 key = self._key_db_value(key)
522 self._execute(cursor, f"SELECT cache_value FROM {self.table_name} WHERE cache_key=?", (key,))
523 row = cursor.fetchone()
524 cursor.close()
525 if row is None:
526 return None
527 return pickle.loads(row[0])
528 finally:
529 self._conn_mutex.release()
531 def __len__(self):
532 self._conn_mutex.acquire()
533 try:
534 cursor = self.conn.cursor()
535 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
536 cnt = cursor.fetchone()[0]
537 cursor.close()
538 return cnt
539 finally:
540 self._conn_mutex.release()
542 def iter_items(self):
543 self._conn_mutex.acquire()
544 try:
545 cursor = self.conn.cursor()
546 cursor.execute(f"SELECT cache_key, cache_value FROM {self.table_name}")
547 while True:
548 row = cursor.fetchone()
549 if row is None:
550 break
551 yield row[0], pickle.loads(row[1])
552 cursor.close()
553 finally:
554 self._conn_mutex.release()
557class SqlitePersistentList(PersistentList):
558 def __init__(self, path):
559 self.keyValueCache = SqlitePersistentKeyValueCache(path, key_type=SqlitePersistentKeyValueCache.KeyType.INTEGER)
560 self.nextKey = len(self.keyValueCache)
562 def append(self, item):
563 self.keyValueCache.set(self.nextKey, item)
564 self.nextKey += 1
566 def iter_items(self):
567 for item in self.keyValueCache.iter_items():
568 yield item[1]
571class CachedValueProviderMixin(Generic[TKey, TValue, TData], ABC):
572 """
573 Represents a value provider that can provide values associated with (hashable) keys via a cache or, if
574 cached values are not yet present, by computing them.
575 """
576 def __init__(self, cache: Optional[KeyValueCache[TKey, TValue]] = None,
577 cache_factory: Optional[Callable[[], KeyValueCache[TKey, TValue]]] = None, persist_cache=False, box_values=False):
578 """
579 :param cache: the cache to use or None. If None, caching will be disabled
580 :param cache_factory: a factory with which to create the cache (or recreate it after unpickling if `persistCache` is False, in which
581 case this factory must be picklable)
582 :param persist_cache: whether to persist the cache when pickling
583 :param box_values: whether to box values, such that None is admissible as a value
584 """
585 self._persistCache = persist_cache
586 self._boxValues = box_values
587 self._cache = cache
588 self._cacheFactory = cache_factory
589 if self._cache is None and cache_factory is not None:
590 self._cache = cache_factory()
592 def __getstate__(self):
593 if not self._persistCache:
594 d = self.__dict__.copy()
595 d["_cache"] = None
596 return d
597 return self.__dict__
599 def __setstate__(self, state):
600 setstate(CachedValueProviderMixin, self, state, renamed_properties={"persistCache": "_persistCache"})
601 if not self._persistCache and self._cacheFactory is not None:
602 self._cache = self._cacheFactory()
604 def _provide_value(self, key, data: Optional[TData] = None):
605 """
606 Provides the value for the key by retrieving the associated value from the cache or, if no entry in the
607 cache is found, by computing the value via _computeValue
609 :param key: the key for which to provide the value
610 :param data: optional data required to compute a value
611 :return: the retrieved or computed value
612 """
613 if self._cache is None:
614 return self._compute_value(key, data)
615 value = self._cache.get(key)
616 if value is None:
617 value = self._compute_value(key, data)
618 self._cache.set(key, value if not self._boxValues else BoxedValue(value))
619 else:
620 if self._boxValues:
621 value: BoxedValue[TValue]
622 value = value.value
623 return value
625 @abstractmethod
626 def _compute_value(self, key: TKey, data: Optional[TData]) -> TValue:
627 """
628 Computes the value for the given key
630 :param key: the key for which to compute the value
631 :return: the computed value
632 """
633 pass
636def cached(fn: Callable[[], T], pickle_path, function_name=None, validity_check_fn: Optional[Callable[[T], bool]] = None,
637 backend="pickle", protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None) -> T:
638 """
639 Calls the given function unless its result is already cached (in a pickle), in which case it will read the cached result
640 and return it.
642 Rather than directly calling this function, consider using the decorator variant :func:`pickle_cached`.
644 :param fn: the function whose result is to be cached
645 :param pickle_path: the path in which to store the cached result
646 :param function_name: the name of the function fn (for the case where its __name__ attribute is not
647 informative)
648 :param validity_check_fn: an optional function to call in order to check whether a cached result is still valid;
649 the function shall return True if the result is still valid and false otherwise. If a cached result is invalid,
650 the function fn is called to compute the result and the cached result is updated.
651 :param backend: pickle or joblib
652 :param protocol: the pickle protocol version
653 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result
654 :param version: if not None, previously persisted data will only be returned if it was stored with the same version
655 :return: the result (either obtained from the cache or the function)
656 """
657 if function_name is None:
658 function_name = fn.__name__
660 def call_fn_and_cache_result():
661 res = fn()
662 log.info(f"Saving cached result in {pickle_path}")
663 if version is not None:
664 persisted_res = {"__cacheVersion": version, "obj": res}
665 else:
666 persisted_res = res
667 dump_pickle(persisted_res, pickle_path, backend=backend, protocol=protocol)
668 return res
670 if os.path.exists(pickle_path):
671 if load:
672 log.info(f"Loading cached result of function '{function_name}' from {pickle_path}")
673 result = load_pickle(pickle_path, backend=backend)
674 if validity_check_fn is not None:
675 if not validity_check_fn(result):
676 log.info(f"Cached result is no longer valid, recomputing ...")
677 result = call_fn_and_cache_result()
678 if version is not None:
679 cached_version = None
680 if type(result) == dict:
681 cached_version = result.get("__cacheVersion")
682 if cached_version != version:
683 log.info(f"Cached result has incorrect version ({cached_version}, expected {version}), recomputing ...")
684 result = call_fn_and_cache_result()
685 else:
686 result = result["obj"]
687 return result
688 else:
689 log.info(f"Ignoring previously stored result in {pickle_path}, calling function '{function_name}' ...")
690 return call_fn_and_cache_result()
691 else:
692 log.info(f"No cached result found in {pickle_path}, calling function '{function_name}' ...")
693 return call_fn_and_cache_result()
696def pickle_cached(cache_base_path: str, filename_prefix: str = None, filename: str = None, backend="pickle",
697 protocol=pickle.HIGHEST_PROTOCOL, load=True, version=None):
698 """
699 Function decorator for caching function results via pickle.
701 Add this decorator to any function to cache its results in pickle files.
702 The function may have arguments, in which case the cache will be specific to the actual arguments
703 by computing a hash code from their pickled representation.
705 :param cache_base_path: the directory where the pickle cache file will be stored
706 :param filename_prefix: a prefix of the name of the cache file to be created, to which the function name and, where applicable,
707 a hash code of the function arguments as well as the extension ".cache.pickle" will be appended.
708 The prefix need not end in a separator, as "-" will automatically be added between filename components.
709 :param filename: the full file name of the cache file to be created; if the function takes arguments, the filename must
710 contain a placeholder '%s' for the argument hash
711 :param backend: the serialisation backend to use (see dumpPickle)
712 :param protocol: the pickle protocol version to use
713 :param load: whether to load a previously persisted result; if False, do not load an old result but store the newly computed result
714 :param version: if not None, previously persisted data will only be returned if it was stored with the same version
715 """
716 os.makedirs(cache_base_path, exist_ok=True)
718 if filename_prefix is None:
719 filename_prefix = ""
720 else:
721 filename_prefix += "-"
723 def decorator(fn: Callable, *_args, **_kwargs):
725 @wraps(fn)
726 def wrapped(*args, **kwargs):
727 hash_code_str = None
728 have_args = len(args) > 0 or len(kwargs) > 0
729 if have_args:
730 hash_code_str = pickle_hash((args, kwargs))
731 if filename is None:
732 pickle_filename = filename_prefix + fn.__qualname__.replace(".<locals>.", ".")
733 if hash_code_str is not None:
734 pickle_filename += "-" + hash_code_str
735 pickle_filename += ".cache.pickle"
736 else:
737 if hash_code_str is not None:
738 if "%s" not in filename:
739 raise Exception("Function called with arguments but full cache filename contains no placeholder (%s) "
740 "for argument hash")
741 pickle_filename = filename % hash_code_str
742 else:
743 if "%s" in filename:
744 raise Exception("Function without arguments but full cache filename with placeholder (%s) was specified")
745 pickle_filename = filename
746 pickle_path = os.path.join(cache_base_path, pickle_filename)
747 return cached(lambda: fn(*args, **kwargs), pickle_path, function_name=fn.__name__, backend=backend, load=load,
748 version=version, protocol=protocol)
750 return wrapped
752 return decorator
755PickleCached = pickle_cached # for backward compatibility
758class LoadSaveInterface(ABC):
759 @abstractmethod
760 def save(self, path: str) -> None:
761 pass
763 @classmethod
764 @abstractmethod
765 def load(cls: T, path: str) -> T:
766 pass
769class PickleLoadSaveMixin(LoadSaveInterface):
770 def save(self, path: Union[str, Path], backend="pickle"):
771 """
772 Saves the instance as pickle
774 :param path:
775 :param backend: pickle, cloudpickle, or joblib
776 """
777 dump_pickle(self, path, backend=backend)
779 @classmethod
780 def load(cls, path: Union[str, Path], backend="pickle"):
781 """
782 Loads a class instance from pickle
784 :param path:
785 :param backend: pickle, cloudpickle, or joblib
786 :return: instance of the present class
787 """
788 log.info(f"Loading instance of {cls} from {path}")
789 result = load_pickle(path, backend=backend)
790 if not isinstance(result, cls):
791 raise Exception(f"Excepted instance of {cls}, instead got: {result.__class__.__name__}")
792 return result
795class LRUCache(KeyValueCache[THashableKey, TValue], Generic[THashableKey, TValue]):
796 def __init__(self, capacity: int) -> None:
797 self._cache = OrderedDict()
798 self._capacity = capacity
800 def get(self, key: THashableKey) -> TValue:
801 if key not in self._cache:
802 return None
803 self._cache.move_to_end(key)
804 return self._cache[key]
806 def set(self, key: THashableKey, value: TValue):
807 if key in self._cache:
808 self._cache.move_to_end(key)
809 self._cache[key] = value
810 if len(self._cache) > self._capacity:
811 self._cache.popitem(last=False)
813 def __len__(self) -> int:
814 return len(self._cache)
816 def clear(self) -> None:
817 self._cache.clear()