Coverage for src/sensai/util/pickle.py: 35%

153 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import logging 

2import os 

3import pickle 

4from copy import copy 

5from pathlib import Path 

6from typing import Any, Callable, Dict, Iterable, List, Tuple, Union 

7 

8from .io import S3Object, is_s3_path 

9 

10log = logging.getLogger(__name__) 

11 

12 

13def load_pickle(path: Union[str, Path], backend="pickle"): 

14 if isinstance(path, Path): 

15 path = str(path) 

16 

17 def read_file(f): 

18 def _load_with_error_log(loader: Callable): 

19 try: 

20 return loader(f) 

21 except Exception as e: 

22 log.error(f"Error loading {path}") 

23 raise e 

24 

25 if backend == "pickle": 

26 return _load_with_error_log(pickle.load) 

27 elif backend == "cloudpickle": 

28 import cloudpickle 

29 return _load_with_error_log(cloudpickle.load) 

30 elif backend == "joblib": 

31 import joblib 

32 return joblib.load(f) 

33 else: 

34 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'") 

35 

36 if is_s3_path(path): 

37 return read_file(S3Object(path).open_file("rb")) 

38 with open(path, "rb") as f: 

39 return read_file(f) 

40 

41 

42def dump_pickle(obj, pickle_path: Union[str, Path], backend="pickle", protocol=pickle.HIGHEST_PROTOCOL): 

43 if isinstance(pickle_path, Path): 

44 pickle_path = str(pickle_path) 

45 

46 def open_file(): 

47 if is_s3_path(pickle_path): 

48 return S3Object(pickle_path).open_file("wb") 

49 else: 

50 return open(pickle_path, "wb") 

51 

52 dir_name = os.path.dirname(pickle_path) 

53 if dir_name != "": 

54 os.makedirs(dir_name, exist_ok=True) 

55 with open_file() as f: 

56 if backend == "pickle": 

57 try: 

58 pickle.dump(obj, f, protocol=protocol) 

59 except AttributeError as e: 

60 failing_paths = PickleFailureDebugger.debug_failure(obj) 

61 raise AttributeError(f"Cannot pickle paths {failing_paths} of {obj}: {str(e)}") 

62 elif backend == "joblib": 

63 import joblib 

64 joblib.dump(obj, f, protocol=protocol) 

65 elif backend == "cloudpickle": 

66 import cloudpickle 

67 cloudpickle.dump(obj, f, protocol=protocol) 

68 else: 

69 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'") 

70 

71 

72class PickleFailureDebugger: 

73 """ 

74 A collection of methods for testing whether objects can be pickled and logging useful infos in case they cannot 

75 """ 

76 

77 enabled = False # global flag controlling the behaviour of logFailureIfEnabled 

78 

79 @classmethod 

80 def _debug_failure(cls, obj, path, failures, handled_object_ids): 

81 if id(obj) in handled_object_ids: 

82 return 

83 handled_object_ids.add(id(obj)) 

84 

85 try: 

86 pickle.dumps(obj) 

87 except: 

88 # determine dictionary of children to investigate (if any) 

89 if hasattr( 

90 obj, '__dict__' 

91 ): # Because of strange behaviour of getstate, here try-except is used instead of if-else 

92 try: # Because of strange behaviour of getattr(_, '__getstate__'), we here use try-except 

93 d = obj.__getstate__() 

94 if type(d) != dict: 

95 d = {"state": d} 

96 except: 

97 d = obj.__dict__ 

98 elif type(obj) == dict: 

99 d = obj 

100 elif type(obj) in (list, tuple, set): 

101 d = dict(enumerate(obj)) 

102 else: 

103 d = {} 

104 

105 # recursively test children 

106 have_failed_child = False 

107 for key, child in d.items(): 

108 child_path = list(path) + [f"{key}[{child.__class__.__name__}]"] 

109 have_failed_child = cls._debug_failure( 

110 child, child_path, failures, handled_object_ids 

111 ) or have_failed_child 

112 

113 if not have_failed_child: 

114 failures.append(path) 

115 

116 return True 

117 else: 

118 return False 

119 

120 @classmethod 

121 def debug_failure(cls, obj) -> List[str]: 

122 """ 

123 Recursively tries to pickle the given object and returns a list of failed paths 

124 

125 :param obj: the object for which to recursively test pickling 

126 :return: a list of object paths that failed to pickle 

127 """ 

128 handled_object_ids = set() 

129 failures = [] 

130 cls._debug_failure(obj, [obj.__class__.__name__], failures, handled_object_ids) 

131 return [".".join(l) for l in failures] 

132 

133 @classmethod 

134 def log_failure_if_enabled(cls, obj, context_info: str = None): 

135 """ 

136 If the class flag 'enabled' is set to true, the pickling of the given object is 

137 recursively tested and the results are logged at error level if there are problems and 

138 info level otherwise. 

139 If the flag is disabled, no action is taken. 

140 

141 :param obj: the object for which to recursively test pickling 

142 :param context_info: optional additional string to be included in the log message 

143 """ 

144 if cls.enabled: 

145 failures = cls.debug_failure(obj) 

146 prefix = f"Picklability analysis for {obj}" 

