Coverage for src/sensai/util/test.py: 0%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import json 

2 

3import numpy as np 

4import pandas as pd 

5 

6 

7def snapshot_compatible(obj, float_decimals=6, significant_digits=12): 

8 """ 

9 Renders an object snapshot-compatible by appropriately converting nested types and reducing float precision to a level 

10 that is likely to not cause problems when testing snapshots for equivalence on different platforms 

11 

12 :param obj: the object to convert 

13 :param float_decimals: the number of float decimal places to consider 

14 :param significant_digits: the (maximum) number of significant digits to consider 

15 :return: the converted object 

16 """ 

17 result = json.loads(json.dumps(obj, default=json_mapper)) 

18 return convert_floats(result, float_decimals, significant_digits) 

19 

20 

21def reduce_float_precision(f, decimals, significant_digits): 

22 return float(format(float(format(f, '.%df' % decimals)), ".%dg" % significant_digits)) 

23 

24 

25def convert_floats(o, float_decimals, significant_digits): 

26 if type(o) == list: 

27 return [convert_floats(x, float_decimals, significant_digits) for x in o] 

28 elif type(o) == dict: 

29 return {key: convert_floats(value, float_decimals, significant_digits) for (key, value) in o.items()} 

30 elif type(o) == float: 

31 return reduce_float_precision(o, float_decimals, significant_digits) 

32 else: 

33 return o 

34 

35 

36def json_mapper(o): 

37 """ 

38 Maps the given data object to a representation that is JSON-compatible. 

39 Currently, the supported object types include, in particular, numpy arrays as well as pandas Series and DataFrames. 

40 

41 :param o: the object to convert 

42 :return: the converted object 

43 """ 

44 if isinstance(o, pd.DataFrame): 

45 if isinstance(o.index, pd.DatetimeIndex): 

46 o.index = o.index.astype('int64').tolist() 

47 return o.to_dict() 

48 if isinstance(o, pd.Series): 

49 return o.values.tolist() 

50 if isinstance(o, np.ndarray): 

51 return o.tolist() 

52 if isinstance(o, list): 

53 return o 

54 else: 

55 return o.__dict__