Coverage for src/sensai/util/plot.py: 23%

217 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import collections 

2import logging 

3from typing import Sequence, Callable, TypeVar, Tuple, Optional, List, Any, Union, Dict 

4 

5import matplotlib.figure 

6import matplotlib.ticker as plticker 

7import numpy as np 

8import pandas as pd 

9import seaborn as sns 

10from matplotlib import pyplot as plt 

11from matplotlib.colors import LinearSegmentedColormap 

12 

13from sensai.util.pandas import SeriesInterpolation 

14 

15log = logging.getLogger(__name__) 

16 

17MATPLOTLIB_DEFAULT_FIGURE_SIZE = (6.4, 4.8) 

18 

19 

20class Color: 

21 def __init__(self, c: Any): 

22 """ 

23 :param c: any color specification that is understood by matplotlib 

24 """ 

25 self.rgba = matplotlib.colors.to_rgba(c) 

26 

27 def darken(self, amount: float): 

28 """ 

29 :param amount: amount to darken in [0,1], where 1 results in black and 0 leaves the color unchanged 

30 :return: the darkened color 

31 """ 

32 import colorsys 

33 rgb = matplotlib.colors.to_rgb(self.rgba) 

34 h, l, s = colorsys.rgb_to_hls(*rgb) 

35 l *= amount 

36 rgb = colorsys.hls_to_rgb(h, l, s) 

37 return Color((*rgb, self.rgba[3])) 

38 

39 def lighten(self, amount: float): 

40 """ 

41 :param amount: amount to lighten in [0,1], where 1 results in white and 0 leaves the color unchanged 

42 :return: the lightened color 

43 """ 

44 import colorsys 

45 rgb = matplotlib.colors.to_rgb(self.rgba) 

46 h, l, s = colorsys.rgb_to_hls(*rgb) 

47 l += (1-l) * amount 

48 rgb = colorsys.hls_to_rgb(h, l, s) 

49 return Color((*rgb, self.rgba[3])) 

50 

51 def alpha(self, opacity: float) -> "Color": 

52 """ 

53 Returns a new color with modified alpha channel (opacity) 

54 :param opacity: the opacity between 0 (transparent) and 1 (fully opaque) 

55 :return: the modified color 

56 """ 

57 if not (0 <= opacity <= 1): 

58 raise ValueError(f"Opacity must be between 0 and 1, got {opacity}") 

59 return Color((*self.rgba[:3], opacity)) 

60 

61 def to_hex(self, keep_alpha=True) -> str: 

62 return matplotlib.colors.to_hex(self.rgba, keep_alpha) 

63 

64 

65class LinearColorMap: 

66 """ 

67 Facilitates usage of linear segmented colour maps by combining a colour map (member `cmap`), which transforms normalised values in [0,1] 

68 into colours, with a normaliser that transforms the original values. The member `scalarMapper` 

69 """ 

70 def __init__(self, norm_min, norm_max, cmap_points: List[Tuple[float, Any]], cmap_points_normalised=False): 

71 """ 

72 :param norm_min: the value that shall be mapped to 0 in the normalised representation (any smaller values are also clipped to 0) 

73 :param norm_max: the value that shall be mapped to 1 in the normalised representation (any larger values are also clipped to 1) 

74 :param cmap_points: a list (of at least two) tuples (v, c) where v is the value and c is the colour associated with the value; 

75 any colour specification supported by matplotlib is admissible 

76 :param cmap_points_normalised: whether the values in `cmap_points` are already normalised 

77 """ 

78 self.norm = matplotlib.colors.Normalize(vmin=norm_min, vmax=norm_max, clip=True) 

79 if not cmap_points_normalised: 

80 cmap_points = [(self.norm(v), c) for v, c in cmap_points] 

81 self.cmap = LinearSegmentedColormap.from_list(f"cmap{id(self)}", cmap_points) 

82 self.scalarMapper = matplotlib.cm.ScalarMappable(norm=self.norm, cmap=self.cmap) 

83 

84 def get_color(self, value): 

85 rgba = self.scalarMapper.to_rgba(value) 

86 return '#%02x%02x%02x%02x' % tuple(int(v * 255) for v in rgba) 

87 

88 

89def plot_matrix(matrix: np.ndarray, title: str, xtick_labels: Sequence[str], ytick_labels: Sequence[str], xlabel: str, 

90 ylabel: str, normalize=True, figsize: Tuple[int, int] = (9, 9), title_add: str = None) -> matplotlib.figure.Figure: 

