From 173495f046fda6f4e9c90d2b63ca93e8ae21e257 Mon Sep 17 00:00:00 2001 From: Sonja Stockhaus Date: Tue, 30 Sep 2025 16:32:22 +0200 Subject: [PATCH 1/2] uniform color handling between labels and points/shapes --- src/spatialdata_plot/pl/basic.py | 16 +++++----- src/spatialdata_plot/pl/render.py | 19 +++++++----- src/spatialdata_plot/pl/render_params.py | 3 +- src/spatialdata_plot/pl/utils.py | 38 ++++++++++++++++-------- 4 files changed, 49 insertions(+), 27 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 025eb9ed..6ea9e471 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -598,7 +598,7 @@ def render_labels( norm: Normalize | None = None, na_color: ColorLike | None = "default", outline_alpha: float | int = 0.0, - fill_alpha: float | int = 0.4, + fill_alpha: float | int | None = None, scale: str | None = None, table_name: str | None = None, table_layer: str | None = None, @@ -643,8 +643,9 @@ def render_labels( won't be shown. outline_alpha : float | int, default 0.0 Alpha value for the outline of the labels. Invisible by default. - fill_alpha : float | int, default 0.4 - Alpha value for the fill of the labels. + fill_alpha : float | int, optional. + Alpha value for the fill of the labels. When no alpha is implied by the passed color, a default value of 0.4 + is used. scale : str | None Influences the resolution of the rendering. Possibilities for setting this parameter: 1) None (default). The image is rasterized to fit the canvas size. For multiscale images, the best scale @@ -702,6 +703,7 @@ def render_labels( sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams( element=element, color=param_values["color"], + col_for_color=param_values["col_for_color"], groups=param_values["groups"], contour_px=param_values["contour_px"], cmap_params=cmap_params, @@ -984,13 +986,13 @@ def show( if wanted_labels_on_this_cs: if (table := params_copy.table_name) is not None: - assert isinstance(params_copy.color, str) - colors = sc.get.obs_df(sdata[table], [params_copy.color]) - if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype): + assert isinstance(params_copy.col_for_color, str) + colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) + if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype): _maybe_set_colors( source=sdata[table], target=sdata[table], - key=params_copy.color, + key=params_copy.col_for_color, palette=params_copy.palette, ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f74259a4..9f147fe1 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1084,7 +1084,8 @@ def _render_labels( table_name = render_params.table_name table_layer = render_params.table_layer palette = render_params.palette - color = render_params.color + color = render_params.color.get_hex() if render_params.color else None + col_for_color = render_params.col_for_color groups = render_params.groups scale = render_params.scale @@ -1137,18 +1138,21 @@ def _render_labels( sdata=sdata_filt, element=label, element_name=element, - value_to_plot=color, + # value_to_plot=color, # TODO + value_to_plot=col_for_color, groups=groups, palette=palette, - na_color=render_params.cmap_params.na_color, + # na_color=render_params.cmap_params.na_color, # TODO + na_color=render_params.color if render_params.color is not None else render_params.cmap_params.na_color, cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, + render_type="labels", ) # rasterize could have removed labels from label # only problematic if color is specified - if rasterize and color is not None: + if rasterize and (color is not None or col_for_color is not None): labels_in_rasterized_image = np.unique(label.values) mask = np.isin(instance_id, labels_in_rasterized_image) instance_id = instance_id[mask] @@ -1157,8 +1161,8 @@ def _render_labels( color_vector = color_vector.remove_unused_categories() assert color_source_vector is not None color_source_vector = color_source_vector[mask] - else: - assert color_source_vector is None + # else: + # assert color_source_vector is None # TODO: delete? def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage: labels = _map_color_seg( @@ -1228,7 +1232,8 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) cax=cax, fig_params=fig_params, adata=table, - value_to_plot=color, + # value_to_plot=color, # TODO + value_to_plot=col_for_color, color_source_vector=color_source_vector, color_vector=color_vector, palette=palette, diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 5e3af820..13235e39 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -254,7 +254,8 @@ class LabelsRenderParams: cmap_params: CmapParams element: str - color: str | None = None + color: Color | None = None + col_for_color: str | None = None groups: str | list[str] | None = None contour_px: int | None = None outline: bool = False diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 9df1f3d0..b0d9802a 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -777,10 +777,11 @@ def _set_color_source_vec( alpha: float = 1.0, table_name: str | None = None, table_layer: str | None = None, - render_type: Literal["points"] | None = None, + render_type: Literal["points", "labels"] | None = None, ) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: if value_to_plot is None and element is not None: - color = np.full(len(element), na_color.get_hex_with_alpha()) + n_elements = len(element) if render_type != "labels" else len(dask.array.unique(element.data).compute()) + color = np.full(n_elements, na_color.get_hex_with_alpha()) return color, color, False # Figure out where to get the color from @@ -1000,7 +1001,7 @@ def _get_categorical_color_mapping( alpha: float = 1, groups: list[str] | str | None = None, palette: list[str] | str | None = None, - render_type: Literal["points"] | None = None, + render_type: Literal["points", "labels"] | None = None, ) -> Mapping[str, str]: if not isinstance(color_source_vector, Categorical): raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}") @@ -1648,7 +1649,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st }: if not isinstance(color, str | tuple | list): raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") - if element_type in {"shapes", "points"}: + if element_type in {"shapes", "points", "labels"}: if _is_color_like(color): logger.info("Value for parameter 'color' appears to be a color, using it as such.") param_dict["col_for_color"] = None @@ -1656,7 +1657,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if param_dict["color"].alpha_is_user_defined(): if element_type == "points" and param_dict.get("alpha") is None: param_dict["alpha"] = param_dict["color"].get_alpha_as_float() - elif element_type == "shapes" and param_dict.get("fill_alpha") is None: + elif element_type in ["shapes", "labels"] and param_dict.get("fill_alpha") is None: param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() else: logger.info( @@ -1668,7 +1669,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st param_dict["color"] = None else: raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") - elif "color" in param_dict and element_type != "labels": + elif "color" in param_dict and element_type != "images": param_dict["col_for_color"] = None if outline_width := param_dict.get("outline_width"): @@ -1754,6 +1755,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif element_type == "shapes": # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color) param_dict["fill_alpha"] = 1.0 + elif element_type == "lables": + param_dict["fill_alpha"] = 0.4 if (cmap := param_dict.get("cmap")) is not None and (palette := param_dict.get("palette")) is not None: raise ValueError("Both `palette` and `cmap` are specified. Please specify only one of them.") @@ -1894,7 +1897,7 @@ def _validate_label_render_params( element: str | None, cmap: list[Colormap | str] | Colormap | str | None, color: str | None, - fill_alpha: float | int, + fill_alpha: float | int | None, contour_px: int | None, groups: list[str] | str | None, palette: list[str] | str | None, @@ -1939,12 +1942,23 @@ def _validate_label_render_params( element_params[el]["table_layer"] = param_dict["table_layer"] element_params[el]["table_name"] = None - element_params[el]["color"] = None - color = param_dict["color"] - if color is not None: - color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True) + + # element_params[el]["color"] = None # TODO: delete + # color = param_dict["color"] + # if color is not None: + # color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], + # labels=True) + # element_params[el]["table_name"] = table_name + # element_params[el]["color"] = color + element_params[el]["color"] = param_dict["color"] + + element_params[el]["col_for_color"] = None + if (col_for_color := param_dict["col_for_color"]) is not None: + col_for_color, table_name = _validate_col_for_column_table( + sdata, el, col_for_color, param_dict["table_name"], labels=True + ) element_params[el]["table_name"] = table_name - element_params[el]["color"] = color + element_params[el]["col_for_color"] = col_for_color element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None From e8a940e8d724ce0a2b1626513d85c10fd4e4befa Mon Sep 17 00:00:00 2001 From: Sonja Stockhaus Date: Tue, 30 Sep 2025 16:41:50 +0200 Subject: [PATCH 2/2] typo --- src/spatialdata_plot/pl/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index b0d9802a..9ebdd786 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1755,7 +1755,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif element_type == "shapes": # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color) param_dict["fill_alpha"] = 1.0 - elif element_type == "lables": + elif element_type == "labels": param_dict["fill_alpha"] = 0.4 if (cmap := param_dict.get("cmap")) is not None and (palette := param_dict.get("palette")) is not None: