Coverage for src/sensai/util/plot.py: 23%

189 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import logging 

2from typing import Sequence, Callable, TypeVar, Tuple, Optional, List, Any 

3 

4import matplotlib.figure 

5import matplotlib.ticker as plticker 

6import numpy as np 

7import seaborn as sns 

8from matplotlib import pyplot as plt 

9from matplotlib.colors import LinearSegmentedColormap 

10 

11log = logging.getLogger(__name__) 

12 

13MATPLOTLIB_DEFAULT_FIGURE_SIZE = (6.4, 4.8) 

14 

15 

16class Color: 

17 def __init__(self, c: Any): 

18 """ 

19 :param c: any color specification that is understood by matplotlib 

20 """ 

21 self.rgba = matplotlib.colors.to_rgba(c) 

22 

23 def darken(self, amount: float): 

24 """ 

25 :param amount: amount to darken in [0,1], where 1 results in black and 0 leaves the color unchanged 

26 :return: the darkened color 

27 """ 

28 import colorsys 

29 rgb = matplotlib.colors.to_rgb(self.rgba) 

30 h, l, s = colorsys.rgb_to_hls(*rgb) 

31 l *= amount 

32 rgb = colorsys.hls_to_rgb(h, l, s) 

33 return Color((*rgb, self.rgba[3])) 

34 

35 def lighten(self, amount: float): 

36 """ 

37 :param amount: amount to lighten in [0,1], where 1 results in white and 0 leaves the color unchanged 

38 :return: the lightened color 

39 """ 

40 import colorsys 

41 rgb = matplotlib.colors.to_rgb(self.rgba) 

42 h, l, s = colorsys.rgb_to_hls(*rgb) 

43 l += (1-l) * amount 

44 rgb = colorsys.hls_to_rgb(h, l, s) 

45 return Color((*rgb, self.rgba[3])) 

46 

47 def alpha(self, opacity: float) -> "Color": 

48 """ 

49 Returns a new color with modified alpha channel (opacity) 

50 :param opacity: the opacity between 0 (transparent) and 1 (fully opaque) 

51 :return: the modified color 

52 """ 

53 if not (0 <= opacity <= 1): 

54 raise ValueError(f"Opacity must be between 0 and 1, got {opacity}") 

55 return Color((*self.rgba[:3], opacity)) 

56 

57 def to_hex(self, keep_alpha=True) -> str: 

58 return matplotlib.colors.to_hex(self.rgba, keep_alpha) 

59 

60 

61class LinearColorMap: 

62 """ 

63 Facilitates usage of linear segmented colour maps by combining a colour map (member `cmap`), which transforms normalised values in [0,1] 

64 into colours, with a normaliser that transforms the original values. The member `scalarMapper` 

65 """ 

66 def __init__(self, norm_min, norm_max, cmap_points: List[Tuple[float, Any]], cmap_points_normalised=False): 

67 """ 

68 :param norm_min: the value that shall be mapped to 0 in the normalised representation (any smaller values are also clipped to 0) 

69 :param norm_max: the value that shall be mapped to 1 in the normalised representation (any larger values are also clipped to 1) 

70 :param cmap_points: a list (of at least two) tuples (v, c) where v is the value and c is the colour associated with the value; 

71 any colour specification supported by matplotlib is admissible 

72 :param cmap_points_normalised: whether the values in `cmap_points` are already normalised 

73 """ 

74 self.norm = matplotlib.colors.Normalize(vmin=norm_min, vmax=norm_max, clip=True) 

75 if not cmap_points_normalised: 

76 cmap_points = [(self.norm(v), c) for v, c in cmap_points] 

77 self.cmap = LinearSegmentedColormap.from_list(f"cmap{id(self)}", cmap_points) 

78 self.scalarMapper = matplotlib.cm.ScalarMappable(norm=self.norm, cmap=self.cmap) 

79 

80 def get_color(self, value): 

81 rgba = self.scalarMapper.to_rgba(value) 

82 return '#%02x%02x%02x%02x' % tuple(int(v * 255) for v in rgba) 

