Coverage for src/colorspace/swatchplot.py: 94%

180 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-23 19:54 +0000

1 

2 

3def swatchplot(pals, show_names = True, nrow = 20, n = 5, cvd = None, **kwargs): 

4 """Palette Swatch Plot 

5 

6 Visualization of color palettes in columns of color swatches. 

7 The first argument `pals` is very flexible and can be: 

8 

9 * List of hex colors, 

10 * a single object which inherits from `colorspace.palettes.palette`, 

11 `colorspace.palettes.hclpalette`, 

12 `colorspace.colorlib.colorobject`, 

13 * a list of objects listed above (all of the same type or mixed), 

14 * a dictionary with lists of objects as above. If a dictionary is used 

15 the keys of the dictionary are used as 'subtitles' to group sets 

16 of palettes, 

17 * an object of class `colorspace.palettes.hclpalettes`, 

18 * or an object of class `matplotlib.colors.LinearSegmentedColormap` or 

19 `matplotlib.colors.ListedColormap`. 

20 

21 Requires the `matplotlib` to be installed. 

22 

23 Args: 

24 pals: The color palettes or color objects to be visualized. 

25 See description for details and examples to demonstrate different 

26 usages. 

27 show_names (bool): Should palette names be shown (if available), defaults to True. 

28 nrow (int): Maximum number of rows of swatches, defaults to `20`. 

29 n (int): Number of colors to be drawn from palette objects, defaults to `5`. 

30 cvd (None or list): Allows to display one or multiple palettes and how they look 

31 with emulated color vision deficiencies. If `None`, this is not applied. 

32 Can be set to a list of characters. Allowed: 

33 `"protan"`, `"tritan"`, `"deutan"`, `"desaturate"` corresponding to the functions 

34 :py:func:`protan <colorspace.CVD.protan>`, 

35 :py:func:`tritan <colorspace.CVD.tritan>`, 

36 :py:func:`deutan <colorspace.CVD.deutan>`, 

37 :py:func:`desaturate <colorspace.CVD.desaturate>`. 

38 **kwargs: forwarded to `matplotlib.pyplot.subplot`, can be used to control e.g., 

39 `figsize`. 

40 

41 Example: 

42 

43 >>> from colorspace import swatchplot, palette 

44 >>> from colorspace import sequential_hcl, diverging_hcl, heat_hcl 

45 >>> 

46 >>> # List of hex colors 

47 >>> swatchplot(['#7FBFF5', '#2A4962', '#111111', '#633C39', '#F8A29E'], 

48 >>> figsize = (7, 0.5)); 

49 >>> 

50 >>> #: Create a custom 'palette' (named): 

51 >>> pal = palette(['#7FBFF5', '#2A4962', '#111111', '#633C39', '#F8A29E'], 

52 >>> "Custom Named Palette") 

53 >>> swatchplot(pal, figsize = (7, 0.5)); 

54 >>> 

55 >>> #: A HCL palette. 'n' defines the number of colors. 

56 >>> swatchplot(sequential_hcl("PuBu"), n = 10, 

57 >>> figsize = (7, 0.5)); 

58 >>> 

59 >>> #: Combine all three 

60 >>> swatchplot([['#7FBFF5', '#2A4962', '#111111', '#633C39', '#F8A29E'], 

61 >>> pal, sequential_hcl("PuBu")], n = 7, 

62 >>> figsize = (7, 1.5)); 

63 >>> 

64 >>> #: A color object (e.g., RGB, HCL, CIELUV, ...) 

65 >>> from colorspace.colorlib import hexcols 

66 >>> cobject = hexcols(heat_hcl()(5)) 

67 >>> cobject.to("HCL") 

68 >>> print(cobject) 

69 >>> #: 

70 >>> swatchplot(cobject, figsize = (7, 0.5)); 

71 >>> 

72 >>> #: Using dictionaries to add subtitles 

73 >>> # to 'group' different palettes. 

74 >>> swatchplot({"Diverging": [diverging_hcl(), diverging_hcl("Red-Green")], 

75 >>> "Sequential": [sequential_hcl("ag_Sunset"), sequential_hcl("OrRd")], 

76 >>> "Others": [['#7FBFF5', '#2A4962', '#111111', '#633C39', '#F8A29E'], 

77 >>> pal, sequential_hcl("PuBu")]}, n = 15); 

78 

79 Raises: 

80 ImportError: If `matplotlib` is not installed. 

81 TypeError: If `nrow` or `n` no int. 

82 TypeError: If `show_names` not bool. 

83 ValueError: If `nrow` or `n` are not positive. 

84 ImportError: If `matplotlib.pyplot` cannot be imported, maybe `matplotlib` not installed? 

85 """ 

