Coverage for src/sensai/util/pandas.py: 35%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import logging 

2from abc import ABC, abstractmethod 

3from copy import copy 

4from typing import List, Optional 

5 

6import numpy as np 

7import pandas as pd 

8 

9from sensai.util import mark_used 

10 

11log = logging.getLogger(__name__) 

12 

13 

14class DataFrameColumnChangeTracker: 

15 """ 

16 A simple class for keeping track of changes in columns between an initial data frame and some other data frame 

17 (usually the result of some transformations performed on the initial one). 

18 

19 Example: 

20 

21 >>> from sensai.util.pandas import DataFrameColumnChangeTracker 

22 >>> import pandas as pd 

23 

24 >>> df = pd.DataFrame({"bar": [1, 2]}) 

25 >>> columnChangeTracker = DataFrameColumnChangeTracker(df) 

26 >>> df["foo"] = [4, 5] 

27 >>> columnChangeTracker.track_change(df) 

28 >>> columnChangeTracker.get_removed_columns() 

29 set() 

30 >>> columnChangeTracker.get_added_columns() 

31 {'foo'} 

32 """ 

33 def __init__(self, initial_df: pd.DataFrame): 

34 self.initialColumns = copy(initial_df.columns) 

35 self.final_columns = None 

36 

37 def track_change(self, changed_df: pd.DataFrame): 

38 self.final_columns = copy(changed_df.columns) 

39 

40 def get_removed_columns(self): 

41 self.assert_change_was_tracked() 

42 return set(self.initialColumns).difference(self.final_columns) 

43 

44 def get_added_columns(self): 

45 """ 

46 Returns the columns in the last entry of the history that were not present the first one 

47 """ 

48 self.assert_change_was_tracked() 

49 return set(self.final_columns).difference(self.initialColumns) 

50 

51 def column_change_string(self): 

52 """ 

53 Returns a string representation of the change 

54 """ 

55 self.assert_change_was_tracked() 

56 if list(self.initialColumns) == list(self.final_columns): 

57 return "none" 

58 removed_cols, added_cols = self.get_removed_columns(), self.get_added_columns() 

59 if removed_cols == added_cols == set(): 

60 return f"reordered {list(self.final_columns)}" 

61 

62 return f"added={list(added_cols)}, removed={list(removed_cols)}" 

63 

64 def assert_change_was_tracked(self): 

65 if self.final_columns is None: 

66 raise Exception(f"No change was tracked yet. " 

67 f"Did you forget to call trackChange on the resulting data frame?") 

68 

69 

70def extract_array(df: pd.DataFrame, dtype=None): 

71 """ 

72 Extracts array from data frame. It is expected that each row corresponds to a data point and 

73 each column corresponds to a "channel". Moreover, all entries are expected to be arrays of the same shape 

74 (or scalars or sequences of the same length). We will refer to that shape as tensorShape. 

75 

76 The output will be of shape `(N_rows, N_columns, *tensorShape)`. Thus, `N_rows` can be interpreted as dataset length 

77 (or batch size, if a single batch is passed) and N_columns can be interpreted as number of channels. 

78 Empty dimensions will be stripped, thus if the data frame has only one column, the array will have shape 

79 `(N_rows, *tensorShape)`. 

80 E.g. an image with three channels could equally be passed as data frame of the type 

81 

82 

83 +------------------+------------------+------------------+ 

84 | R | G | B | 

85 +==================+==================+==================+ 

86 | channel | channel | channel | 

87 +------------------+------------------+------------------+ 

88 | channel | channel | channel | 

89 +------------------+------------------+------------------+ 

90 | ... | ... | ... | 

91 +------------------+------------------+------------------+ 

92 

93 or as data frame of type 

94 

95 +------------------+ 

96 | image | 

97 +==================+ 

98 | RGB-array | 

99 +------------------+ 

100 | RGB-array | 

101 +------------------+ 

102 | ... | 

103 +------------------+ 

104 

105 In both cases the returned array will have shape `(N_images, 3, width, height)` 

106 

107 :param df: data frame where each entry is an array of shape tensorShape 

108 :param dtype: if not None, convert the array's data type to this type (string or numpy dtype) 

109 :return: array of shape `(N_rows, N_columns, *tensorShape)` with stripped empty dimensions 

110 """ 

111 log.debug(f"Stacking tensors of shape {np.array(df.iloc[0, 0]).shape}") 

112 try: 

113 # This compact way of extracting the array causes dtypes to be modified, 

114 # arr = np.stack(df.apply(np.stack, axis=1)).squeeze() 

115 # so we use this numpy-only alternative: 

116 arr = df.values 

117 if arr.shape[1] > 1: 

118 arr = np.stack([np.stack(arr[i]) for i in range(arr.shape[0])]) 

119 else: 

120 arr = np.stack(arr[:, 0]) 

121 # For the case where there is only one row, the old implementation above removed the first dimension, 

122 # so we do the same, even though it seems odd to do so (potential problem for batch size 1) 

123 # TODO: remove this behavior 

124 if arr.shape[0] == 1: 

125 arr = arr[0] 

126 except ValueError: 

127 raise ValueError(f"No array can be extracted from frame of length {len(df)} with columns {list(df.columns)}. " 

128 f"Make sure that all entries have the same shape") 

129 if dtype is not None: 

130 arr = arr.astype(dtype, copy=False) 

131 return arr 

132 

133 

134def remove_duplicate_index_entries(df: pd.DataFrame): 

135 """ 

136 Removes successive duplicate index entries by keeping only the first occurrence for every duplicate index element. 

137 

138 :param df: the data frame, which is assumed to have a sorted index 

139 :return: the (modified) data frame with duplicate index entries removed 

140 """ 