83 

84 

85def plot_matrix(matrix: np.ndarray, title: str, xtick_labels: Sequence[str], ytick_labels: Sequence[str], xlabel: str, 

86 ylabel: str, normalize=True, figsize: Tuple[int, int] = (9, 9), title_add: str = None) -> matplotlib.figure.Figure: 

87 """ 

88 :param matrix: matrix whose data to plot, where matrix[i, j] will be rendered at x=i, y=j 

89 :param title: the plot's title 

90 :param xtick_labels: the labels for the x-axis ticks 

91 :param ytick_labels: the labels for the y-axis ticks 

92 :param xlabel: the label for the x-axis 

93 :param ylabel: the label for the y-axis 

94 :param normalize: whether to normalise the matrix before plotting it (dividing each entry by the sum of all entries) 

95 :param figsize: an optional size of the figure to be created 

96 :param title_add: an optional second line to add to the title 

97 :return: the figure object 

98 """ 

99 matrix = np.transpose(matrix) 

100 

101 if title_add is not None: 

102 title += f"\n {title_add} " 

103 

104 if normalize: 

105 matrix = matrix.astype('float') / matrix.sum() 

106 fig, ax = plt.subplots(figsize=figsize) 

107 fig.canvas.manager.set_window_title(title.replace("\n", " ")) 

108 # We want to show all ticks... 

109 ax.set(xticks=np.arange(matrix.shape[1]), 

110 yticks=np.arange(matrix.shape[0]), 

111 # ... and label them with the respective list entries 

112 xticklabels=xtick_labels, yticklabels=ytick_labels, 

113 title=title, 

114 xlabel=xlabel, 

115 ylabel=ylabel) 

116 im = ax.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues) 

117 ax.figure.colorbar(im, ax=ax) 

118 

119 # Rotate the tick labels and set their alignment. 

120 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 

121 rotation_mode="anchor") 

122 

123 # Loop over data dimensions and create text annotations. 

124 fmt = '.4f' if normalize else ('.2f' if matrix.dtype.kind == 'f' else 'd') 

125 thresh = matrix.max() / 2. 

126 for i in range(matrix.shape[0]): 

127 for j in range(matrix.shape[1]): 

128 ax.text(j, i, format(matrix[i, j], fmt), 

129 ha="center", va="center", 

130 color="white" if matrix[i, j] > thresh else "black") 

131 fig.tight_layout() 

132 return fig 

133 

134 

135TPlot = TypeVar("TPlot", bound="Plot") 

136 

137 

138class Plot: 

139 def __init__(self, draw: Callable[[], None] = None, name=None): 

140 """ 

141 :param draw: function which returns a matplotlib.Axes object to show 

142 :param name: name/number of the figure, which determines the window caption; it should be unique, as any plot 

143 with the same name will have its contents rendered in the same window. By default, figures are number 

144 sequentially. 

145 """ 

146 fig, ax = plt.subplots(num=name) 

147 self.fig: plt.Figure = fig 

148 self.ax: plt.Axes = ax 

149 if draw is not None: 

150 draw() 

151 

152 def xlabel(self: TPlot, label) -> TPlot: 

153 self.ax.set_xlabel(label) 

154 return self 

155 

156 def ylabel(self: TPlot, label) -> TPlot: 

157 self.ax.set_ylabel(label) 

158 return self 

159 

160 def title(self: TPlot, title: str) -> TPlot: 

161 self.ax.set_title(title) 

162 return self 

163 

164 def xlim(self: TPlot, min_value, max_value) -> TPlot: 

165 self.ax.set_xlim(min_value, max_value) 

166 return self 

167 

168 def ylim(self: TPlot, min_value, max_value) -> TPlot: 

169 self.ax.set_ylim(min_value, max_value) 

170 return self 

171 

172 def save(self, path): 

173 log.info(f"Saving figure in {path}") 

174 self.fig.savefig(path) 

175 

176 def xtick(self: TPlot, major=None, minor=None) -> TPlot: 