86 

87 # Requires matplotlib. If not available, throw ImportError 

88 try: 

89 import matplotlib.pyplot as plt 

90 except ImportError as e: 

91 raise ImportError("problems importing matplotlib.pyplt (not installed?)") 

92 

93 from numpy import all 

94 

95 # Sanity checks: nrow and n only 

96 if not isinstance(nrow, int): raise TypeError("argument `nrow` must be int") 

97 if not isinstance(n, int): raise TypeError("argument `n` must be int") 

98 if not isinstance(show_names, bool): raise TypeError("argument `show_names` must be bool") 

99 if not nrow > 0: raise ValueError("argument `nrow` must be positive") 

100 if not n > 0: raise ValueError("argument `n` must be positive") 

101 

102 # Checking optional cvd argument 

103 if not isinstance(cvd, (str, list, type(None))): 

104 raise TypeError("unexpected input on argument `cvd`") 

105 if isinstance(cvd, list): 

106 if not all([isinstance(x, str) for x in cvd]): 

107 raise ValueError("unexpected input on argument for `cvd`") 

108 elif isinstance(cvd, str): 

109 cvd = [cvd] 

110 

111 # Checking values 

112 if isinstance(cvd, list): 

113 valid_cvd_types = ["protan", "tritan", "deutan", "desaturate"] 

114 if not all([x in valid_cvd_types for x in cvd]): 

115 raise ValueError(f"allowed values for argument `cvd` are: {', '.join(valid_cvd_types)}") 

116 

117 

118 # --------------------------------------------------------------- 

119 # Setting up matplotlib for plotting 

120 # --------------------------------------------------------------- 

121 

122 # Allow the user to specify figure size if needed 

123 if "figsize" in kwargs: 

124 figsize = kwargs["figsize"] 

125 if not isinstance(figsize, tuple) or not len(figsize) == 2: 

126 raise ValueError("argument `figsize` must be a tuple of length 2") 

127 for i in range(2): 

128 if not isinstance(figsize[i], int) and not isinstance(figsize[i], float): 

129 raise ValueError(f"element [{i}] in `figsize` not int or float.") 

130 else: 

131 figsize = (5, 4) # default figure size 

132 

133 

134 # --------------------------------------------------------------- 

135 # Prepare the palettes for plotting the swatches. 

136 # The function allows for various types as iput which will 

137 # be converted to a list of dicts, or a dict of list of dicts. 

138 # 

139 # Note that <name> can also be empty if the palette is unnamed. 

140 # 

141 # One single palette results in: 

142 # [{"name": <name>, "colors": <list of hex colors>}] 

143 # Multiple palettes result in: 

144 # [{"name": <name palette 1>, "colors": <list of hex colors palette 1>}, 

145 # {"name": <name palette 2>, "colors": <list of hex colors palette 2>}, 

146 # ...] 

147 # Dict of palette collections: 

148 # {"First Collection": [{"name": <name palette 1>, "colors": <list of hex colors palette 1>}, 

149 # {"name": <name palette 2>, "colors": <list of hex colors palette 2>}, 

150 # ...], 

151 # "Second Collection": [{"name": <name palette 1>, "colors": <list of hex colors palette 1>}, 

152 # {"name": <name palette 2>, "colors": <list of hex colors palette 2>}, 