91 """ 

92 :param matrix: matrix whose data to plot, where matrix[i, j] will be rendered at x=i, y=j 

93 :param title: the plot's title 

94 :param xtick_labels: the labels for the x-axis ticks 

95 :param ytick_labels: the labels for the y-axis ticks 

96 :param xlabel: the label for the x-axis 

97 :param ylabel: the label for the y-axis 

98 :param normalize: whether to normalise the matrix before plotting it (dividing each entry by the sum of all entries) 

99 :param figsize: an optional size of the figure to be created 

100 :param title_add: an optional second line to add to the title 

101 :return: the figure object 

102 """ 

103 matrix = np.transpose(matrix) 

104 

105 if title_add is not None: 

106 title += f"\n {title_add} " 

107 

108 if normalize: 

109 matrix = matrix.astype('float') / matrix.sum() 

110 fig, ax = plt.subplots(figsize=figsize) 

111 fig.canvas.manager.set_window_title(title.replace("\n", " ")) 

112 # We want to show all ticks... 

113 ax.set(xticks=np.arange(matrix.shape[1]), 

114 yticks=np.arange(matrix.shape[0]), 

115 # ... and label them with the respective list entries 

116 xticklabels=xtick_labels, yticklabels=ytick_labels, 

117 title=title, 

118 xlabel=xlabel, 

119 ylabel=ylabel) 

120 im = ax.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues) 

121 ax.figure.colorbar(im, ax=ax) 

122 

123 # Rotate the tick labels and set their alignment. 

124 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 

125 rotation_mode="anchor") 

126 

127 # Loop over data dimensions and create text annotations. 

128 fmt = '.4f' if normalize else ('.2f' if matrix.dtype.kind == 'f' else 'd') 

129 thresh = matrix.max() / 2. 

130 for i in range(matrix.shape[0]): 

131 for j in range(matrix.shape[1]): 

132 ax.text(j, i, format(matrix[i, j], fmt), 

133 ha="center", va="center", 

134 color="white" if matrix[i, j] > thresh else "black") 

135 fig.tight_layout() 

136 return fig 

137 

138 

139TPlot = TypeVar("TPlot", bound="Plot") 

140 

141 

142class Plot: 

143 def __init__(self, draw: Callable[[plt.Axes], None] = None, name=None, ax: Optional[plt.Axes] = None): 

144 """ 

145 :param draw: function which returns a matplotlib.Axes object to show 

146 :param name: name/number of the figure, which determines the window caption; it should be unique, as any plot 

147 with the same name will have its contents rendered in the same window. By default, figures are number 

148 sequentially. 

149 :param ax: the axes to draw to 

150 """ 

151 if ax is not None: 

152 fig = None 

153 else: 

154 fig, ax = plt.subplots(num=name) 

155 self.fig: plt.Figure = fig 

156 self.ax: plt.Axes = ax 

157 if draw is not None: 

158 draw(ax) 

159 

160 def xlabel(self: TPlot, label) -> TPlot: 

161 self.ax.set_xlabel(label) 

162 return self 

163 

164 def ylabel(self: TPlot, label) -> TPlot: 

165 self.ax.set_ylabel(label) 

166 return self 

167 

168 def title(self: TPlot, title: str) -> TPlot: 

169 self.ax.set_title(title) 

170 return self 

171 

172 def xlim(self: TPlot, min_value, max_value) -> TPlot: 

173 self.ax.set_xlim(min_value, max_value) 

174 return self 

175 

176 def ylim(self: TPlot, min_value, max_value) -> TPlot: 

177 self.ax.set_ylim(min_value, max_value) 

178 return self 

179 

180 def save(self, path): 

181 log.info(f"Saving figure in {path}") 

182 self.fig.savefig(path) 

183 

184 def xtick(self: TPlot, major=None, minor=None) -> TPlot: 

185 """ 

186 Sets a tick on every integer multiple of the given base values. 

187 The major ticks are labelled, the minor ticks are not. 

188 

189 :param major: the major tick base value 

190 :param minor: the minor tick base value 

191 :return: self 

192 """ 

193 if major is not None: 

194 self.xtick_major(major) 

195 if minor is not None: 

196 self.xtick_minor(minor) 

197 return self 

198 

199 def xtick_major(self: TPlot, base) -> TPlot: 

200 self.ax.xaxis.set_major_locator(plticker.MultipleLocator(base=base)) 

201 return self 

202 

203 def xtick_minor(self: TPlot, base) -> TPlot: 

204 self.ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=base)) 

205 return self 

206 

207 def ytick_major(self: TPlot, base) -> TPlot: 

208 self.ax.yaxis.set_major_locator(plticker.MultipleLocator(base=base)) 

209 return self 

210 

211 

212class ScatterPlot(Plot): 

213 N_MAX_TRANSPARENCY = 1000 

214 N_MIN_TRANSPARENCY = 100 

215 MAX_OPACITY = 0.5 

216 MIN_OPACITY = 0.05 

217 