177 """ 

178 Sets a tick on every integer multiple of the given base values. 

179 The major ticks are labelled, the minor ticks are not. 

180 

181 :param major: the major tick base value 

182 :param minor: the minor tick base value 

183 :return: self 

184 """ 

185 if major is not None: 

186 self.xtick_major(major) 

187 if minor is not None: 

188 self.xtick_minor(minor) 

189 return self 

190 

191 def xtick_major(self: TPlot, base) -> TPlot: 

192 self.ax.xaxis.set_major_locator(plticker.MultipleLocator(base=base)) 

193 return self 

194 

195 def xtick_minor(self: TPlot, base) -> TPlot: 

196 self.ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=base)) 

197 return self 

198 

199 def ytick_major(self: TPlot, base) -> TPlot: 

200 self.ax.yaxis.set_major_locator(plticker.MultipleLocator(base=base)) 

201 return self 

202 

203 

204class ScatterPlot(Plot): 

205 N_MAX_TRANSPARENCY = 1000 

206 N_MIN_TRANSPARENCY = 100 

207 MAX_OPACITY = 0.5 

208 MIN_OPACITY = 0.05 

209 

210 def __init__(self, x, y, c=None, c_base: Tuple[float, float, float] = (0, 0, 1), c_opacity=None, x_label=None, y_label=None, **kwargs): 

211 """ 

212 :param x: the x values; if has name (e.g. pd.Series), will be used as axis label 

213 :param y: the y values; if has name (e.g. pd.Series), will be used as axis label 

214 :param c: the colour specification; if None, compose from ``c_base`` and ``c_opacity`` 

215 :param c_base: the base colour as (R, G, B) floats 

216 :param c_opacity: the opacity; if None, automatically determine from number of data points 

217 :param x_label: 

218 :param y_label: 

219 :param kwargs: 

220 """ 

221 if c is None: 

222 if c_base is None: 

223 c_base = (0, 0, 1) 

224 if c_opacity is None: 

225 n = len(x) 

226 if n > self.N_MAX_TRANSPARENCY: 

227 transparency = 1 

228 elif n < self.N_MIN_TRANSPARENCY: 

229 transparency = 0 

230 else: 

231 transparency = (n - self.N_MIN_TRANSPARENCY) / (self.N_MAX_TRANSPARENCY - self.N_MIN_TRANSPARENCY) 

232 c_opacity = self.MIN_OPACITY + (self.MAX_OPACITY - self.MIN_OPACITY) * (1-transparency) 

233 c = ((*c_base, c_opacity),) 

234 

235 assert len(x) == len(y) 

236 if x_label is None and hasattr(x, "name"): 

237 x_label = x.name 

238 if y_label is None and hasattr(y, "name"): 

239 y_label = y.name 

240 

241 def draw(): 

242 if x_label is not None: 

243 plt.xlabel(x_label) 

244 if x_label is not None: 

245 plt.ylabel(y_label) 

246 plt.scatter(x, y, c=c, **kwargs) 

247 

248 super().__init__(draw) 

249 

250 

251class HeatMapPlot(Plot): 

252 DEFAULT_CMAP_FACTORY = lambda num_points: LinearSegmentedColormap.from_list("whiteToRed", 

253 ((0, (1, 1, 1)), (1 / num_points, (1, 0.96, 0.96)), (1, (0.7, 0, 0))), num_points) 

254 

255 def __init__(self, x, y, x_label=None, y_label=None, bins=60, cmap=None, common_range=True, diagonal=False, 

256 diagonal_color="green", **kwargs): 

257 """ 

258 :param x: the x values 

259 :param y: the y values 

260 :param x_label: the x-axis label 

261 :param y_label: the y-axis label 

262 :param bins: the number of bins to use in each dimension 

263 :param cmap: the colour map to use for heat values (if None, use default) 

264 :param common_range: whether the heat map is to use a common rng for the x- and y-axes (set to False if x and y are completely 

265 different quantities; set to True use cases such as the evaluation of regression model quality) 

266 :param diagonal: whether to draw the diagonal line (useful for regression evaluation) 

267 :param diagonal_color: the colour to use for the diagonal line 

268 :param kwargs: parameters to pass on to plt.imshow 

269 """ 