153 # ...]} 

154 # 'hclpalettes' object will also be converted into 

155 # a dictionary as shown above. 

156 # --------------------------------------------------------------- 

157 

158 from numpy import all, max, sum, where 

159 from .palettes import palette, defaultpalette, hclpalette, hclpalettes 

160 from .colorlib import colorobject 

161 from .utils import check_hex_colors 

162 allowed = (palette, defaultpalette, hclpalette) 

163 

164 # Helper function; Convert whatever we get (and can) into a simple 

165 # dictionary containing "name" (name of palette, defaults to None) 

166 # and "colors", a hex-list with colors. 

167 def _pal_to_dict(x, n): 

168 """Helper function: Converts one palette or color object 

169 

170 Converts all possible tpes of color palettes or objects into 

171 a dictionary. Used to prepare the inputs to swatchplot for the 

172 plot itself. 

173 

174 Args: 

175 x: Some kind of a color-representing object. See swatchplot 

176 description for more details. 

177 n (int): Number of colors to be drawn from non-fixed palettes. 

178 

179 Return: 

180 dict: Returns a single dict with `name` (name of the palette) 

181 and `color` (list of hex colors). 

182 

183 Raises: 

184 Exception: If input `x` is of unknown type/format and cannot be converted. 

185 ValueError: If the palette does not provide any color at all. 

186 """ 

187 

188 

189 # In case argument 'pals' is a list we first check if this is 

190 # a valid list of hex colors. If so: convert to dictionary. 

191 if isinstance(x, list): 

192 res = {"name": None, "colors": check_hex_colors(x)} 

193 # Single colorobject (e.g., RGB, HCL, CIELUV, ...) 

194 elif isinstance(x, colorobject): 

195 res = {"name": None, "colors": x.colors(n)} 

196 # Single color hclpalette object (e.g., diverging_hcl, sequential_hcl, ...) 

197 elif isinstance(x, hclpalette): 

198 res = {"name": x.name(), "colors": x.colors(n)} 

199 # Single palette object (custom palette) 

200 elif isinstance(x, (palette, defaultpalette)): 

201 res = {"name": x.name(), "colors": x.colors(n)} 

202 else: 

203 raise Exception(f"could not convert `pals`, improper input (type {type(x)}).") 

204 

205 # Checking length of color list 

206 if not len(res["colors"]) > 0: 

207 raise ValueError(f"got at least one color object/palette with 0 colors") 

208 return res 

209 

210 

211 # If 'pals' is: 

212 # * a list with proper hex values 

213 # * a single colorobject (e.g., RGB, HCL, CIELUV, ...) 

214 # * a single hclpalette object (e.g., diverging_hcl, sequential_hcl, ...) 

215 # * a single palette object 

216 # ... convert and put it into a list of length 1. 

217 def _convert_pals_to_list(pals, n): 

218 """Helper function: convert list of palettes or color objects. 

219 

220 Used as a generic function to convert a series of palettes or color 

221 objects given by the user into the format we will need later on for 

222 creating ths watchplot. 

223 

224 Args: 

225 pals: forwarded from main swatchplot call. 

226 n (int): number of colors for palette objects, forwarded from main swatchplot call. 

227 

228 Returns: 

229 list: A dictionary and a list consiting of dictionaries.  

230 The first dictionary contains meta information about the number 

231 of named palettes (`n_named`), the number of palettes (`n_palettes`), 

232 and the highest number of colors among these palettes as it can differ. 

233 The second list contains dictionaries where each dictionary contains 

234 two elements: `name` (str) defining name of the palette, and  

235 `colors` (list) which is a list of hex colors (str), the colors 

236 to be displayed. 

237 """ 

238 

239 from matplotlib.colors import LinearSegmentedColormap, ListedColormap 

240 

241 # In case we get a list let's check if we have a valid hex list. 

242 # Can also be a list of list processed later on. 

243 if isinstance(pals, (str, list)): 

244 try: 

