Coverage for src/sensai/util/plot.py: 23%
189 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import logging
2from typing import Sequence, Callable, TypeVar, Tuple, Optional, List, Any
4import matplotlib.figure
5import matplotlib.ticker as plticker
6import numpy as np
7import seaborn as sns
8from matplotlib import pyplot as plt
9from matplotlib.colors import LinearSegmentedColormap
11log = logging.getLogger(__name__)
13MATPLOTLIB_DEFAULT_FIGURE_SIZE = (6.4, 4.8)
16class Color:
17 def __init__(self, c: Any):
18 """
19 :param c: any color specification that is understood by matplotlib
20 """
21 self.rgba = matplotlib.colors.to_rgba(c)
23 def darken(self, amount: float):
24 """
25 :param amount: amount to darken in [0,1], where 1 results in black and 0 leaves the color unchanged
26 :return: the darkened color
27 """
28 import colorsys
29 rgb = matplotlib.colors.to_rgb(self.rgba)
30 h, l, s = colorsys.rgb_to_hls(*rgb)
31 l *= amount
32 rgb = colorsys.hls_to_rgb(h, l, s)
33 return Color((*rgb, self.rgba[3]))
35 def lighten(self, amount: float):
36 """
37 :param amount: amount to lighten in [0,1], where 1 results in white and 0 leaves the color unchanged
38 :return: the lightened color
39 """
40 import colorsys
41 rgb = matplotlib.colors.to_rgb(self.rgba)
42 h, l, s = colorsys.rgb_to_hls(*rgb)
43 l += (1-l) * amount
44 rgb = colorsys.hls_to_rgb(h, l, s)
45 return Color((*rgb, self.rgba[3]))
47 def alpha(self, opacity: float) -> "Color":
48 """
49 Returns a new color with modified alpha channel (opacity)
50 :param opacity: the opacity between 0 (transparent) and 1 (fully opaque)
51 :return: the modified color
52 """
53 if not (0 <= opacity <= 1):
54 raise ValueError(f"Opacity must be between 0 and 1, got {opacity}")
55 return Color((*self.rgba[:3], opacity))
57 def to_hex(self, keep_alpha=True) -> str:
58 return matplotlib.colors.to_hex(self.rgba, keep_alpha)
61class LinearColorMap:
62 """
63 Facilitates usage of linear segmented colour maps by combining a colour map (member `cmap`), which transforms normalised values in [0,1]
64 into colours, with a normaliser that transforms the original values. The member `scalarMapper`
65 """
66 def __init__(self, norm_min, norm_max, cmap_points: List[Tuple[float, Any]], cmap_points_normalised=False):
67 """
68 :param norm_min: the value that shall be mapped to 0 in the normalised representation (any smaller values are also clipped to 0)
69 :param norm_max: the value that shall be mapped to 1 in the normalised representation (any larger values are also clipped to 1)
70 :param cmap_points: a list (of at least two) tuples (v, c) where v is the value and c is the colour associated with the value;
71 any colour specification supported by matplotlib is admissible
72 :param cmap_points_normalised: whether the values in `cmap_points` are already normalised
73 """
74 self.norm = matplotlib.colors.Normalize(vmin=norm_min, vmax=norm_max, clip=True)
75 if not cmap_points_normalised:
76 cmap_points = [(self.norm(v), c) for v, c in cmap_points]
77 self.cmap = LinearSegmentedColormap.from_list(f"cmap{id(self)}", cmap_points)
78 self.scalarMapper = matplotlib.cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
80 def get_color(self, value):
81 rgba = self.scalarMapper.to_rgba(value)
82 return '#%02x%02x%02x%02x' % tuple(int(v * 255) for v in rgba)
85def plot_matrix(matrix: np.ndarray, title: str, xtick_labels: Sequence[str], ytick_labels: Sequence[str], xlabel: str,
86 ylabel: str, normalize=True, figsize: Tuple[int, int] = (9, 9), title_add: str = None) -> matplotlib.figure.Figure:
87 """
88 :param matrix: matrix whose data to plot, where matrix[i, j] will be rendered at x=i, y=j
89 :param title: the plot's title
90 :param xtick_labels: the labels for the x-axis ticks
91 :param ytick_labels: the labels for the y-axis ticks
92 :param xlabel: the label for the x-axis
93 :param ylabel: the label for the y-axis
94 :param normalize: whether to normalise the matrix before plotting it (dividing each entry by the sum of all entries)
95 :param figsize: an optional size of the figure to be created
96 :param title_add: an optional second line to add to the title
97 :return: the figure object
98 """
99 matrix = np.transpose(matrix)
101 if title_add is not None:
102 title += f"\n {title_add} "
104 if normalize:
105 matrix = matrix.astype('float') / matrix.sum()
106 fig, ax = plt.subplots(figsize=figsize)
107 fig.canvas.manager.set_window_title(title.replace("\n", " "))
108 # We want to show all ticks...
109 ax.set(xticks=np.arange(matrix.shape[1]),
110 yticks=np.arange(matrix.shape[0]),
111 # ... and label them with the respective list entries
112 xticklabels=xtick_labels, yticklabels=ytick_labels,
113 title=title,
114 xlabel=xlabel,
115 ylabel=ylabel)
116 im = ax.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues)
117 ax.figure.colorbar(im, ax=ax)
119 # Rotate the tick labels and set their alignment.
120 plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
121 rotation_mode="anchor")
123 # Loop over data dimensions and create text annotations.
124 fmt = '.4f' if normalize else ('.2f' if matrix.dtype.kind == 'f' else 'd')
125 thresh = matrix.max() / 2.
126 for i in range(matrix.shape[0]):
127 for j in range(matrix.shape[1]):
128 ax.text(j, i, format(matrix[i, j], fmt),
129 ha="center", va="center",
130 color="white" if matrix[i, j] > thresh else "black")
131 fig.tight_layout()
132 return fig
135TPlot = TypeVar("TPlot", bound="Plot")
138class Plot:
139 def __init__(self, draw: Callable[[], None] = None, name=None):
140 """
141 :param draw: function which returns a matplotlib.Axes object to show
142 :param name: name/number of the figure, which determines the window caption; it should be unique, as any plot
143 with the same name will have its contents rendered in the same window. By default, figures are number
144 sequentially.
145 """
146 fig, ax = plt.subplots(num=name)
147 self.fig: plt.Figure = fig
148 self.ax: plt.Axes = ax
149 if draw is not None:
150 draw()
152 def xlabel(self: TPlot, label) -> TPlot:
153 self.ax.set_xlabel(label)
154 return self
156 def ylabel(self: TPlot, label) -> TPlot:
157 self.ax.set_ylabel(label)
158 return self
160 def title(self: TPlot, title: str) -> TPlot:
161 self.ax.set_title(title)
162 return self
164 def xlim(self: TPlot, min_value, max_value) -> TPlot:
165 self.ax.set_xlim(min_value, max_value)
166 return self
168 def ylim(self: TPlot, min_value, max_value) -> TPlot:
169 self.ax.set_ylim(min_value, max_value)
170 return self
172 def save(self, path):
173 log.info(f"Saving figure in {path}")
174 self.fig.savefig(path)
176 def xtick(self: TPlot, major=None, minor=None) -> TPlot:
177 """
178 Sets a tick on every integer multiple of the given base values.
179 The major ticks are labelled, the minor ticks are not.
181 :param major: the major tick base value
182 :param minor: the minor tick base value
183 :return: self
184 """
185 if major is not None:
186 self.xtick_major(major)
187 if minor is not None:
188 self.xtick_minor(minor)
189 return self
191 def xtick_major(self: TPlot, base) -> TPlot:
192 self.ax.xaxis.set_major_locator(plticker.MultipleLocator(base=base))
193 return self
195 def xtick_minor(self: TPlot, base) -> TPlot:
196 self.ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=base))
197 return self
199 def ytick_major(self: TPlot, base) -> TPlot:
200 self.ax.yaxis.set_major_locator(plticker.MultipleLocator(base=base))
201 return self
204class ScatterPlot(Plot):
205 N_MAX_TRANSPARENCY = 1000
206 N_MIN_TRANSPARENCY = 100
207 MAX_OPACITY = 0.5
208 MIN_OPACITY = 0.05
210 def __init__(self, x, y, c=None, c_base: Tuple[float, float, float] = (0, 0, 1), c_opacity=None, x_label=None, y_label=None, **kwargs):
211 """
212 :param x: the x values; if has name (e.g. pd.Series), will be used as axis label
213 :param y: the y values; if has name (e.g. pd.Series), will be used as axis label
214 :param c: the colour specification; if None, compose from ``c_base`` and ``c_opacity``
215 :param c_base: the base colour as (R, G, B) floats
216 :param c_opacity: the opacity; if None, automatically determine from number of data points
217 :param x_label:
218 :param y_label:
219 :param kwargs:
220 """
221 if c is None:
222 if c_base is None:
223 c_base = (0, 0, 1)
224 if c_opacity is None:
225 n = len(x)
226 if n > self.N_MAX_TRANSPARENCY:
227 transparency = 1
228 elif n < self.N_MIN_TRANSPARENCY:
229 transparency = 0
230 else:
231 transparency = (n - self.N_MIN_TRANSPARENCY) / (self.N_MAX_TRANSPARENCY - self.N_MIN_TRANSPARENCY)
232 c_opacity = self.MIN_OPACITY + (self.MAX_OPACITY - self.MIN_OPACITY) * (1-transparency)
233 c = ((*c_base, c_opacity),)
235 assert len(x) == len(y)
236 if x_label is None and hasattr(x, "name"):
237 x_label = x.name
238 if y_label is None and hasattr(y, "name"):
239 y_label = y.name
241 def draw():
242 if x_label is not None:
243 plt.xlabel(x_label)
244 if x_label is not None:
245 plt.ylabel(y_label)
246 plt.scatter(x, y, c=c, **kwargs)
248 super().__init__(draw)
251class HeatMapPlot(Plot):
252 DEFAULT_CMAP_FACTORY = lambda num_points: LinearSegmentedColormap.from_list("whiteToRed",
253 ((0, (1, 1, 1)), (1 / num_points, (1, 0.96, 0.96)), (1, (0.7, 0, 0))), num_points)
255 def __init__(self, x, y, x_label=None, y_label=None, bins=60, cmap=None, common_range=True, diagonal=False,
256 diagonal_color="green", **kwargs):
257 """
258 :param x: the x values
259 :param y: the y values
260 :param x_label: the x-axis label
261 :param y_label: the y-axis label
262 :param bins: the number of bins to use in each dimension
263 :param cmap: the colour map to use for heat values (if None, use default)
264 :param common_range: whether the heat map is to use a common rng for the x- and y-axes (set to False if x and y are completely
265 different quantities; set to True use cases such as the evaluation of regression model quality)
266 :param diagonal: whether to draw the diagonal line (useful for regression evaluation)
267 :param diagonal_color: the colour to use for the diagonal line
268 :param kwargs: parameters to pass on to plt.imshow
269 """
270 assert len(x) == len(y)
271 if x_label is None and hasattr(x, "name"):
272 x_label = x.name
273 if y_label is None and hasattr(y, "name"):
274 y_label = y.name
276 def draw():
277 nonlocal cmap
278 x_range = [min(x), max(x)]
279 y_range = [min(y), max(y)]
280 rng = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]
281 if common_range:
282 x_range = y_range = rng
283 if diagonal:
284 plt.plot(rng, rng, '-', lw=0.75, label="_not in legend", color=diagonal_color, zorder=2)
285 heatmap, _, _ = np.histogram2d(x, y, range=[x_range, y_range], bins=bins, density=False)
286 extent = [x_range[0], x_range[1], y_range[0], y_range[1]]
287 if cmap is None:
288 cmap = HeatMapPlot.DEFAULT_CMAP_FACTORY(len(x))
289 if x_label is not None:
290 plt.xlabel(x_label)
291 if y_label is not None:
292 plt.ylabel(y_label)
293 plt.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, aspect="auto", **kwargs)
295 super().__init__(draw)
298class HistogramPlot(Plot):
299 def __init__(self, values, bins="auto", kde=False, cdf=False, cdf_complementary=False, cdf_secondary_axis=True,
300 binwidth=None, stat="probability", xlabel=None,
301 **kwargs):
302 """
303 :param values: the values to plot
304 :param bins: a bin specification as understood by sns.histplot
305 :param kde: whether to add a kernel density estimator
306 :param cdf: whether to add a plot of the cumulative distribution function (cdf)
307 :param cdf_complementary: whether to plot, if cdf is enabled, the complementary values
308 :param cdf_secondary_axis: whether to use, if cdf is enabled, a secondary
309 :param binwidth: the bin width; if None, inferred
310 :param stat: the statistic to plot (as understood by sns.histplot)
311 :param xlabel: the label for the x-axis
312 :param kwargs: arguments to pass on to sns.histplot
313 """
315 def draw():
316 nonlocal cdf_secondary_axis
317 sns.histplot(values, bins=bins, kde=kde, binwidth=binwidth, stat=stat, **kwargs)
318 plt.ylabel(stat)
319 if cdf:
320 ecdf_stat = stat
321 if ecdf_stat not in ("count", "proportion", "probability"):
322 ecdf_stat = "proportion"
323 cdf_secondary_axis = True
324 cdf_ax: Optional[plt.Axes] = None
325 cdf_ax_label = f"{ecdf_stat} (cdf)"
326 if cdf_secondary_axis:
327 cdf_ax: plt.Axes = plt.twinx()
328 if stat in ("proportion", "probability"):
329 y_tick = 0.1
330 elif stat == "percent":
331 y_tick = 10
332 else:
333 y_tick = None
334 if y_tick is not None:
335 cdf_ax.yaxis.set_major_locator(plticker.MultipleLocator(base=y_tick))
336 if cdf_complementary or ecdf_stat in ("count", "proportion", "probability"):
337 ecdf_stat = "proportion" if stat == "probability" else stat # same semantics but "probability" not understood by ecdfplot
338 sns.ecdfplot(values, stat=ecdf_stat, complementary=cdf_complementary, color="orange", ax=cdf_ax)
339 else:
340 sns.histplot(values, bins=100, stat=stat, element="poly", fill=False, cumulative=True, color="orange", ax=cdf_ax)
341 if cdf_ax is not None:
342 cdf_ax.set_ylabel(cdf_ax_label)
343 if xlabel is not None:
344 self.xlabel(xlabel)
346 super().__init__(draw)