270 assert len(x) == len(y) 

271 if x_label is None and hasattr(x, "name"): 

272 x_label = x.name 

273 if y_label is None and hasattr(y, "name"): 

274 y_label = y.name 

275 

276 def draw(): 

277 nonlocal cmap 

278 x_range = [min(x), max(x)] 

279 y_range = [min(y), max(y)] 

280 rng = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])] 

281 if common_range: 

282 x_range = y_range = rng 

283 if diagonal: 

284 plt.plot(rng, rng, '-', lw=0.75, label="_not in legend", color=diagonal_color, zorder=2) 

285 heatmap, _, _ = np.histogram2d(x, y, range=[x_range, y_range], bins=bins, density=False) 

286 extent = [x_range[0], x_range[1], y_range[0], y_range[1]] 

287 if cmap is None: 

288 cmap = HeatMapPlot.DEFAULT_CMAP_FACTORY(len(x)) 

289 if x_label is not None: 

290 plt.xlabel(x_label) 

291 if y_label is not None: 

292 plt.ylabel(y_label) 

293 plt.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, aspect="auto", **kwargs) 

294 

295 super().__init__(draw) 

296 

297 

298class HistogramPlot(Plot): 

299 def __init__(self, values, bins="auto", kde=False, cdf=False, cdf_complementary=False, cdf_secondary_axis=True, 

300 binwidth=None, stat="probability", xlabel=None, 

301 **kwargs): 

302 """ 

303 :param values: the values to plot 

304 :param bins: a bin specification as understood by sns.histplot 

305 :param kde: whether to add a kernel density estimator 

306 :param cdf: whether to add a plot of the cumulative distribution function (cdf) 

307 :param cdf_complementary: whether to plot, if cdf is enabled, the complementary values 

308 :param cdf_secondary_axis: whether to use, if cdf is enabled, a secondary 

309 :param binwidth: the bin width; if None, inferred 

310 :param stat: the statistic to plot (as understood by sns.histplot) 

311 :param xlabel: the label for the x-axis 

312 :param kwargs: arguments to pass on to sns.histplot 

313 """ 

314 

315 def draw(): 

316 nonlocal cdf_secondary_axis 

317 sns.histplot(values, bins=bins, kde=kde, binwidth=binwidth, stat=stat, **kwargs) 

318 plt.ylabel(stat) 

319 if cdf: 

320 ecdf_stat = stat 

321 if ecdf_stat not in ("count", "proportion", "probability"): 

322 ecdf_stat = "proportion" 

323 cdf_secondary_axis = True 

324 cdf_ax: Optional[plt.Axes] = None 

325 cdf_ax_label = f"{ecdf_stat} (cdf)" 

326 if cdf_secondary_axis: 

327 cdf_ax: plt.Axes = plt.twinx() 

328 if stat in ("proportion", "probability"): 

329 y_tick = 0.1 

330 elif stat == "percent": 

331 y_tick = 10 

332 else: 

333 y_tick = None 

334 if y_tick is not None: 

335 cdf_ax.yaxis.set_major_locator(plticker.MultipleLocator(base=y_tick)) 

336 if cdf_complementary or ecdf_stat in ("count", "proportion", "probability"): 

337 ecdf_stat = "proportion" if stat == "probability" else stat # same semantics but "probability" not understood by ecdfplot 

338 sns.ecdfplot(values, stat=ecdf_stat, complementary=cdf_complementary, color="orange", ax=cdf_ax) 

339 else: 

340 sns.histplot(values, bins=100, stat=stat, element="poly", fill=False, cumulative=True, color="orange", ax=cdf_ax) 

341 if cdf_ax is not None: 

342 cdf_ax.set_ylabel(cdf_ax_label) 

343 if xlabel is not None: 

344 self.xlabel(xlabel) 

345 

346 super().__init__(draw)