141 keep = [True] 

142 prev_item = df.index[0] 

143 for item in df.index[1:]: 

144 keep.append(item != prev_item) 

145 prev_item = item 

146 return df[keep] 

147 

148 

149def query_data_frame(df: pd.DataFrame, sql: str): 

150 """ 

151 Queries the given data frame with the given condition specified in SQL syntax. 

152 

153 NOTE: Requires duckdb to be installed. 

154 

155 :param df: the data frame to query 

156 :param sql: an SQL query starting with the WHERE clause (excluding the 'where' keyword itself) 

157 :return: the filtered/transformed data frame 

158 """ 

159 import duckdb 

160 

161 NUM_TYPE_INFERENCE_ROWS = 100 

162 

163 def is_supported_object_col(col_name: str): 

164 supported_type_set = set() 

165 contains_unsupported_types = False 

166 # check the first N values 

167 for value in df[col_name].iloc[:NUM_TYPE_INFERENCE_ROWS]: 

168 if isinstance(value, str): 

169 supported_type_set.add(str) 

170 elif value is None: 

171 pass 

172 else: 

173 contains_unsupported_types = True 

174 return not contains_unsupported_types and len(supported_type_set) == 1 

175 

176 # determine which columns are object columns that are unsupported by duckdb and would raise errors 

177 # if they remained in the data frame that is queried 

178 added_index_col = "__sensai_resultset_index__" 

179 original_columns = df.columns 

180 object_columns = list(df.dtypes[df.dtypes == object].index) 

181 object_columns = [c for c in object_columns if not is_supported_object_col(c)] 

182 

183 # add an artificial index which we will use to identify the rows for object column reconstruction 

184 df[added_index_col] = np.arange(len(df)) 

185 

186 try: 

187 # remove the object columns from the data frame but save them for subsequent reconstruction 

188 objects_df = df[object_columns + [added_index_col]] 

189 query_df = df.drop(columns=object_columns) 

190 mark_used(query_df) 

191 

192 # apply query with reduced df 

193 result_df = duckdb.query(f"select * from query_df where {sql}").to_df() 

194 

195 # restore object columns in result 

196 objects_df.set_index(added_index_col, drop=True, inplace=True) 

197 result_df.set_index(added_index_col, drop=True, inplace=True) 

198 result_objects_df = objects_df.loc[result_df.index] 

199 assert len(result_df) == len(result_objects_df) 

200 full_result_df = pd.concat([result_df, result_objects_df], axis=1) 

201 full_result_df = full_result_df[original_columns] 

202 

203 finally: 

204 # clean up 

205 df.drop(columns=added_index_col, inplace=True) 

206 

207 return full_result_df 

208 

209 

210class SeriesInterpolation(ABC): 

211 def interpolate(self, series: pd.Series, inplace: bool = False) -> Optional[pd.Series]: 

212 if not inplace: 

213 series = series.copy() 

214 self._interpolate_in_place(series) 

215 return series if not inplace else None 

216 

217 @abstractmethod 

218 def _interpolate_in_place(self, series: pd.Series) -> None: 

219 pass 

220 

221 def interpolate_all_with_combined_index(self, series_list: List[pd.Series]) -> List[pd.Series]: 

222 """ 

223 Interpolates the given series using the combined index of all series. 

224 

225 :param series_list: the list of series to interpolate 

226 :return: a list of corresponding interpolated series, each having the same index 

227 """ 

228 # determine common index and 

229 index_elements = set() 

230 for series in series_list: 

231 index_elements.update(series.index) 

232 common_index = sorted(index_elements) 

233 

234 # reindex, filling the gaps via interpolation 

235 interpolated_series_list = [] 

236 for series in series_list: 

237 series = series.copy() 

238 series = series.reindex(common_index, method=None) 

239 self.interpolate(series, inplace=True) 

240 interpolated_series_list.append(series) 

241 

242 return interpolated_series_list 

243 

244 

245class SeriesInterpolationLinearIndex(SeriesInterpolation): 

246 def __init__(self, ffill: bool = False, bfill: bool = False): 

247 """ 

248 :param ffill: whether to fill any N/A values at the end of the series with the last valid observation 

249 :param bfill: whether to fill any N/A values at the start of the series with the first valid observation 

250 """ 

251 self.ffill = ffill 

252 self.bfill = bfill 

253 

254 def _interpolate_in_place(self, series: pd.Series) -> None: 

255 series.interpolate(method="index", inplace=True) 

256 if self.ffill: 

257 series.interpolate(method="ffill", limit_direction="forward") 

258 if self.bfill: 

259 series.interpolate(method="bfill", limit_direction="backward") 

260 

261 

262class SeriesInterpolationRepeatPreceding(SeriesInterpolation): 

263 def __init__(self, bfill: bool = False): 

264 """ 

265 :param bfill: whether to fill any N/A values at the start of the series with the first valid observation 

266 """ 

267 self.bfill = bfill 

268 

269 def _interpolate_in_place(self, series: pd.Series) -> None: 

270 series.interpolate(method="pad", limit_direction="forward", inplace=True) 

271 if self.bfill: 

272 series.interpolate(method="bfill", limit_direction="backward") 

273 

274 

275def average_series(series_list: List[pd.Series], interpolation: SeriesInterpolation) -> pd.Series: 

276 interpolated_series_list = interpolation.interpolate_all_with_combined_index(series_list) 

277 return sum(interpolated_series_list) / len(interpolated_series_list) # type: ignore