Coverage for src/sensai/util/pickle.py: 33%

143 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import logging 

2import os 

3import pickle 

4from pathlib import Path 

5from typing import Any, Callable, Dict, Iterable, List, Union 

6 

7from .io import S3Object, is_s3_path 

8 

9log = logging.getLogger(__name__) 

10 

11 

12def load_pickle(path: Union[str, Path], backend="pickle"): 

13 if isinstance(path, Path): 

14 path = str(path) 

15 

16 def read_file(f): 

17 def _load_with_error_log(loader: Callable): 

18 try: 

19 return loader(f) 

20 except Exception as e: 

21 log.error(f"Error loading {path}") 

22 raise e 

23 

24 if backend == "pickle": 

25 return _load_with_error_log(pickle.load) 

26 elif backend == "cloudpickle": 

27 import cloudpickle 

28 return _load_with_error_log(cloudpickle.load) 

29 elif backend == "joblib": 

30 import joblib 

31 return joblib.load(f) 

32 else: 

33 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'") 

34 

35 if is_s3_path(path): 

36 return read_file(S3Object(path).open_file("rb")) 

37 with open(path, "rb") as f: 

38 return read_file(f) 

39 

40 

41def dump_pickle(obj, pickle_path: Union[str, Path], backend="pickle", protocol=pickle.HIGHEST_PROTOCOL): 

42 if isinstance(pickle_path, Path): 

43 pickle_path = str(pickle_path) 

44 

45 def open_file(): 

46 if is_s3_path(pickle_path): 

47 return S3Object(pickle_path).open_file("wb") 

48 else: 

49 return open(pickle_path, "wb") 

50 

51 dir_name = os.path.dirname(pickle_path) 

52 if dir_name != "": 

53 os.makedirs(dir_name, exist_ok=True) 

54 with open_file() as f: 

55 if backend == "pickle": 

56 try: 

57 pickle.dump(obj, f, protocol=protocol) 

58 except AttributeError as e: 

59 failing_paths = PickleFailureDebugger.debug_failure(obj) 

60 raise AttributeError(f"Cannot pickle paths {failing_paths} of {obj}: {str(e)}") 

61 elif backend == "joblib": 

62 import joblib 

63 joblib.dump(obj, f, protocol=protocol) 

64 elif backend == "cloudpickle": 

65 import cloudpickle 

66 cloudpickle.dump(obj, f, protocol=protocol) 

67 else: 

68 raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'") 

69 

70 

71class PickleFailureDebugger: 

72 """ 

73 A collection of methods for testing whether objects can be pickled and logging useful infos in case they cannot 

74 """ 

75 

76 enabled = False # global flag controlling the behaviour of logFailureIfEnabled 

77 

78 @classmethod 

79 def _debug_failure(cls, obj, path, failures, handled_object_ids): 

80 if id(obj) in handled_object_ids: 

81 return 

82 handled_object_ids.add(id(obj)) 

83 

84 try: 

85 pickle.dumps(obj) 

86 except: 

87 # determine dictionary of children to investigate (if any) 

88 if hasattr( 

89 obj, '__dict__' 

90 ): # Because of strange behaviour of getstate, here try-except is used instead of if-else 

91 try: # Because of strange behaviour of getattr(_, '__getstate__'), we here use try-except 

92 d = obj.__getstate__() 

93 if type(d) != dict: 

94 d = {"state": d} 

95 except: 

96 d = obj.__dict__ 

97 elif type(obj) == dict: 

98 d = obj 

99 elif type(obj) in (list, tuple, set): 

100 d = dict(enumerate(obj)) 

101 else: 

102 d = {} 

103 

104 # recursively test children 

105 have_failed_child = False 

106 for key, child in d.items(): 

107 child_path = list(path) + [f"{key}[{child.__class__.__name__}]"] 

108 have_failed_child = cls._debug_failure( 

109 child, child_path, failures, handled_object_ids 

110 ) or have_failed_child 

111 

112 if not have_failed_child: 

113 failures.append(path) 

114 

115 return True 

116 else: 

117 return False 

118 

119 @classmethod 

120 def debug_failure(cls, obj) -> List[str]: 

121 """ 

122 Recursively tries to pickle the given object and returns a list of failed paths 

123 

124 :param obj: the object for which to recursively test pickling 

125 :return: a list of object paths that failed to pickle 

126 """ 

127 handled_object_ids = set() 

128 failures = [] 

129 cls._debug_failure(obj, [obj.__class__.__name__], failures, handled_object_ids) 

130 return [".".join(l) for l in failures] 

131 