245 pals = [check_hex_colors(pals)] 

246 except: 

247 pass 

248 

249 if isinstance(pals, colorobject) or \ 

250 isinstance(pals, hclpalette) or isinstance(pals, palette): 

251 res = [_pal_to_dict(pals, n)] 

252 # What else? If we have a list we now iterate over the different items 

253 # and convert each entry into a dict using _pal_to_dict(). Will fail 

254 # if we have no rule for this. 

255 elif isinstance(pals, list): 

256 res = [_pal_to_dict(x, n) for x in pals] 

257 # Matplotlib colormap? Convert 

258 elif isinstance(pals, (LinearSegmentedColormap, ListedColormap)): 

259 from .cmap import cmap_to_sRGB 

260 tmp_cols = cmap_to_sRGB(pals, n).colors() 

261 res = [_pal_to_dict(palette(tmp_cols, name = pals.name), n)] 

262 del tmp_cols 

263 # If we got a dictionary we keep the keys as names and extract 

264 # the colors from the object(s) itself. 

265 elif isinstance(pals, dict): 

266 res = [] 

267 for key,pal in pals.items(): 

268 tmp = _pal_to_dict(pal, n) 

269 res.append({"name": key, "colors": tmp["colors"]}) 

270 else: 

271 raise TypeError(f"cannot deal with object of type \"{type(pals)}\"") 

272 

273 

274 # Extract number of palettes, number of named palettes, 

275 # and max number of colors. 

276 meta = {"n_named": sum([0 if x is None else 1 for x in res]), 

277 "n_palettes": len(res), 

278 "max_colors": max([len(x["colors"]) for x in res])} 

279 

280 # Return meta info and data 

281 return meta, res 

282 

283 

284 # If input is an object of class hclpalettes we will first 

285 # convert it into a dictionary which is then further processed. 

286 if isinstance(pals, hclpalettes): 

287 tmp = {} 

288 for type_ in pals.get_palette_types(): tmp[type_] = pals.get_palettes(type_) 

289 # Overwrite input object 

290 pals = tmp 

291 

292 

293 # If 'pals' is not a dictionary we don't have "groups". 

294 # will be converted into one single list (the resulting 

295 # object "data"). 

296 if not isinstance(pals, dict): 

297 meta, data = _convert_pals_to_list(pals, n) 

298 

299 # Else (dictionary provided by the user) we will 

300 # process each item in the dictionary individually. 

301 # The result "data" is a dictionary itself (not a list as above). 

302 else: 

303 meta = None 

304 data = {} 

305 for key,pal in pals.items(): 

306 tmp_meta, tmp_data = _convert_pals_to_list(pal, n) 

307 data[key] = tmp_data 

308 # Store meta information 

309 if meta is None: 

310 meta = tmp_meta 

311 else: 

312 meta["n_named"] += tmp_meta["n_named"] 

313 meta["n_palettes"] += tmp_meta["n_palettes"] 

314 meta["max_colors"] = max([meta["max_colors"], tmp_meta["max_colors"]]) 

315 

316 # No named palettes? Well, then we can set 'show_names' to FALSE. 

317 if meta["n_named"] == 0: show_names = False 

318 

319 

320 # --------------------------------------------------------------- 

321 # User request for CVD emulated palettes? 

322 # --------------------------------------------------------------- 

323 if not cvd is None: 

324 new_data = dict() 

325 counter = 0 

326 

327 # Reduce dictionary to list 

328 if isinstance(data, dict): 

329 from warnings import warn 

330 warn("Dictionary inputs to swatchplot in combination with cvd not allowed, " + \ 

331 "dictionary will be reduced to a list.") 

332 tmp_data = [] 

333 for rec in data.values(): tmp_data += rec 

334 data = tmp_data; del tmp_data 

335 

336 # Convert list back into a dictionary; each entry is one of the 

337 # palettes provided by the user with a series of palettes according 

338 # to the types of color vision deficiencies specified 

339 if isinstance(data, list): 

340 from colorspace import CVD 

