Coverage for src/sensai/util/pandas.py: 35%
124 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import logging
2from abc import ABC, abstractmethod
3from copy import copy
4from typing import List, Optional
6import numpy as np
7import pandas as pd
9from sensai.util import mark_used
11log = logging.getLogger(__name__)
14class DataFrameColumnChangeTracker:
15 """
16 A simple class for keeping track of changes in columns between an initial data frame and some other data frame
17 (usually the result of some transformations performed on the initial one).
19 Example:
21 >>> from sensai.util.pandas import DataFrameColumnChangeTracker
22 >>> import pandas as pd
24 >>> df = pd.DataFrame({"bar": [1, 2]})
25 >>> columnChangeTracker = DataFrameColumnChangeTracker(df)
26 >>> df["foo"] = [4, 5]
27 >>> columnChangeTracker.track_change(df)
28 >>> columnChangeTracker.get_removed_columns()
29 set()
30 >>> columnChangeTracker.get_added_columns()
31 {'foo'}
32 """
33 def __init__(self, initial_df: pd.DataFrame):
34 self.initialColumns = copy(initial_df.columns)
35 self.final_columns = None
37 def track_change(self, changed_df: pd.DataFrame):
38 self.final_columns = copy(changed_df.columns)
40 def get_removed_columns(self):
41 self.assert_change_was_tracked()
42 return set(self.initialColumns).difference(self.final_columns)
44 def get_added_columns(self):
45 """
46 Returns the columns in the last entry of the history that were not present the first one
47 """
48 self.assert_change_was_tracked()
49 return set(self.final_columns).difference(self.initialColumns)
51 def column_change_string(self):
52 """
53 Returns a string representation of the change
54 """
55 self.assert_change_was_tracked()
56 if list(self.initialColumns) == list(self.final_columns):
57 return "none"
58 removed_cols, added_cols = self.get_removed_columns(), self.get_added_columns()
59 if removed_cols == added_cols == set():
60 return f"reordered {list(self.final_columns)}"
62 return f"added={list(added_cols)}, removed={list(removed_cols)}"
64 def assert_change_was_tracked(self):
65 if self.final_columns is None:
66 raise Exception(f"No change was tracked yet. "
67 f"Did you forget to call trackChange on the resulting data frame?")
70def extract_array(df: pd.DataFrame, dtype=None):
71 """
72 Extracts array from data frame. It is expected that each row corresponds to a data point and
73 each column corresponds to a "channel". Moreover, all entries are expected to be arrays of the same shape
74 (or scalars or sequences of the same length). We will refer to that shape as tensorShape.
76 The output will be of shape `(N_rows, N_columns, *tensorShape)`. Thus, `N_rows` can be interpreted as dataset length
77 (or batch size, if a single batch is passed) and N_columns can be interpreted as number of channels.
78 Empty dimensions will be stripped, thus if the data frame has only one column, the array will have shape
79 `(N_rows, *tensorShape)`.
80 E.g. an image with three channels could equally be passed as data frame of the type
83 +------------------+------------------+------------------+
84 | R | G | B |
85 +==================+==================+==================+
86 | channel | channel | channel |
87 +------------------+------------------+------------------+
88 | channel | channel | channel |
89 +------------------+------------------+------------------+
90 | ... | ... | ... |
91 +------------------+------------------+------------------+
93 or as data frame of type
95 +------------------+
96 | image |
97 +==================+
98 | RGB-array |
99 +------------------+
100 | RGB-array |
101 +------------------+
102 | ... |
103 +------------------+
105 In both cases the returned array will have shape `(N_images, 3, width, height)`
107 :param df: data frame where each entry is an array of shape tensorShape
108 :param dtype: if not None, convert the array's data type to this type (string or numpy dtype)
109 :return: array of shape `(N_rows, N_columns, *tensorShape)` with stripped empty dimensions
110 """
111 log.debug(f"Stacking tensors of shape {np.array(df.iloc[0, 0]).shape}")
112 try:
113 # This compact way of extracting the array causes dtypes to be modified,
114 # arr = np.stack(df.apply(np.stack, axis=1)).squeeze()
115 # so we use this numpy-only alternative:
116 arr = df.values
117 if arr.shape[1] > 1:
118 arr = np.stack([np.stack(arr[i]) for i in range(arr.shape[0])])
119 else:
120 arr = np.stack(arr[:, 0])
121 # For the case where there is only one row, the old implementation above removed the first dimension,
122 # so we do the same, even though it seems odd to do so (potential problem for batch size 1)
123 # TODO: remove this behavior
124 if arr.shape[0] == 1:
125 arr = arr[0]
126 except ValueError:
127 raise ValueError(f"No array can be extracted from frame of length {len(df)} with columns {list(df.columns)}. "
128 f"Make sure that all entries have the same shape")
129 if dtype is not None:
130 arr = arr.astype(dtype, copy=False)
131 return arr
134def remove_duplicate_index_entries(df: pd.DataFrame):
135 """
136 Removes successive duplicate index entries by keeping only the first occurrence for every duplicate index element.
138 :param df: the data frame, which is assumed to have a sorted index
139 :return: the (modified) data frame with duplicate index entries removed
140 """
141 keep = [True]
142 prev_item = df.index[0]
143 for item in df.index[1:]:
144 keep.append(item != prev_item)
145 prev_item = item
146 return df[keep]
149def query_data_frame(df: pd.DataFrame, sql: str):
150 """
151 Queries the given data frame with the given condition specified in SQL syntax.
153 NOTE: Requires duckdb to be installed.
155 :param df: the data frame to query
156 :param sql: an SQL query starting with the WHERE clause (excluding the 'where' keyword itself)
157 :return: the filtered/transformed data frame
158 """
159 import duckdb
161 NUM_TYPE_INFERENCE_ROWS = 100
163 def is_supported_object_col(col_name: str):
164 supported_type_set = set()
165 contains_unsupported_types = False
166 # check the first N values
167 for value in df[col_name].iloc[:NUM_TYPE_INFERENCE_ROWS]:
168 if isinstance(value, str):
169 supported_type_set.add(str)
170 elif value is None:
171 pass
172 else:
173 contains_unsupported_types = True
174 return not contains_unsupported_types and len(supported_type_set) == 1
176 # determine which columns are object columns that are unsupported by duckdb and would raise errors
177 # if they remained in the data frame that is queried
178 added_index_col = "__sensai_resultset_index__"
179 original_columns = df.columns
180 object_columns = list(df.dtypes[df.dtypes == object].index)
181 object_columns = [c for c in object_columns if not is_supported_object_col(c)]
183 # add an artificial index which we will use to identify the rows for object column reconstruction
184 df[added_index_col] = np.arange(len(df))
186 try:
187 # remove the object columns from the data frame but save them for subsequent reconstruction
188 objects_df = df[object_columns + [added_index_col]]
189 query_df = df.drop(columns=object_columns)
190 mark_used(query_df)
192 # apply query with reduced df
193 result_df = duckdb.query(f"select * from query_df where {sql}").to_df()
195 # restore object columns in result
196 objects_df.set_index(added_index_col, drop=True, inplace=True)
197 result_df.set_index(added_index_col, drop=True, inplace=True)
198 result_objects_df = objects_df.loc[result_df.index]
199 assert len(result_df) == len(result_objects_df)
200 full_result_df = pd.concat([result_df, result_objects_df], axis=1)
201 full_result_df = full_result_df[original_columns]
203 finally:
204 # clean up
205 df.drop(columns=added_index_col, inplace=True)
207 return full_result_df
210class SeriesInterpolation(ABC):
211 def interpolate(self, series: pd.Series, inplace: bool = False) -> Optional[pd.Series]:
212 if not inplace:
213 series = series.copy()
214 self._interpolate_in_place(series)
215 return series if not inplace else None
217 @abstractmethod
218 def _interpolate_in_place(self, series: pd.Series) -> None:
219 pass
221 def interpolate_all_with_combined_index(self, series_list: List[pd.Series]) -> List[pd.Series]:
222 """
223 Interpolates the given series using the combined index of all series.
225 :param series_list: the list of series to interpolate
226 :return: a list of corresponding interpolated series, each having the same index
227 """
228 # determine common index and
229 index_elements = set()
230 for series in series_list:
231 index_elements.update(series.index)
232 common_index = sorted(index_elements)
234 # reindex, filling the gaps via interpolation
235 interpolated_series_list = []
236 for series in series_list:
237 series = series.copy()
238 series = series.reindex(common_index, method=None)
239 self.interpolate(series, inplace=True)
240 interpolated_series_list.append(series)
242 return interpolated_series_list
245class SeriesInterpolationLinearIndex(SeriesInterpolation):
246 def __init__(self, ffill: bool = False, bfill: bool = False):
247 """
248 :param ffill: whether to fill any N/A values at the end of the series with the last valid observation
249 :param bfill: whether to fill any N/A values at the start of the series with the first valid observation
250 """
251 self.ffill = ffill
252 self.bfill = bfill
254 def _interpolate_in_place(self, series: pd.Series) -> None:
255 series.interpolate(method="index", inplace=True)
256 if self.ffill:
257 series.interpolate(method="ffill", limit_direction="forward")
258 if self.bfill:
259 series.interpolate(method="bfill", limit_direction="backward")
262class SeriesInterpolationRepeatPreceding(SeriesInterpolation):
263 def __init__(self, bfill: bool = False):
264 """
265 :param bfill: whether to fill any N/A values at the start of the series with the first valid observation
266 """
267 self.bfill = bfill
269 def _interpolate_in_place(self, series: pd.Series) -> None:
270 series.interpolate(method="pad", limit_direction="forward", inplace=True)
271 if self.bfill:
272 series.interpolate(method="bfill", limit_direction="backward")
275def average_series(series_list: List[pd.Series], interpolation: SeriesInterpolation) -> pd.Series:
276 interpolated_series_list = interpolation.interpolate_all_with_combined_index(series_list)
277 return sum(interpolated_series_list) / len(interpolated_series_list) # type: ignore