Source code for sensai.util.pickle

import logging
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Union

from .io import S3Object, is_s3_path

log = logging.getLogger(__name__)


[docs]def load_pickle(path: Union[str, Path], backend="pickle"): if isinstance(path, Path): path = str(path) def read_file(f): def _load_with_error_log(loader: Callable): try: return loader(f) except Exception as e: log.error(f"Error loading {path}") raise e if backend == "pickle": return _load_with_error_log(pickle.load) elif backend == "cloudpickle": import cloudpickle return _load_with_error_log(cloudpickle.load) elif backend == "joblib": import joblib return joblib.load(f) else: raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'") if is_s3_path(path): return read_file(S3Object(path).open_file("rb")) with open(path, "rb") as f: return read_file(f)
[docs]def dump_pickle(obj, pickle_path: Union[str, Path], backend="pickle", protocol=pickle.HIGHEST_PROTOCOL): if isinstance(pickle_path, Path): pickle_path = str(pickle_path) def open_file(): if is_s3_path(pickle_path): return S3Object(pickle_path).open_file("wb") else: return open(pickle_path, "wb") dir_name = os.path.dirname(pickle_path) if dir_name != "": os.makedirs(dir_name, exist_ok=True) with open_file() as f: if backend == "pickle": try: pickle.dump(obj, f, protocol=protocol) except AttributeError as e: failing_paths = PickleFailureDebugger.debug_failure(obj) raise AttributeError(f"Cannot pickle paths {failing_paths} of {obj}: {str(e)}") elif backend == "joblib": import joblib joblib.dump(obj, f, protocol=protocol) elif backend == "cloudpickle": import cloudpickle cloudpickle.dump(obj, f, protocol=protocol) else: raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'")
[docs]class PickleFailureDebugger: """ A collection of methods for testing whether objects can be pickled and logging useful infos in case they cannot """ enabled = False # global flag controlling the behaviour of logFailureIfEnabled @classmethod def _debug_failure(cls, obj, path, failures, handled_object_ids): if id(obj) in handled_object_ids: return handled_object_ids.add(id(obj)) try: pickle.dumps(obj) except: # determine dictionary of children to investigate (if any) if hasattr( obj, '__dict__' ): # Because of strange behaviour of getstate, here try-except is used instead of if-else try: # Because of strange behaviour of getattr(_, '__getstate__'), we here use try-except d = obj.__getstate__() if type(d) != dict: d = {"state": d} except: d = obj.__dict__ elif type(obj) == dict: d = obj elif type(obj) in (list, tuple, set): d = dict(enumerate(obj)) else: d = {} # recursively test children have_failed_child = False for key, child in d.items(): child_path = list(path) + [f"{key}[{child.__class__.__name__}]"] have_failed_child = cls._debug_failure( child, child_path, failures, handled_object_ids ) or have_failed_child if not have_failed_child: failures.append(path) return True else: return False
[docs] @classmethod def debug_failure(cls, obj) -> List[str]: """ Recursively tries to pickle the given object and returns a list of failed paths :param obj: the object for which to recursively test pickling :return: a list of object paths that failed to pickle """ handled_object_ids = set() failures = [] cls._debug_failure(obj, [obj.__class__.__name__], failures, handled_object_ids) return [".".join(l) for l in failures]
[docs] @classmethod def log_failure_if_enabled(cls, obj, context_info: str = None): """ If the class flag 'enabled' is set to true, the pickling of the given object is recursively tested and the results are logged at error level if there are problems and info level otherwise. If the flag is disabled, no action is taken. :param obj: the object for which to recursively test pickling :param context_info: optional additional string to be included in the log message """ if cls.enabled: failures = cls.debug_failure(obj) prefix = f"Picklability analysis for {obj}" if context_info is not None: prefix += " (context: %s)" % context_info if len(failures) > 0: log.error(f"{prefix}: pickling would result in failures due to: {failures}") else: log.info(f"{prefix}: is picklable")
[docs]def setstate( cls, obj, state: Dict[str, Any], renamed_properties: Dict[str, str] = None, new_optional_properties: List[str] = None, new_default_properties: Dict[str, Any] = None, removed_properties: List[str] = None ) -> None: """ Helper function for safe implementations of __setstate__ in classes, which appropriately handles the cases where a parent class already implements __setstate__ and where it does not. Call this function whenever you would actually like to call the super-class' implementation. Unfortunately, __setstate__ is not implemented in object, rendering super().__setstate__(state) invalid in the general case. :param cls: the class in which you are implementing __setstate__ :param obj: the instance of cls :param state: the state dictionary :param renamed_properties: a mapping from old property names to new property names :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialised with None :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present :param removed_properties: a list of names of properties that are no longer being used """ # handle new/changed properties if renamed_properties is not None: for mOld, mNew in renamed_properties.items(): if mOld in state: state[mNew] = state[mOld] del state[mOld] if new_optional_properties is not None: for mNew in new_optional_properties: if mNew not in state: state[mNew] = None if new_default_properties is not None: for mNew, mValue in new_default_properties.items(): if mNew not in state: state[mNew] = mValue if removed_properties is not None: for p in removed_properties: if p in state: del state[p] # call super implementation, if any s = super(cls, obj) if hasattr(s, '__setstate__'): s.__setstate__(state) else: obj.__dict__ = state
[docs]def getstate( cls, obj, transient_properties: Iterable[str] = None, excluded_properties: Iterable[str] = None, override_properties: Dict[str, Any] = None, excluded_default_properties: Dict[str, Any] = None ) -> Dict[str, Any]: """ Helper function for safe implementations of __getstate__ in classes, which appropriately handles the cases where a parent class already implements __getstate__ and where it does not. Call this function whenever you would actually like to call the super-class' implementation. Unfortunately, __getstate__ is not implemented in object, rendering super().__getstate__() invalid in the general case. :param cls: the class in which you are implementing __getstate__ :param obj: the instance of cls :param transient_properties: transient properties which be set to None in serialisations :param excluded_properties: properties which shall be completely removed from serialisations :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set; use this to set a fixed value for an existing property or to add a completely new property :param excluded_default_properties: properties which shall be completely removed from serialisations, if they are set to the given default value :return: the state dictionary, which may be modified by the receiver """ s = super(cls, obj) if hasattr(s, '__getstate__'): d = s.__getstate__() else: d = obj.__dict__.copy() if transient_properties is not None: for p in transient_properties: if p in d: d[p] = None if excluded_properties is not None: for p in excluded_properties: if p in d: del d[p] if override_properties is not None: for k, v in override_properties.items(): d[k] = v if excluded_default_properties is not None: for p, v in excluded_default_properties.items(): if p in d and d[p] == v: del d[p] return d