218 def __init__(self, x, y, c=None, c_base: Tuple[float, float, float] = (0, 0, 1), c_opacity=None, x_label=None, 

219 y_label=None, add_diagonal=False, **kwargs): 

220 """ 

221 :param x: the x values; if has name (e.g. pd.Series), will be used as axis label 

222 :param y: the y values; if has name (e.g. pd.Series), will be used as axis label 

223 :param c: the colour specification; if None, compose from ``c_base`` and ``c_opacity`` 

224 :param c_base: the base colour as (R, G, B) floats 

225 :param c_opacity: the opacity; if None, automatically determine from number of data points 

226 :param x_label: 

227 :param y_label: 

228 :param kwargs: 

229 """ 

230 if c is None: 

231 if c_base is None: 

232 c_base = (0, 0, 1) 

233 if c_opacity is None: 

234 n = len(x) 

235 if n > self.N_MAX_TRANSPARENCY: 

236 transparency = 1 

237 elif n < self.N_MIN_TRANSPARENCY: 

238 transparency = 0 

239 else: 

240 transparency = (n - self.N_MIN_TRANSPARENCY) / (self.N_MAX_TRANSPARENCY - self.N_MIN_TRANSPARENCY) 

241 c_opacity = self.MIN_OPACITY + (self.MAX_OPACITY - self.MIN_OPACITY) * (1-transparency) 

242 c = ((*c_base, c_opacity),) 

243 

244 assert len(x) == len(y) 

245 if x_label is None and hasattr(x, "name"): 

246 x_label = x.name 

247 if y_label is None and hasattr(y, "name"): 

248 y_label = y.name 

249 

250 def draw(ax): 

251 if x_label is not None: 

252 plt.xlabel(x_label) 

253 if x_label is not None: 

254 plt.ylabel(y_label) 

255 value_range = [min(min(x), min(y)), max(max(x), max(y))] 

256 plt.scatter(x, y, c=c, **kwargs) 

257 if add_diagonal: 

258 plt.plot(value_range, value_range, '-', lw=1, label="_not in legend", color="green", zorder=1) 

259 

260 super().__init__(draw) 

261 

262 

263class HeatMapPlot(Plot): 

264 DEFAULT_CMAP_FACTORY = lambda num_points: LinearSegmentedColormap.from_list("whiteToRed", 

265 ((0, (1, 1, 1)), (1 / num_points, (1, 0.96, 0.96)), (1, (0.7, 0, 0))), num_points) 

266 

267 def __init__(self, x, y, x_label=None, y_label=None, bins=60, cmap=None, common_range=True, diagonal=False, 

268 diagonal_color="green", **kwargs): 

269 """ 

270 :param x: the x values 

271 :param y: the y values 

272 :param x_label: the x-axis label 

273 :param y_label: the y-axis label 

274 :param bins: the number of bins to use in each dimension 

275 :param cmap: the colour map to use for heat values (if None, use default) 

276 :param common_range: whether the heat map is to use a common rng for the x- and y-axes (set to False if x and y are completely 

277 different quantities; set to True use cases such as the evaluation of regression model quality) 

278 :param diagonal: whether to draw the diagonal line (useful for regression evaluation) 

279 :param diagonal_color: the colour to use for the diagonal line 

280 :param kwargs: parameters to pass on to plt.imshow 

281 """ 

282 assert len(x) == len(y) 

283 if x_label is None and hasattr(x, "name"): 

284 x_label = x.name 

285 if y_label is None and hasattr(y, "name"): 

286 y_label = y.name 

287 

288 def draw(ax): 

289 nonlocal cmap 

290 x_range = [min(x), max(x)] 

291 y_range = [min(y), max(y)] 

292 rng = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])] 

293 if common_range: 

294 x_range = y_range = rng 

295 if diagonal: 

296 plt.plot(rng, rng, '-', lw=0.75, label="_not in legend", color=diagonal_color, zorder=2) 

297 heatmap, _, _ = np.histogram2d(x, y, range=[x_range, y_range], bins=bins, density=False) 

298 extent = [x_range[0], x_range[1], y_range[0], y_range[1]] 

299 if cmap is None: 

300 cmap = HeatMapPlot.DEFAULT_CMAP_FACTORY(len(x)) 

301 if x_label is not None: 

302 plt.xlabel(x_label) 

303 if y_label is not None: 

304 plt.ylabel(y_label) 

305 plt.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, aspect="auto", **kwargs) 

306 

307 super().__init__(draw) 

308 

309 

310class HistogramPlot(Plot): 

311 def __init__(self, values, bins="auto", kde=False, cdf=False, cdf_complementary=False, cdf_secondary_axis=True, 

312 binwidth=None, stat="probability", xlabel=None, 

313 **kwargs): 

