Coverage for src/sensai/util/plot.py: 23%
217 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import collections
2import logging
3from typing import Sequence, Callable, TypeVar, Tuple, Optional, List, Any, Union, Dict
5import matplotlib.figure
6import matplotlib.ticker as plticker
7import numpy as np
8import pandas as pd
9import seaborn as sns
10from matplotlib import pyplot as plt
11from matplotlib.colors import LinearSegmentedColormap
13from sensai.util.pandas import SeriesInterpolation
15log = logging.getLogger(__name__)
17MATPLOTLIB_DEFAULT_FIGURE_SIZE = (6.4, 4.8)
20class Color:
21 def __init__(self, c: Any):
22 """
23 :param c: any color specification that is understood by matplotlib
24 """
25 self.rgba = matplotlib.colors.to_rgba(c)
27 def darken(self, amount: float):
28 """
29 :param amount: amount to darken in [0,1], where 1 results in black and 0 leaves the color unchanged
30 :return: the darkened color
31 """
32 import colorsys
33 rgb = matplotlib.colors.to_rgb(self.rgba)
34 h, l, s = colorsys.rgb_to_hls(*rgb)
35 l *= amount
36 rgb = colorsys.hls_to_rgb(h, l, s)
37 return Color((*rgb, self.rgba[3]))
39 def lighten(self, amount: float):
40 """
41 :param amount: amount to lighten in [0,1], where 1 results in white and 0 leaves the color unchanged
42 :return: the lightened color
43 """
44 import colorsys
45 rgb = matplotlib.colors.to_rgb(self.rgba)
46 h, l, s = colorsys.rgb_to_hls(*rgb)
47 l += (1-l) * amount
48 rgb = colorsys.hls_to_rgb(h, l, s)
49 return Color((*rgb, self.rgba[3]))
51 def alpha(self, opacity: float) -> "Color":
52 """
53 Returns a new color with modified alpha channel (opacity)
54 :param opacity: the opacity between 0 (transparent) and 1 (fully opaque)
55 :return: the modified color
56 """
57 if not (0 <= opacity <= 1):
58 raise ValueError(f"Opacity must be between 0 and 1, got {opacity}")
59 return Color((*self.rgba[:3], opacity))
61 def to_hex(self, keep_alpha=True) -> str:
62 return matplotlib.colors.to_hex(self.rgba, keep_alpha)
65class LinearColorMap:
66 """
67 Facilitates usage of linear segmented colour maps by combining a colour map (member `cmap`), which transforms normalised values in [0,1]
68 into colours, with a normaliser that transforms the original values. The member `scalarMapper`
69 """
70 def __init__(self, norm_min, norm_max, cmap_points: List[Tuple[float, Any]], cmap_points_normalised=False):
71 """
72 :param norm_min: the value that shall be mapped to 0 in the normalised representation (any smaller values are also clipped to 0)
73 :param norm_max: the value that shall be mapped to 1 in the normalised representation (any larger values are also clipped to 1)
74 :param cmap_points: a list (of at least two) tuples (v, c) where v is the value and c is the colour associated with the value;
75 any colour specification supported by matplotlib is admissible
76 :param cmap_points_normalised: whether the values in `cmap_points` are already normalised
77 """
78 self.norm = matplotlib.colors.Normalize(vmin=norm_min, vmax=norm_max, clip=True)
79 if not cmap_points_normalised:
80 cmap_points = [(self.norm(v), c) for v, c in cmap_points]
81 self.cmap = LinearSegmentedColormap.from_list(f"cmap{id(self)}", cmap_points)
82 self.scalarMapper = matplotlib.cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
84 def get_color(self, value):
85 rgba = self.scalarMapper.to_rgba(value)
86 return '#%02x%02x%02x%02x' % tuple(int(v * 255) for v in rgba)
89def plot_matrix(matrix: np.ndarray, title: str, xtick_labels: Sequence[str], ytick_labels: Sequence[str], xlabel: str,
90 ylabel: str, normalize=True, figsize: Tuple[int, int] = (9, 9), title_add: str = None) -> matplotlib.figure.Figure:
91 """
92 :param matrix: matrix whose data to plot, where matrix[i, j] will be rendered at x=i, y=j
93 :param title: the plot's title
94 :param xtick_labels: the labels for the x-axis ticks
95 :param ytick_labels: the labels for the y-axis ticks
96 :param xlabel: the label for the x-axis
97 :param ylabel: the label for the y-axis
98 :param normalize: whether to normalise the matrix before plotting it (dividing each entry by the sum of all entries)
99 :param figsize: an optional size of the figure to be created
100 :param title_add: an optional second line to add to the title
101 :return: the figure object
102 """
103 matrix = np.transpose(matrix)
105 if title_add is not None:
106 title += f"\n {title_add} "
108 if normalize:
109 matrix = matrix.astype('float') / matrix.sum()
110 fig, ax = plt.subplots(figsize=figsize)
111 fig.canvas.manager.set_window_title(title.replace("\n", " "))
112 # We want to show all ticks...
113 ax.set(xticks=np.arange(matrix.shape[1]),
114 yticks=np.arange(matrix.shape[0]),
115 # ... and label them with the respective list entries
116 xticklabels=xtick_labels, yticklabels=ytick_labels,
117 title=title,
118 xlabel=xlabel,
119 ylabel=ylabel)
120 im = ax.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues)
121 ax.figure.colorbar(im, ax=ax)
123 # Rotate the tick labels and set their alignment.
124 plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
125 rotation_mode="anchor")
127 # Loop over data dimensions and create text annotations.
128 fmt = '.4f' if normalize else ('.2f' if matrix.dtype.kind == 'f' else 'd')
129 thresh = matrix.max() / 2.
130 for i in range(matrix.shape[0]):
131 for j in range(matrix.shape[1]):
132 ax.text(j, i, format(matrix[i, j], fmt),
133 ha="center", va="center",
134 color="white" if matrix[i, j] > thresh else "black")
135 fig.tight_layout()
136 return fig
139TPlot = TypeVar("TPlot", bound="Plot")
142class Plot:
143 def __init__(self, draw: Callable[[plt.Axes], None] = None, name=None, ax: Optional[plt.Axes] = None):
144 """
145 :param draw: function which returns a matplotlib.Axes object to show
146 :param name: name/number of the figure, which determines the window caption; it should be unique, as any plot
147 with the same name will have its contents rendered in the same window. By default, figures are number
148 sequentially.
149 :param ax: the axes to draw to
150 """
151 if ax is not None:
152 fig = None
153 else:
154 fig, ax = plt.subplots(num=name)
155 self.fig: plt.Figure = fig
156 self.ax: plt.Axes = ax
157 if draw is not None:
158 draw(ax)
160 def xlabel(self: TPlot, label) -> TPlot:
161 self.ax.set_xlabel(label)
162 return self
164 def ylabel(self: TPlot, label) -> TPlot:
165 self.ax.set_ylabel(label)
166 return self
168 def title(self: TPlot, title: str) -> TPlot:
169 self.ax.set_title(title)
170 return self
172 def xlim(self: TPlot, min_value, max_value) -> TPlot:
173 self.ax.set_xlim(min_value, max_value)
174 return self
176 def ylim(self: TPlot, min_value, max_value) -> TPlot:
177 self.ax.set_ylim(min_value, max_value)
178 return self
180 def save(self, path):
181 log.info(f"Saving figure in {path}")
182 self.fig.savefig(path)
184 def xtick(self: TPlot, major=None, minor=None) -> TPlot:
185 """
186 Sets a tick on every integer multiple of the given base values.
187 The major ticks are labelled, the minor ticks are not.
189 :param major: the major tick base value
190 :param minor: the minor tick base value
191 :return: self
192 """
193 if major is not None:
194 self.xtick_major(major)
195 if minor is not None:
196 self.xtick_minor(minor)
197 return self
199 def xtick_major(self: TPlot, base) -> TPlot:
200 self.ax.xaxis.set_major_locator(plticker.MultipleLocator(base=base))
201 return self
203 def xtick_minor(self: TPlot, base) -> TPlot:
204 self.ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=base))
205 return self
207 def ytick_major(self: TPlot, base) -> TPlot:
208 self.ax.yaxis.set_major_locator(plticker.MultipleLocator(base=base))
209 return self
212class ScatterPlot(Plot):
213 N_MAX_TRANSPARENCY = 1000
214 N_MIN_TRANSPARENCY = 100
215 MAX_OPACITY = 0.5
216 MIN_OPACITY = 0.05
218 def __init__(self, x, y, c=None, c_base: Tuple[float, float, float] = (0, 0, 1), c_opacity=None, x_label=None,
219 y_label=None, add_diagonal=False, **kwargs):
220 """
221 :param x: the x values; if has name (e.g. pd.Series), will be used as axis label
222 :param y: the y values; if has name (e.g. pd.Series), will be used as axis label
223 :param c: the colour specification; if None, compose from ``c_base`` and ``c_opacity``
224 :param c_base: the base colour as (R, G, B) floats
225 :param c_opacity: the opacity; if None, automatically determine from number of data points
226 :param x_label:
227 :param y_label:
228 :param kwargs:
229 """
230 if c is None:
231 if c_base is None:
232 c_base = (0, 0, 1)
233 if c_opacity is None:
234 n = len(x)
235 if n > self.N_MAX_TRANSPARENCY:
236 transparency = 1
237 elif n < self.N_MIN_TRANSPARENCY:
238 transparency = 0
239 else:
240 transparency = (n - self.N_MIN_TRANSPARENCY) / (self.N_MAX_TRANSPARENCY - self.N_MIN_TRANSPARENCY)
241 c_opacity = self.MIN_OPACITY + (self.MAX_OPACITY - self.MIN_OPACITY) * (1-transparency)
242 c = ((*c_base, c_opacity),)
244 assert len(x) == len(y)
245 if x_label is None and hasattr(x, "name"):
246 x_label = x.name
247 if y_label is None and hasattr(y, "name"):
248 y_label = y.name
250 def draw(ax):
251 if x_label is not None:
252 plt.xlabel(x_label)
253 if x_label is not None:
254 plt.ylabel(y_label)
255 value_range = [min(min(x), min(y)), max(max(x), max(y))]
256 plt.scatter(x, y, c=c, **kwargs)
257 if add_diagonal:
258 plt.plot(value_range, value_range, '-', lw=1, label="_not in legend", color="green", zorder=1)
260 super().__init__(draw)
263class HeatMapPlot(Plot):
264 DEFAULT_CMAP_FACTORY = lambda num_points: LinearSegmentedColormap.from_list("whiteToRed",
265 ((0, (1, 1, 1)), (1 / num_points, (1, 0.96, 0.96)), (1, (0.7, 0, 0))), num_points)
267 def __init__(self, x, y, x_label=None, y_label=None, bins=60, cmap=None, common_range=True, diagonal=False,
268 diagonal_color="green", **kwargs):
269 """
270 :param x: the x values
271 :param y: the y values
272 :param x_label: the x-axis label
273 :param y_label: the y-axis label
274 :param bins: the number of bins to use in each dimension
275 :param cmap: the colour map to use for heat values (if None, use default)
276 :param common_range: whether the heat map is to use a common rng for the x- and y-axes (set to False if x and y are completely
277 different quantities; set to True use cases such as the evaluation of regression model quality)
278 :param diagonal: whether to draw the diagonal line (useful for regression evaluation)
279 :param diagonal_color: the colour to use for the diagonal line
280 :param kwargs: parameters to pass on to plt.imshow
281 """
282 assert len(x) == len(y)
283 if x_label is None and hasattr(x, "name"):
284 x_label = x.name
285 if y_label is None and hasattr(y, "name"):
286 y_label = y.name
288 def draw(ax):
289 nonlocal cmap
290 x_range = [min(x), max(x)]
291 y_range = [min(y), max(y)]
292 rng = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]
293 if common_range:
294 x_range = y_range = rng
295 if diagonal:
296 plt.plot(rng, rng, '-', lw=0.75, label="_not in legend", color=diagonal_color, zorder=2)
297 heatmap, _, _ = np.histogram2d(x, y, range=[x_range, y_range], bins=bins, density=False)
298 extent = [x_range[0], x_range[1], y_range[0], y_range[1]]
299 if cmap is None:
300 cmap = HeatMapPlot.DEFAULT_CMAP_FACTORY(len(x))
301 if x_label is not None:
302 plt.xlabel(x_label)
303 if y_label is not None:
304 plt.ylabel(y_label)
305 plt.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, aspect="auto", **kwargs)
307 super().__init__(draw)
310class HistogramPlot(Plot):
311 def __init__(self, values, bins="auto", kde=False, cdf=False, cdf_complementary=False, cdf_secondary_axis=True,
312 binwidth=None, stat="probability", xlabel=None,
313 **kwargs):
314 """
315 :param values: the values to plot
316 :param bins: a bin specification as understood by sns.histplot
317 :param kde: whether to add a kernel density estimator
318 :param cdf: whether to add a plot of the cumulative distribution function (cdf)
319 :param cdf_complementary: whether to plot, if cdf is enabled, the complementary values
320 :param cdf_secondary_axis: whether to use, if cdf is enabled, a secondary
321 :param binwidth: the bin width; if None, inferred
322 :param stat: the statistic to plot (as understood by sns.histplot)
323 :param xlabel: the label for the x-axis
324 :param kwargs: arguments to pass on to sns.histplot
325 """
327 def draw(ax):
328 nonlocal cdf_secondary_axis
329 sns.histplot(values, bins=bins, kde=kde, binwidth=binwidth, stat=stat, **kwargs)
330 plt.ylabel(stat)
331 if cdf:
332 ecdf_stat = stat
333 if ecdf_stat not in ("count", "proportion", "probability"):
334 ecdf_stat = "proportion"
335 cdf_secondary_axis = True
336 cdf_ax: Optional[plt.Axes] = None
337 cdf_ax_label = f"{ecdf_stat} (cdf)"
338 if cdf_secondary_axis:
339 cdf_ax: plt.Axes = plt.twinx()
340 if stat in ("proportion", "probability"):
341 y_tick = 0.1
342 elif stat == "percent":
343 y_tick = 10
344 else:
345 y_tick = None
346 if y_tick is not None:
347 cdf_ax.yaxis.set_major_locator(plticker.MultipleLocator(base=y_tick))
348 if cdf_complementary or ecdf_stat in ("count", "proportion", "probability"):
349 ecdf_stat = "proportion" if stat == "probability" else stat # same semantics but "probability" not understood by ecdfplot
350 sns.ecdfplot(values, stat=ecdf_stat, complementary=cdf_complementary, color="orange", ax=cdf_ax)
351 else:
352 sns.histplot(values, bins=100, stat=stat, element="poly", fill=False, cumulative=True, color="orange", ax=cdf_ax)
353 if cdf_ax is not None:
354 cdf_ax.set_ylabel(cdf_ax_label)
355 if xlabel is not None:
356 self.xlabel(xlabel)
358 super().__init__(draw)
361class AverageSeriesLinePlot(Plot):
362 """
363 Plots the average of a collection of series or the averages of several collections of series,
364 establishing a common index (the unification of all indices) for each collection via interpolation.
365 The standard deviation is additionally shown as a shaded area around each line.
366 """
367 def __init__(self,
368 series_collection: Union[List[pd.Series], Dict[str, List[pd.Series]]],
369 interpolation: SeriesInterpolation,
370 collection_name="collection",
371 ax: Optional[plt.Axes] = None,
372 hue_order=None, palette=None):
373 """
374 :param series_collection: either a list of series to average or a dictionary mapping the name of a collection
375 to a list of series to average
376 :param interpolation: the interpolation with which to obtain series values for the unified index of a collection of series
377 :param collection_name: a name indicating what a key in `series_collection` refers to, which will appear in the legend
378 for the case where more than one collection is passed
379 :param ax: the axis to plot to; if None, create a new figure and axis
380 :param hue_order: the hue order (for the case where there is more than one collection of series)
381 :param palette: the colour palette to use
382 """
383 if isinstance(series_collection, dict):
384 series_dict = series_collection
385 else:
386 series_dict = {"_": series_collection}
388 series_list = next(iter(series_dict.values()))
389 x_name = series_list[0].index.name or "x"
390 y_name = series_list[0].name or "y"
392 # build data frame with all series, interpolating each sub-collection
393 dfs = []
394 for name, series_list in series_dict.items():
395 interpolated_series_list = interpolation.interpolate_all_with_combined_index(series_list)
396 for series in interpolated_series_list:
397 df = pd.DataFrame({y_name: series, x_name: series.index})
398 df["series_id"] = id(series)
399 df[collection_name] = name
400 dfs.append(df)
401 full_df = pd.concat(dfs, axis=0).reset_index(drop=True)
403 def draw(ax):
404 sns.lineplot(
405 data=full_df,
406 x=x_name,
407 y=y_name,
408 estimator="mean",
409 hue=collection_name if len(series_dict) > 1 else None,
410 hue_order=hue_order,
411 palette=palette,
412 ax=ax,
413 )
415 super().__init__(draw, ax=ax)