132 @classmethod 

133 def log_failure_if_enabled(cls, obj, context_info: str = None): 

134 """ 

135 If the class flag 'enabled' is set to true, the pickling of the given object is 

136 recursively tested and the results are logged at error level if there are problems and 

137 info level otherwise. 

138 If the flag is disabled, no action is taken. 

139 

140 :param obj: the object for which to recursively test pickling 

141 :param context_info: optional additional string to be included in the log message 

142 """ 

143 if cls.enabled: 

144 failures = cls.debug_failure(obj) 

145 prefix = f"Picklability analysis for {obj}" 

146 if context_info is not None: 

147 prefix += " (context: %s)" % context_info 

148 if len(failures) > 0: 

149 log.error(f"{prefix}: pickling would result in failures due to: {failures}") 

150 else: 

151 log.info(f"{prefix}: is picklable") 

152 

153 

154def setstate( 

155 cls, 

156 obj, 

157 state: Dict[str, Any], 

158 renamed_properties: Dict[str, str] = None, 

159 new_optional_properties: List[str] = None, 

160 new_default_properties: Dict[str, Any] = None, 

161 removed_properties: List[str] = None 

162) -> None: 

163 """ 

164 Helper function for safe implementations of __setstate__ in classes, which appropriately handles the cases where 

165 a parent class already implements __setstate__ and where it does not. Call this function whenever you would actually 

166 like to call the super-class' implementation. 

167 Unfortunately, __setstate__ is not implemented in object, rendering super().__setstate__(state) invalid in the general case. 

168 

169 :param cls: the class in which you are implementing __setstate__ 

170 :param obj: the instance of cls 

171 :param state: the state dictionary 

172 :param renamed_properties: a mapping from old property names to new property names 

173 :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialised with None 

174 :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present 

175 :param removed_properties: a list of names of properties that are no longer being used 

176 """ 

177 # handle new/changed properties 

178 if renamed_properties is not None: 

179 for mOld, mNew in renamed_properties.items(): 

180 if mOld in state: 

181 state[mNew] = state[mOld] 

182 del state[mOld] 

183 if new_optional_properties is not None: 

184 for mNew in new_optional_properties: 

185 if mNew not in state: 

186 state[mNew] = None 

187 if new_default_properties is not None: 

188 for mNew, mValue in new_default_properties.items(): 

189 if mNew not in state: 

190 state[mNew] = mValue 

191 if removed_properties is not None: 

192 for p in removed_properties: 

193 if p in state: 

194 del state[p] 

195 # call super implementation, if any 

196 s = super(cls, obj) 

197 if hasattr(s, '__setstate__'): 

198 s.__setstate__(state) 

199 else: 

200 obj.__dict__ = state 

201 

202 

203def getstate( 

204 cls, 

205 obj, 

206 transient_properties: Iterable[str] = None, 

207 excluded_properties: Iterable[str] = None, 

208 override_properties: Dict[str, Any] = None, 

209 excluded_default_properties: Dict[str, Any] = None 

210) -> Dict[str, Any]: 

211 """ 

212 Helper function for safe implementations of __getstate__ in classes, which appropriately handles the cases where 

213 a parent class already implements __getstate__ and where it does not. Call this function whenever you would actually 

214 like to call the super-class' implementation. 

215 Unfortunately, __getstate__ is not implemented in object, rendering super().__getstate__() invalid in the general case. 

216 

217 :param cls: the class in which you are implementing __getstate__ 

218 :param obj: the instance of cls 

219 :param transient_properties: transient properties which be set to None in serialisations 

220 :param excluded_properties: properties which shall be completely removed from serialisations 

221 :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set; 

222 use this to set a fixed value for an existing property or to add a completely new property 

223 :param excluded_default_properties: properties which shall be completely removed from serialisations, if they are set 

224 to the given default value 

225 :return: the state dictionary, which may be modified by the receiver 

226 """ 

227 s = super(cls, obj) 

228 if hasattr(s, '__getstate__'): 

229 d = s.__getstate__() 

230 else: 

231 d = obj.__dict__.copy() 

232 if transient_properties is not None: 

233 for p in transient_properties: 

234 if p in d: 

235 d[p] = None 

236 if excluded_properties is not None: 

237 for p in excluded_properties: 

238 if p in d: 

239 del d[p] 

240 if override_properties is not None: 

241 for k, v in override_properties.items(): 

242 d[k] = v 

243 if excluded_default_properties is not None: 

244 for p, v in excluded_default_properties.items(): 

245 if p in d and d[p] == v: 

246 del d[p] 

247 return d