Skip to content

Commit c6d6153

Browse files
timtreispre-commit-ci[bot]melonora
authored
Refactor of colorbar and norm logic (#346)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wouter-Michiel Vierdag <[email protected]>
1 parent 6cef5df commit c6d6153

20 files changed

+33
-76
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ and this project adheres to [Semantic Versioning][].
1818

1919
- Lowered RMSE-threshold for plot-based tests from 45 to 15 (#344)
2020
- When subsetting to `groups`, `NA` isn't automatically added to legend (#344)
21+
- When rendering a single image channel, a colorbar is now shown (#346)
22+
- Removed `percentiles_for_norm` parameter (#346)
23+
- Changed `norm` to no longer accept bools, only `mpl.colors.Normalise` or `None` (#346)
2124

2225
### Fixed
2326

2427
- Filtering with `groups` now preserves original cmap (#344)
2528
- Non-selected `groups` are now not shown in `na_color` (#344)
29+
- Several issues associated with `norm` and `colorbar` (#346)
2630

2731
## [0.2.5] - 2024-08-23
2832

src/spatialdata_plot/pl/basic.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def render_shapes(
166166
outline_color: str | list[float] = "#000000ff",
167167
outline_alpha: float | int = 0.0,
168168
cmap: Colormap | str | None = None,
169-
norm: bool | Normalize = False,
169+
norm: Normalize | None = None,
170170
scale: float | int = 1.0,
171171
method: str | None = None,
172172
table_name: str | None = None,
@@ -301,7 +301,7 @@ def render_points(
301301
palette: list[str] | str | None = None,
302302
na_color: ColorLike | None = "default",
303303
cmap: Colormap | str | None = None,
304-
norm: None | Normalize = None,
304+
norm: Normalize | None = None,
305305
size: float | int = 1.0,
306306
method: str | None = None,
307307
table_name: str | None = None,
@@ -422,7 +422,6 @@ def render_images(
422422
na_color: ColorLike | None = "default",
423423
palette: list[str] | str | None = None,
424424
alpha: float | int = 1.0,
425-
percentiles_for_norm: tuple[float, float] | None = None,
426425
scale: str | None = None,
427426
**kwargs: Any,
428427
) -> sd.SpatialData:
@@ -457,8 +456,6 @@ def render_images(
457456
Palette to color images. The number of palettes should be equal to the number of channels.
458457
alpha : float | int, default 1.0
459458
Alpha value for the images. Must be a numeric between 0 and 1.
460-
percentiles_for_norm : tuple[float, float] | None
461-
Optional pair of floats (pmin < pmax, 0-100) which will be used for quantile normalization.
462459
scale : str | None
463460
Influences the resolution of the rendering. Possibilities include:
464461
1) `None` (default): The image is rasterized to fit the canvas size. For
@@ -486,20 +483,14 @@ def render_images(
486483
cmap=cmap,
487484
norm=norm,
488485
scale=scale,
489-
percentiles_for_norm=percentiles_for_norm,
490486
)
491487

492488
sdata = self._copy()
493489
sdata = _verify_plotting_tree(sdata)
494490
n_steps = len(sdata.plotting_tree.keys())
495491

496492
for element, param_values in params_dict.items():
497-
# cmap_params = _prepare_cmap_norm(
498-
# cmap=params_dict[element]["cmap"],
499-
# norm=norm,
500-
# na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
501-
# **kwargs,
502-
# )
493+
503494
cmap_params: list[CmapParams] | CmapParams
504495
if isinstance(cmap, list):
505496
cmap_params = [
@@ -525,7 +516,6 @@ def render_images(
525516
cmap_params=cmap_params,
526517
palette=param_values["palette"],
527518
alpha=param_values["alpha"],
528-
percentiles_for_norm=param_values["percentiles_for_norm"],
529519
scale=param_values["scale"],
530520
zorder=n_steps,
531521
)

src/spatialdata_plot/pl/render.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import datashader as ds
1010
import geopandas as gpd
1111
import matplotlib
12+
import matplotlib.pyplot as plt
1213
import matplotlib.transforms as mtransforms
1314
import numpy as np
1415
import pandas as pd
@@ -47,7 +48,6 @@
4748
_maybe_set_colors,
4849
_mpl_ax_contains_elements,
4950
_multiscale_to_spatial_image,
50-
_normalize,
5151
_rasterize_if_necessary,
5252
_set_color_source_vec,
5353
to_hex,
@@ -128,6 +128,7 @@ def _render_shapes(
128128
shapes = shapes.reset_index()
129129
color_source_vector = color_source_vector[mask]
130130
color_vector = color_vector[mask]
131+
131132
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
132133

133134
# Using dict.fromkeys here since set returns in arbitrary order
@@ -255,9 +256,13 @@ def _render_shapes(
255256
for path in _cax.get_paths():
256257
path.vertices = trans.transform(path.vertices)
257258

258-
# Sets the limits of the colorbar to the values instead of [0, 1]
259-
if not norm and not values_are_categorical:
260-
_cax.set_clim(min(color_vector), max(color_vector))
259+
if not values_are_categorical:
260+
# If the user passed a Normalize object with vmin/vmax we'll use those,
261+
# # if not we'll use the min/max of the color_vector
262+
_cax.set_clim(
263+
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
264+
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
265+
)
261266

262267
if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
263268
# necessary in case different shapes elements are annotated with one table
@@ -603,11 +608,6 @@ def _render_images(
603608
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
604609
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()
605610

606-
if render_params.percentiles_for_norm != (None, None):
607-
layer = _normalize(
608-
layer, pmin=render_params.percentiles_for_norm[0], pmax=render_params.percentiles_for_norm[1], clip=True
609-
)
610-
611611
if render_params.cmap_params.norm: # type: ignore[attr-defined]
612612
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
613613

@@ -623,20 +623,16 @@ def _render_images(
623623

624624
_ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)
625625

626+
if legend_params.colorbar:
627+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
628+
fig_params.fig.colorbar(sm, ax=ax)
629+
626630
# 2) Image has any number of channels but 1
627631
else:
628632
layers = {}
629633
for ch_index, c in enumerate(channels):
630634
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
631635

632-
if render_params.percentiles_for_norm != (None, None):
633-
layers[c] = _normalize(
634-
layers[c],
635-
pmin=render_params.percentiles_for_norm[0],
636-
pmax=render_params.percentiles_for_norm[1],
637-
clip=True,
638-
)
639-
640636
if not isinstance(render_params.cmap_params, list):
641637
if render_params.cmap_params.norm is not None:
642638
layers[c] = render_params.cmap_params.norm(layers[c])

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def _get_scalebar(
489489

490490
def _prepare_cmap_norm(
491491
cmap: Colormap | str | None = None,
492-
norm: Normalize | bool = False,
492+
norm: Normalize | None = None,
493493
na_color: ColorLike | None = None,
494494
vmin: float | None = None,
495495
vmax: float | None = None,
@@ -1623,29 +1623,6 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
16231623
if scale < 0:
16241624
raise ValueError("Parameter 'scale' must be a positive number.")
16251625

1626-
if (percentiles_for_norm := param_dict.get("percentiles_for_norm")) is None:
1627-
percentiles_for_norm = (None, None)
1628-
elif not (isinstance(percentiles_for_norm, (list, tuple)) or len(percentiles_for_norm) != 2):
1629-
raise TypeError("Parameter 'percentiles_for_norm' must be a list or tuple of exactly two floats or None.")
1630-
elif not all(
1631-
isinstance(p, (float, int, type(None)))
1632-
and isinstance(p, type(percentiles_for_norm[0]))
1633-
and (p is None or 0 <= p <= 100)
1634-
for p in percentiles_for_norm
1635-
):
1636-
raise TypeError(
1637-
"Each item in 'percentiles_for_norm' must be of the same dtype and must be a float or int within [0, 100], "
1638-
"or None"
1639-
)
1640-
elif (
1641-
percentiles_for_norm[0] is not None
1642-
and percentiles_for_norm[1] is not None
1643-
and percentiles_for_norm[0] > percentiles_for_norm[1]
1644-
):
1645-
raise ValueError("The first number in 'percentiles_for_norm' must not be smaller than the second.")
1646-
if "percentiles_for_norm" in param_dict:
1647-
param_dict["percentiles_for_norm"] = percentiles_for_norm
1648-
16491626
if size := param_dict.get("size"):
16501627
if not isinstance(size, (float, int)):
16511628
raise TypeError("Parameter 'size' must be numeric.")
@@ -1886,7 +1863,6 @@ def _validate_image_render_params(
18861863
cmap: list[Colormap | str] | Colormap | str | None,
18871864
norm: Normalize | None,
18881865
scale: str | None,
1889-
percentiles_for_norm: tuple[float | None, float | None] | None,
18901866
) -> dict[str, dict[str, Any]]:
18911867
param_dict: dict[str, Any] = {
18921868
"sdata": sdata,
@@ -1898,7 +1874,6 @@ def _validate_image_render_params(
18981874
"cmap": cmap,
18991875
"norm": norm,
19001876
"scale": scale,
1901-
"percentiles_for_norm": percentiles_for_norm,
19021877
}
19031878
param_dict = _type_check_params(param_dict, "images")
19041879

@@ -1945,8 +1920,6 @@ def _validate_image_render_params(
19451920
else:
19461921
element_params[el]["scale"] = scale
19471922

1948-
element_params[el]["percentiles_for_norm"] = param_dict["percentiles_for_norm"]
1949-
19501923
return element_params
19511924

19521925

Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)