341 for rec in data: 

342 tmp = [{"name": "original", "colors": rec["colors"]}] 

343 counter += 1 

344 for fn in cvd: 

345 tmp.append({"name": fn, "colors": getattr(CVD, fn)(rec["colors"])}) 

346 counter += 1 

347 new_data[rec["name"]] = tmp 

348 

349 # Overwrite existing 'data' object and re-specify the 

350 # meta information. From here on the plotting is the same 

351 # as if the user would have had provided a dictionary with a 

352 # series of named palettes in combination with cvd = None. 

353 data = new_data 

354 meta["n_named"] = counter 

355 meta["n_palettes"] = counter 

356 del new_data, counter 

357 

358 

359 # --------------------------------------------------------------- 

360 # Now let's start the fun with plotting! 

361 # --------------------------------------------------------------- 

362 # Helper function to plot the color palettes 

363 # Calls "cmap()" function (see below) 

364 def _plot_swatches(data, xpos, ypos, xstep, ystep, show_names, single_palette = False): 

365 """Helper function: plotting a swatch. 

366 

367 Args: 

368 data (list): List of dicts as prepared in the upper part 

369 of the swatchplot function. 

370 xpos (float): Current X position on the plot. 

371 ypos (float): Current Y position on the plot. 

372 xstep (float): Step in X-direction for columns. 

373 ystep (float): Step in Y-direction for rows. 

374 single_palette (bool): Set to true if we have one single palette; 

375 changes the x-offset to make use of the full canvas. 

376 

377 Raises: 

378 TypeError: Wrong unexpected type of input argument (xpos, ypos, xstep, ystep, 

379 single_palette, and show_names). 

380 ValueError: Arguments out of valid bounds (xpos, ypos, xstep, ystep). 

381 """ 

382 

383 if not isinstance(xpos, float) or not isinstance(ypos, float) or \ 

384 not isinstance(xstep, float) or not isinstance(ystep, float) or \ 

385 not isinstance(show_names, bool) or not isinstance(single_palette, bool): 

386 raise TypeError("non-suitable input argument (wrong type)") 

387 if not xpos >= 0. or not xpos <= 1. or not ypos >= 0. or not ypos <= 1. or \ 

388 not xstep >= 0. or not xstep <= 1. or not ystep >= 0. or not ystep <= 1: 

389 raise ValueError("at least one of xpos/ypos/xstep/ystep out of valid bounds") 

390 

391 

392 # Plotting one swatch after another. 

393 # Calculates new x/y position which will be returned 

394 # and re-used for the next set of palettes (if there are any). 

395 for pal in data: 

396 

397 # Adding text (only if not single_palette) 

398 if show_names and not single_palette: 

399 ax.text(xpos + xstep * 0.02, ypos, pal["name"], name_args) 

400 

401 # Getting colors, plotting color bar 

402 xoff = 0.35 if show_names and not single_palette else 0. 

403 _swatch(ax, pal["colors"], len(pal["colors"]), ypos - 0.8 * ystep / 2., 

404 ypos + 0.8 * ystep / 2., xpos + xoff * xstep, xpos + 0.99 * xstep) 

405 

406 ypos -= ystep 

407 # Start new column 

408 if ypos < 0: 

409 ypos = 1. - ystep / 2.; xpos = xpos + xstep 

410 

411 return xpos, ypos 

412 

413 

414 # Helper function, draw the colormap 

415 def _swatch(ax, cols, ncols, ylo, yhi, xmin, xmax, boxedupto = 6, frameupto = 9): 

416 

417 from numpy import linspace 

418 from matplotlib.patches import Rectangle 

419 framecol = "#cecece" 

420 

421 if ncols == 1: 

422 space = 0. 

423 step = xmax - xmin 

424 xlo = [xmin] 

425 edgecolor = framecol 

426 

427 elif ncols <= boxedupto: 

428 # ----------------------------------------- 

429 # For n = 2 

430 # |<------------ deltax ------------>| 

431 # xmin xmax 

