Coverage for src/sensai/util/pickle.py: 35%
153 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import logging
2import os
3import pickle
4from copy import copy
5from pathlib import Path
6from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
8from .io import S3Object, is_s3_path
10log = logging.getLogger(__name__)
13def load_pickle(path: Union[str, Path], backend="pickle"):
14 if isinstance(path, Path):
15 path = str(path)
17 def read_file(f):
18 def _load_with_error_log(loader: Callable):
19 try:
20 return loader(f)
21 except Exception as e:
22 log.error(f"Error loading {path}")
23 raise e
25 if backend == "pickle":
26 return _load_with_error_log(pickle.load)
27 elif backend == "cloudpickle":
28 import cloudpickle
29 return _load_with_error_log(cloudpickle.load)
30 elif backend == "joblib":
31 import joblib
32 return joblib.load(f)
33 else:
34 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'")
36 if is_s3_path(path):
37 return read_file(S3Object(path).open_file("rb"))
38 with open(path, "rb") as f:
39 return read_file(f)
42def dump_pickle(obj, pickle_path: Union[str, Path], backend="pickle", protocol=pickle.HIGHEST_PROTOCOL):
43 if isinstance(pickle_path, Path):
44 pickle_path = str(pickle_path)
46 def open_file():
47 if is_s3_path(pickle_path):
48 return S3Object(pickle_path).open_file("wb")
49 else:
50 return open(pickle_path, "wb")
52 dir_name = os.path.dirname(pickle_path)
53 if dir_name != "":
54 os.makedirs(dir_name, exist_ok=True)
55 with open_file() as f:
56 if backend == "pickle":
57 try:
58 pickle.dump(obj, f, protocol=protocol)
59 except AttributeError as e:
60 failing_paths = PickleFailureDebugger.debug_failure(obj)
61 raise AttributeError(f"Cannot pickle paths {failing_paths} of {obj}: {str(e)}")
62 elif backend == "joblib":
63 import joblib
64 joblib.dump(obj, f, protocol=protocol)
65 elif backend == "cloudpickle":
66 import cloudpickle
67 cloudpickle.dump(obj, f, protocol=protocol)
68 else:
69 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'")
72class PickleFailureDebugger:
73 """
74 A collection of methods for testing whether objects can be pickled and logging useful infos in case they cannot
75 """
77 enabled = False # global flag controlling the behaviour of logFailureIfEnabled
79 @classmethod
80 def _debug_failure(cls, obj, path, failures, handled_object_ids):
81 if id(obj) in handled_object_ids:
82 return
83 handled_object_ids.add(id(obj))
85 try:
86 pickle.dumps(obj)
87 except:
88 # determine dictionary of children to investigate (if any)
89 if hasattr(
90 obj, '__dict__'
91 ): # Because of strange behaviour of getstate, here try-except is used instead of if-else
92 try: # Because of strange behaviour of getattr(_, '__getstate__'), we here use try-except
93 d = obj.__getstate__()
94 if type(d) != dict:
95 d = {"state": d}
96 except:
97 d = obj.__dict__
98 elif type(obj) == dict:
99 d = obj
100 elif type(obj) in (list, tuple, set):
101 d = dict(enumerate(obj))
102 else:
103 d = {}
105 # recursively test children
106 have_failed_child = False
107 for key, child in d.items():
108 child_path = list(path) + [f"{key}[{child.__class__.__name__}]"]
109 have_failed_child = cls._debug_failure(
110 child, child_path, failures, handled_object_ids
111 ) or have_failed_child
113 if not have_failed_child:
114 failures.append(path)
116 return True
117 else:
118 return False
120 @classmethod
121 def debug_failure(cls, obj) -> List[str]:
122 """
123 Recursively tries to pickle the given object and returns a list of failed paths
125 :param obj: the object for which to recursively test pickling
126 :return: a list of object paths that failed to pickle
127 """
128 handled_object_ids = set()
129 failures = []
130 cls._debug_failure(obj, [obj.__class__.__name__], failures, handled_object_ids)
131 return [".".join(l) for l in failures]
133 @classmethod
134 def log_failure_if_enabled(cls, obj, context_info: str = None):
135 """
136 If the class flag 'enabled' is set to true, the pickling of the given object is
137 recursively tested and the results are logged at error level if there are problems and
138 info level otherwise.
139 If the flag is disabled, no action is taken.
141 :param obj: the object for which to recursively test pickling
142 :param context_info: optional additional string to be included in the log message
143 """
144 if cls.enabled:
145 failures = cls.debug_failure(obj)
146 prefix = f"Picklability analysis for {obj}"
147 if context_info is not None:
148 prefix += " (context: %s)" % context_info
149 if len(failures) > 0:
150 log.error(f"{prefix}: pickling would result in failures due to: {failures}")
151 else:
152 log.info(f"{prefix}: is picklable")
155class PersistableObject:
156 """
157 Base class which can be used for objects that shall support being persisted via pickle.
159 IMPORTANT:
160 The implementations correspond to the default behaviour of pickle for the case where an object has a non-empty
161 set of attributes. However, for the case where the set of attributes can be empty adding the explicit
162 implementation of `__getstate__` is crucial in ensuring that `__setstate__` will be called upon unpickling.
163 So if an object initially has no attributes and is persisted in that state, then any future refactorings
164 cannot be handled via `__setstate__` by default, but they can when using this class.
165 """
166 def __getstate__(self):
167 return self.__dict__
169 def __setstate__(self, state):
170 self.__dict__ = state
173def setstate(
174 cls,
175 obj,
176 state: Dict[str, Any],
177 renamed_properties: Dict[str, Union[str, Tuple[str, Callable[[Dict[str, Any]], Any]]]] = None,
178 new_optional_properties: List[str] = None,
179 new_default_properties: Dict[str, Any] = None,
180 removed_properties: List[str] = None,
181) -> None:
182 """
183 Helper function for safe implementations of __setstate__ in classes, which appropriately handles the cases where
184 a parent class already implements __setstate__ and where it does not. Call this function whenever you would actually
185 like to call the super-class' implementation.
186 Unfortunately, __setstate__ is not implemented in object, rendering super().__setstate__(state) invalid in the general case.
188 :param cls: the class in which you are implementing __setstate__
189 :param obj: the instance of cls
190 :param state: the state dictionary
191 :param renamed_properties: can be used for renaming as well as for assigning new values.
192 If passed must map an old property name to either a new property name or
193 to tuple of a new property name and a function that computes the new value from the state dictionary.
194 :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialised with None
195 :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present
196 :param removed_properties: a list of names of properties that are no longer being used
197 """
198 # handle new/changed properties
199 if renamed_properties is not None:
200 # `new` can either be a string or a tuple of a string and a function
201 for old_name, new in renamed_properties.items():
202 if old_name in state:
203 if isinstance(new, str):
204 new_name, new_value = new, state[old_name]
205 else:
206 new_name, new_value = new[0], new[1](state)
208 del state[old_name]
209 state[new_name] = new_value
211 if new_optional_properties is not None:
212 for mNew in new_optional_properties:
213 if mNew not in state:
214 state[mNew] = None
215 if new_default_properties is not None:
216 for mNew, mValue in new_default_properties.items():
217 if mNew not in state:
218 state[mNew] = mValue
219 if removed_properties is not None:
220 for p in removed_properties:
221 if p in state:
222 del state[p]
223 # call super implementation, if any
224 s = super(cls, obj)
225 if hasattr(s, '__setstate__'):
226 s.__setstate__(state)
227 else:
228 obj.__dict__ = state
231def getstate(
232 cls,
233 obj,
234 transient_properties: Iterable[str] = None,
235 excluded_properties: Iterable[str] = None,
236 override_properties: Dict[str, Any] = None,
237 excluded_default_properties: Dict[str, Any] = None
238) -> Dict[str, Any]:
239 """
240 Helper function for safe implementations of __getstate__ in classes, which appropriately handles the cases where
241 a parent class already implements __getstate__ and where it does not. Call this function whenever you would actually
242 like to call the super-class' implementation.
243 Unfortunately, __getstate__ is not implemented in object, rendering super().__getstate__() invalid in the general case.
245 :param cls: the class in which you are implementing __getstate__
246 :param obj: the instance of cls
247 :param transient_properties: transient properties which be set to None in serialisations
248 :param excluded_properties: properties which shall be completely removed from serialisations
249 :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set;
250 use this to set a fixed value for an existing property or to add a completely new property
251 :param excluded_default_properties: properties which shall be completely removed from serialisations, if they are set
252 to the given default value
253 :return: the state dictionary, which may be modified by the receiver
254 """
255 s = super(cls, obj)
256 if hasattr(s, '__getstate__'):
257 d = s.__getstate__()
258 else:
259 d = obj.__dict__
260 d = copy(d)
261 if transient_properties is not None:
262 for p in transient_properties:
263 if p in d:
264 d[p] = None
265 if excluded_properties is not None:
266 for p in excluded_properties:
267 if p in d:
268 del d[p]
269 if override_properties is not None:
270 for k, v in override_properties.items():
271 d[k] = v
272 if excluded_default_properties is not None:
273 for p, v in excluded_default_properties.items():
274 if p in d and d[p] == v:
275 del d[p]
276 return d