147 if context_info is not None: 

148 prefix += " (context: %s)" % context_info 

149 if len(failures) > 0: 

150 log.error(f"{prefix}: pickling would result in failures due to: {failures}") 

151 else: 

152 log.info(f"{prefix}: is picklable") 

153 

154 

155class PersistableObject: 

156 """ 

157 Base class which can be used for objects that shall support being persisted via pickle. 

158 

159 IMPORTANT: 

160 The implementations correspond to the default behaviour of pickle for the case where an object has a non-empty 

161 set of attributes. However, for the case where the set of attributes can be empty adding the explicit 

162 implementation of `__getstate__` is crucial in ensuring that `__setstate__` will be called upon unpickling. 

163 So if an object initially has no attributes and is persisted in that state, then any future refactorings 

164 cannot be handled via `__setstate__` by default, but they can when using this class. 

165 """ 

166 def __getstate__(self): 

167 return self.__dict__ 

168 

169 def __setstate__(self, state): 

170 self.__dict__ = state 

171 

172 

173def setstate( 

174 cls, 

175 obj, 

176 state: Dict[str, Any], 

177 renamed_properties: Dict[str, Union[str, Tuple[str, Callable[[Dict[str, Any]], Any]]]] = None, 

178 new_optional_properties: List[str] = None, 

179 new_default_properties: Dict[str, Any] = None, 

180 removed_properties: List[str] = None, 

181) -> None: 

182 """ 

183 Helper function for safe implementations of __setstate__ in classes, which appropriately handles the cases where 

184 a parent class already implements __setstate__ and where it does not. Call this function whenever you would actually 

185 like to call the super-class' implementation. 

186 Unfortunately, __setstate__ is not implemented in object, rendering super().__setstate__(state) invalid in the general case. 

187 

188 :param cls: the class in which you are implementing __setstate__ 

189 :param obj: the instance of cls 

190 :param state: the state dictionary 

191 :param renamed_properties: can be used for renaming as well as for assigning new values. 

192 If passed must map an old property name to either a new property name or 

193 to tuple of a new property name and a function that computes the new value from the state dictionary. 

194 :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialised with None 

195 :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present 

196 :param removed_properties: a list of names of properties that are no longer being used 

197 """ 

198 # handle new/changed properties 

199 if renamed_properties is not None: 

200 # `new` can either be a string or a tuple of a string and a function 

201 for old_name, new in renamed_properties.items(): 

202 if old_name in state: 

203 if isinstance(new, str): 

204 new_name, new_value = new, state[old_name] 

205 else: 

206 new_name, new_value = new[0], new[1](state) 

207 

208 del state[old_name] 

209 state[new_name] = new_value 

210 

211 if new_optional_properties is not None: 

212 for mNew in new_optional_properties: 

213 if mNew not in state: 

214 state[mNew] = None 

215 if new_default_properties is not None: 

216 for mNew, mValue in new_default_properties.items(): 

217 if mNew not in state: 

218 state[mNew] = mValue 

219 if removed_properties is not None: 

220 for p in removed_properties: 

221 if p in state: 

222 del state[p] 

223 # call super implementation, if any 

224 s = super(cls, obj) 

225 if hasattr(s, '__setstate__'): 

226 s.__setstate__(state) 

227 else: 

228 obj.__dict__ = state 

229 

230 

231def getstate( 

232 cls, 

233 obj, 

234 transient_properties: Iterable[str] = None, 

235 excluded_properties: Iterable[str] = None, 

236 override_properties: Dict[str, Any] = None, 

237 excluded_default_properties: Dict[str, Any] = None 

238) -> Dict[str, Any]: 

239 """ 

240 Helper function for safe implementations of __getstate__ in classes, which appropriately handles the cases where 

241 a parent class already implements __getstate__ and where it does not. Call this function whenever you would actually 

242 like to call the super-class' implementation. 

243 Unfortunately, __getstate__ is not implemented in object, rendering super().__getstate__() invalid in the general case. 

244 

245 :param cls: the class in which you are implementing __getstate__ 

246 :param obj: the instance of cls 

247 :param transient_properties: transient properties which be set to None in serialisations 

248 :param excluded_properties: properties which shall be completely removed from serialisations 

249 :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set; 

250 use this to set a fixed value for an existing property or to add a completely new property 

251 :param excluded_default_properties: properties which shall be completely removed from serialisations, if they are set 

252 to the given default value 

253 :return: the state dictionary, which may be modified by the receiver 

254 """ 

255 s = super(cls, obj) 

256 if hasattr(s, '__getstate__'): 

257 d = s.__getstate__() 

258 else: 

259 d = obj.__dict__ 

260 d = copy(d) 

261 if transient_properties is not None: 

262 for p in transient_properties: 

263 if p in d: 

264 d[p] = None 

265 if excluded_properties is not None: 

266 for p in excluded_properties: 

267 if p in d: 

268 del d[p] 

269 if override_properties is not None: 

270 for k, v in override_properties.items(): 

271 d[k] = v 

272 if excluded_default_properties is not None: 

273 for p, v in excluded_default_properties.items(): 

274 if p in d and d[p] == v: 

275 del d[p] 

276 return d