432 # | -------------------------------- | 

433 # |#####COL1####### ######COL2###### 

434 # | >|-|< space | 

435 # |<--- step ------->| | 

436 # |xlo[0] |xlo[1] | 

437 # | |xhi[0] |xhi[1] 

438 #  

439 deltax = float(xmax - xmin) 

440 space = deltax * 0.05 / (ncols - 1) 

441 step = (deltax - float(ncols - 1.) * space) / float(ncols) 

442 xlo = linspace(xmin, xmax - step + space, ncols) 

443 edgecolor = framecol 

444 

445 # Else it is a bit simpler 

446 else: 

447 space = 0. 

448 step = float(xmax - xmin) / float(ncols) 

449 xlo = linspace(xmin, xmax - float(xmax - xmin) / (ncols), ncols) 

450 edgecolor = framecol if ncols <= frameupto else None 

451 

452 # Plotting the rectangles 

453 for i in range(0, len(cols)): 

454 rect = Rectangle((xlo[i], ylo), (step - space), yhi - ylo, 

455 facecolor = "#FFFFFF" if cols[i] is None else cols[i], 

456 edgecolor = edgecolor) 

457 ax.add_patch(rect) 

458 

459 # Outer frame 

460 if ncols > frameupto: 

461 rect = Rectangle((xmin, ylo), step * len(cols), yhi - ylo, 

462 facecolor = "none", edgecolor = framecol) 

463 ax.add_patch(rect) 

464 

465 

466 # --------------------------------------------------------------- 

467 # Plot 

468 # --------------------------------------------------------------- 

469 # Initialize new figure 

470 #fig, ax = plt.subplots(figsize = figsize) 

471 if "files_regex" in kwargs.keys(): del kwargs["files_regex"] 

472 fig, ax = plt.subplots(**kwargs) 

473 

474 # Compute number of columns needed. Thus, we first of all have to check 

475 # if 'data' is a list (no group-titles) or a dictionary (each block/group 

476 # of palettes gets a title -> more space needed). 

477 if isinstance(data, dict): 

478 nblocks = meta["n_palettes"] + len(data) 

479 else: 

480 nblocks = meta["n_palettes"] 

481 if nblocks <= nrow: 

482 ncol = 1 

483 nrow = nblocks 

484 else: 

485 from numpy import ceil 

486 ncol = ceil(float(nblocks) / float(nrow)) 

487 

488 # Plotting the different color maps 

489 ystep = 1. / float(nrow) 

490 ypos = 1. - ystep / 2 

491 

492 # Starting top left 

493 xstep = 1. / float(ncol) 

494 xpos = 0. 

495 

496 # Adjusting outer margins 

497 fig.subplots_adjust(left = 0., bottom = 0., right = 1., 

498 top = 1., wspace = 0., hspace = 0.) 

499 ax.axis("off") 

500 # Small white margin around the plot 

501 ax.set_xlim(-0.01, 1.01); ax.set_ylim(-0.01, 1.01) 

502 

503 # Styling of the texts 

504 type_args = {"weight": "bold", "va": "center", "ha": "left"} 

505 type_args["size"] = "large" if meta["n_palettes"] > 20 else "xx-large" 

506 name_args = {"va": "center", "ha": "left"} 

507 

508 # One or multiple palettes but no titles/grouping 

509 if isinstance(data, list): 

510 xpos, ypos = _plot_swatches(data, xpos, ypos, xstep, ystep, show_names) 

511 # Else dictionary: adding additional titles for grouping 

512 else: 

513 for key,pal in data.items(): 

514 ax.text(xpos + xstep * 0.02, ypos, key, type_args) 

515 ypos -= ystep 

516 # Start new column 

517 if ypos < 0: 

518 ypos = 1. - ystep / 2.; xpos = xpos + xstep 

519 xpos, ypos = _plot_swatches(pal, xpos, ypos, xstep, ystep, show_names) 

520 

521 # Show figure 

522 plt.show() 

523 

524 return fig 

525 

526