Coverage for src/sensai/util/pickle.py: 33%
143 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import logging
2import os
3import pickle
4from pathlib import Path
5from typing import Any, Callable, Dict, Iterable, List, Union
7from .io import S3Object, is_s3_path
9log = logging.getLogger(__name__)
12def load_pickle(path: Union[str, Path], backend="pickle"):
13 if isinstance(path, Path):
14 path = str(path)
16 def read_file(f):
17 def _load_with_error_log(loader: Callable):
18 try:
19 return loader(f)
20 except Exception as e:
21 log.error(f"Error loading {path}")
22 raise e
24 if backend == "pickle":
25 return _load_with_error_log(pickle.load)
26 elif backend == "cloudpickle":
27 import cloudpickle
28 return _load_with_error_log(cloudpickle.load)
29 elif backend == "joblib":
30 import joblib
31 return joblib.load(f)
32 else:
33 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'")
35 if is_s3_path(path):
36 return read_file(S3Object(path).open_file("rb"))
37 with open(path, "rb") as f:
38 return read_file(f)
41def dump_pickle(obj, pickle_path: Union[str, Path], backend="pickle", protocol=pickle.HIGHEST_PROTOCOL):
42 if isinstance(pickle_path, Path):
43 pickle_path = str(pickle_path)
45 def open_file():
46 if is_s3_path(pickle_path):
47 return S3Object(pickle_path).open_file("wb")
48 else:
49 return open(pickle_path, "wb")
51 dir_name = os.path.dirname(pickle_path)
52 if dir_name != "":
53 os.makedirs(dir_name, exist_ok=True)
54 with open_file() as f:
55 if backend == "pickle":
56 try:
57 pickle.dump(obj, f, protocol=protocol)
58 except AttributeError as e:
59 failing_paths = PickleFailureDebugger.debug_failure(obj)
60 raise AttributeError(f"Cannot pickle paths {failing_paths} of {obj}: {str(e)}")
61 elif backend == "joblib":
62 import joblib
63 joblib.dump(obj, f, protocol=protocol)
64 elif backend == "cloudpickle":
65 import cloudpickle
66 cloudpickle.dump(obj, f, protocol=protocol)
67 else:
68 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'")
71class PickleFailureDebugger:
72 """
73 A collection of methods for testing whether objects can be pickled and logging useful infos in case they cannot
74 """
76 enabled = False # global flag controlling the behaviour of logFailureIfEnabled
78 @classmethod
79 def _debug_failure(cls, obj, path, failures, handled_object_ids):
80 if id(obj) in handled_object_ids:
81 return
82 handled_object_ids.add(id(obj))
84 try:
85 pickle.dumps(obj)
86 except:
87 # determine dictionary of children to investigate (if any)
88 if hasattr(
89 obj, '__dict__'
90 ): # Because of strange behaviour of getstate, here try-except is used instead of if-else
91 try: # Because of strange behaviour of getattr(_, '__getstate__'), we here use try-except
92 d = obj.__getstate__()
93 if type(d) != dict:
94 d = {"state": d}
95 except:
96 d = obj.__dict__
97 elif type(obj) == dict:
98 d = obj
99 elif type(obj) in (list, tuple, set):
100 d = dict(enumerate(obj))
101 else:
102 d = {}
104 # recursively test children
105 have_failed_child = False
106 for key, child in d.items():
107 child_path = list(path) + [f"{key}[{child.__class__.__name__}]"]
108 have_failed_child = cls._debug_failure(
109 child, child_path, failures, handled_object_ids
110 ) or have_failed_child
112 if not have_failed_child:
113 failures.append(path)
115 return True
116 else:
117 return False
119 @classmethod
120 def debug_failure(cls, obj) -> List[str]:
121 """
122 Recursively tries to pickle the given object and returns a list of failed paths
124 :param obj: the object for which to recursively test pickling
125 :return: a list of object paths that failed to pickle
126 """
127 handled_object_ids = set()
128 failures = []
129 cls._debug_failure(obj, [obj.__class__.__name__], failures, handled_object_ids)
130 return [".".join(l) for l in failures]
132 @classmethod
133 def log_failure_if_enabled(cls, obj, context_info: str = None):
134 """
135 If the class flag 'enabled' is set to true, the pickling of the given object is
136 recursively tested and the results are logged at error level if there are problems and
137 info level otherwise.
138 If the flag is disabled, no action is taken.
140 :param obj: the object for which to recursively test pickling
141 :param context_info: optional additional string to be included in the log message
142 """
143 if cls.enabled:
144 failures = cls.debug_failure(obj)
145 prefix = f"Picklability analysis for {obj}"
146 if context_info is not None:
147 prefix += " (context: %s)" % context_info
148 if len(failures) > 0:
149 log.error(f"{prefix}: pickling would result in failures due to: {failures}")
150 else:
151 log.info(f"{prefix}: is picklable")
154def setstate(
155 cls,
156 obj,
157 state: Dict[str, Any],
158 renamed_properties: Dict[str, str] = None,
159 new_optional_properties: List[str] = None,
160 new_default_properties: Dict[str, Any] = None,
161 removed_properties: List[str] = None
162) -> None:
163 """
164 Helper function for safe implementations of __setstate__ in classes, which appropriately handles the cases where
165 a parent class already implements __setstate__ and where it does not. Call this function whenever you would actually
166 like to call the super-class' implementation.
167 Unfortunately, __setstate__ is not implemented in object, rendering super().__setstate__(state) invalid in the general case.
169 :param cls: the class in which you are implementing __setstate__
170 :param obj: the instance of cls
171 :param state: the state dictionary
172 :param renamed_properties: a mapping from old property names to new property names
173 :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialised with None
174 :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present
175 :param removed_properties: a list of names of properties that are no longer being used
176 """
177 # handle new/changed properties
178 if renamed_properties is not None:
179 for mOld, mNew in renamed_properties.items():
180 if mOld in state:
181 state[mNew] = state[mOld]
182 del state[mOld]
183 if new_optional_properties is not None:
184 for mNew in new_optional_properties:
185 if mNew not in state:
186 state[mNew] = None
187 if new_default_properties is not None:
188 for mNew, mValue in new_default_properties.items():
189 if mNew not in state:
190 state[mNew] = mValue
191 if removed_properties is not None:
192 for p in removed_properties:
193 if p in state:
194 del state[p]
195 # call super implementation, if any
196 s = super(cls, obj)
197 if hasattr(s, '__setstate__'):
198 s.__setstate__(state)
199 else:
200 obj.__dict__ = state
203def getstate(
204 cls,
205 obj,
206 transient_properties: Iterable[str] = None,
207 excluded_properties: Iterable[str] = None,
208 override_properties: Dict[str, Any] = None,
209 excluded_default_properties: Dict[str, Any] = None
210) -> Dict[str, Any]:
211 """
212 Helper function for safe implementations of __getstate__ in classes, which appropriately handles the cases where
213 a parent class already implements __getstate__ and where it does not. Call this function whenever you would actually
214 like to call the super-class' implementation.
215 Unfortunately, __getstate__ is not implemented in object, rendering super().__getstate__() invalid in the general case.
217 :param cls: the class in which you are implementing __getstate__
218 :param obj: the instance of cls
219 :param transient_properties: transient properties which be set to None in serialisations
220 :param excluded_properties: properties which shall be completely removed from serialisations
221 :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set;
222 use this to set a fixed value for an existing property or to add a completely new property
223 :param excluded_default_properties: properties which shall be completely removed from serialisations, if they are set
224 to the given default value
225 :return: the state dictionary, which may be modified by the receiver
226 """
227 s = super(cls, obj)
228 if hasattr(s, '__getstate__'):
229 d = s.__getstate__()
230 else:
231 d = obj.__dict__.copy()
232 if transient_properties is not None:
233 for p in transient_properties:
234 if p in d:
235 d[p] = None
236 if excluded_properties is not None:
237 for p in excluded_properties:
238 if p in d:
239 del d[p]
240 if override_properties is not None:
241 for k, v in override_properties.items():
242 d[k] = v
243 if excluded_default_properties is not None:
244 for p, v in excluded_default_properties.items():
245 if p in d and d[p] == v:
246 del d[p]
247 return d