314 """ 

315 :param values: the values to plot 

316 :param bins: a bin specification as understood by sns.histplot 

317 :param kde: whether to add a kernel density estimator 

318 :param cdf: whether to add a plot of the cumulative distribution function (cdf) 

319 :param cdf_complementary: whether to plot, if cdf is enabled, the complementary values 

320 :param cdf_secondary_axis: whether to use, if cdf is enabled, a secondary 

321 :param binwidth: the bin width; if None, inferred 

322 :param stat: the statistic to plot (as understood by sns.histplot) 

323 :param xlabel: the label for the x-axis 

324 :param kwargs: arguments to pass on to sns.histplot 

325 """ 

326 

327 def draw(ax): 

328 nonlocal cdf_secondary_axis 

329 sns.histplot(values, bins=bins, kde=kde, binwidth=binwidth, stat=stat, **kwargs) 

330 plt.ylabel(stat) 

331 if cdf: 

332 ecdf_stat = stat 

333 if ecdf_stat not in ("count", "proportion", "probability"): 

334 ecdf_stat = "proportion" 

335 cdf_secondary_axis = True 

336 cdf_ax: Optional[plt.Axes] = None 

337 cdf_ax_label = f"{ecdf_stat} (cdf)" 

338 if cdf_secondary_axis: 

339 cdf_ax: plt.Axes = plt.twinx() 

340 if stat in ("proportion", "probability"): 

341 y_tick = 0.1 

342 elif stat == "percent": 

343 y_tick = 10 

344 else: 

345 y_tick = None 

346 if y_tick is not None: 

347 cdf_ax.yaxis.set_major_locator(plticker.MultipleLocator(base=y_tick)) 

348 if cdf_complementary or ecdf_stat in ("count", "proportion", "probability"): 

349 ecdf_stat = "proportion" if stat == "probability" else stat # same semantics but "probability" not understood by ecdfplot 

350 sns.ecdfplot(values, stat=ecdf_stat, complementary=cdf_complementary, color="orange", ax=cdf_ax) 

351 else: 

352 sns.histplot(values, bins=100, stat=stat, element="poly", fill=False, cumulative=True, color="orange", ax=cdf_ax) 

353 if cdf_ax is not None: 

354 cdf_ax.set_ylabel(cdf_ax_label) 

355 if xlabel is not None: 

356 self.xlabel(xlabel) 

357 

358 super().__init__(draw) 

359 

360 

361class AverageSeriesLinePlot(Plot): 

362 """ 

363 Plots the average of a collection of series or the averages of several collections of series, 

364 establishing a common index (the unification of all indices) for each collection via interpolation. 

365 The standard deviation is additionally shown as a shaded area around each line. 

366 """ 

367 def __init__(self, 

368 series_collection: Union[List[pd.Series], Dict[str, List[pd.Series]]], 

369 interpolation: SeriesInterpolation, 

370 collection_name="collection", 

371 ax: Optional[plt.Axes] = None, 

372 hue_order=None, palette=None): 

373 """ 

374 :param series_collection: either a list of series to average or a dictionary mapping the name of a collection 

375 to a list of series to average 

376 :param interpolation: the interpolation with which to obtain series values for the unified index of a collection of series 

377 :param collection_name: a name indicating what a key in `series_collection` refers to, which will appear in the legend 

378 for the case where more than one collection is passed 

379 :param ax: the axis to plot to; if None, create a new figure and axis 

380 :param hue_order: the hue order (for the case where there is more than one collection of series) 

381 :param palette: the colour palette to use 

382 """ 

383 if isinstance(series_collection, dict): 

384 series_dict = series_collection 

385 else: 

386 series_dict = {"_": series_collection} 

387 

388 series_list = next(iter(series_dict.values())) 

389 x_name = series_list[0].index.name or "x" 

390 y_name = series_list[0].name or "y" 

391 

392 # build data frame with all series, interpolating each sub-collection 

393 dfs = [] 

394 for name, series_list in series_dict.items(): 

395 interpolated_series_list = interpolation.interpolate_all_with_combined_index(series_list) 

396 for series in interpolated_series_list: 

397 df = pd.DataFrame({y_name: series, x_name: series.index}) 

398 df["series_id"] = id(series) 

399 df[collection_name] = name 

400 dfs.append(df) 

401 full_df = pd.concat(dfs, axis=0).reset_index(drop=True) 

402 

403 def draw(ax): 

404 sns.lineplot( 

405 data=full_df, 

406 x=x_name, 

407 y=y_name, 

408 estimator="mean", 

409 hue=collection_name if len(series_dict) > 1 else None, 

410 hue_order=hue_order, 

411 palette=palette, 

412 ax=ax, 

413 ) 

414 

415 super().__init__(draw, ax=ax)