diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f5c334e4..49cbaed1 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -45,6 +45,7 @@ _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, _get_transformation_matrix_for_datashader, + _hex_no_alpha, _is_coercable_to_float, _map_color_seg, _maybe_set_colors, @@ -191,7 +192,10 @@ def _render_shapes( lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2] ) transformed_element = ShapesModel.parse( - gpd.GeoDataFrame(data=sdata_filt.shapes[element].drop("geometry", axis=1), geometry=transformed_element) + gpd.GeoDataFrame( + data=sdata_filt.shapes[element].drop("geometry", axis=1), + geometry=transformed_element, + ) ) plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas( @@ -208,7 +212,11 @@ def _render_shapes( aggregate_with_reduction = None if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): if color_by_categorical: - agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count())) + agg = cvs.polygons( + transformed_element, + geometry="geometry", + agg=ds.by(col_for_color, ds.count()), + ) else: reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean" logger.info( @@ -216,7 +224,11 @@ def _render_shapes( "to the matplotlib result." ) agg = _datashader_aggregate_with_function( - render_params.ds_reduction, cvs, transformed_element, col_for_color, "shapes" + render_params.ds_reduction, + cvs, + transformed_element, + col_for_color, + "shapes", ) # save min and max values for drawing the colorbar aggregate_with_reduction = (agg.min(), agg.max()) @@ -246,7 +258,7 @@ def _render_shapes( agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) color_key = ( - [x[:-2] for x in color_vector.categories.values] + [_hex_no_alpha(x) for x in color_vector.categories.values] if (type(color_vector) is pd.core.arrays.categorical.Categorical) and (len(color_vector.categories.values) > 1) else None @@ -257,7 +269,7 @@ def _render_shapes( if color_vector is not None: ds_cmap = color_vector[0] if isinstance(ds_cmap, str) and ds_cmap[0] == "#": - ds_cmap = ds_cmap[:-2] + ds_cmap = _hex_no_alpha(ds_cmap) ds_result = _datashader_map_aggregate_to_color( agg, @@ -272,7 +284,10 @@ def _render_shapes( # else: all elements would get alpha=0 and the color bar would have a weird range if aggregate_with_reduction[0] == aggregate_with_reduction[1]: ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) - aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1) + aggregate_with_reduction = ( + aggregate_with_reduction[0], + aggregate_with_reduction[0] + 1, + ) ds_result = _datashader_map_aggregate_to_color( agg, @@ -468,7 +483,9 @@ def _render_points( # we construct an anndata to hack the plotting functions if table_name is None: adata = AnnData( - X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype + X=points[["x", "y"]].values, + obs=points[coords].reset_index(), + dtype=points[["x", "y"]].values.dtype, ) else: adata_obs = sdata_filt[table_name].obs @@ -496,7 +513,9 @@ def _render_points( sdata_filt.points[element] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"}) # restore transformation in coordinate system of interest set_transformation( - element=sdata_filt.points[element], transformation=transformation_in_cs, to_coordinate_system=coordinate_system + element=sdata_filt.points[element], + transformation=transformation_in_cs, + to_coordinate_system=coordinate_system, ) if col_for_color is not None: @@ -586,7 +605,11 @@ def _render_points( "to the matplotlib result." ) agg = _datashader_aggregate_with_function( - render_params.ds_reduction, cvs, transformed_element, col_for_color, "points" + render_params.ds_reduction, + cvs, + transformed_element, + col_for_color, + "points", ) # save min and max values for drawing the colorbar aggregate_with_reduction = (agg.min(), agg.max()) @@ -642,7 +665,10 @@ def _render_points( # else: all elements would get alpha=0 and the color bar would have a weird range if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]): ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) - aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1) + aggregate_with_reduction = ( + aggregate_with_reduction[0], + aggregate_with_reduction[0] + 1, + ) ds_result = _datashader_map_aggregate_to_color( agg, @@ -805,7 +831,12 @@ def _render_images( # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip. _ax_show_and_transform( - layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm + layer, + trans_data, + ax, + cmap=cmap, + zorder=render_params.zorder, + norm=render_params.cmap_params.norm, ) if legend_params.colorbar: @@ -832,7 +863,11 @@ def _render_images( else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels stacked = ( - np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels + np.stack( + [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], + 0, + ).sum(0) + / n_channels ) stacked = stacked[:, :, :3] logger.warning( @@ -844,7 +879,13 @@ def _render_images( "Consider using 'palette' instead." ) - _ax_show_and_transform(stacked, trans_data, ax, render_params.alpha, zorder=render_params.zorder) + _ax_show_and_transform( + stacked, + trans_data, + ax, + render_params.alpha, + zorder=render_params.zorder, + ) # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors elif palette is None and not got_multiple_cmaps: @@ -858,7 +899,13 @@ def _render_images( colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) colored = colored[:, :, :3] - _ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder) + _ax_show_and_transform( + colored, + trans_data, + ax, + render_params.alpha, + zorder=render_params.zorder, + ) # 2C) Image has n channels and palette info elif palette is not None and not got_multiple_cmaps: @@ -869,16 +916,32 @@ def _render_images( colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) colored = colored[:, :, :3] - _ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder) + _ax_show_and_transform( + colored, + trans_data, + ax, + render_params.alpha, + zorder=render_params.zorder, + ) elif palette is None and got_multiple_cmaps: channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr] colored = ( - np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels + np.stack( + [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], + 0, + ).sum(0) + / n_channels ) colored = colored[:, :, :3] - _ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder) + _ax_show_and_transform( + colored, + trans_data, + ax, + render_params.alpha, + zorder=render_params.zorder, + ) elif palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") @@ -999,7 +1062,9 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) # outline-only case elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0: cax = _draw_labels( - seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha + seg_erosionpx=render_params.contour_px, + seg_boundaries=True, + alpha=render_params.outline_alpha, ) alpha_to_decorate_ax = render_params.outline_alpha @@ -1010,7 +1075,9 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) # ... then overlay the contour cax_contour = _draw_labels( - seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha + seg_erosionpx=render_params.contour_px, + seg_boundaries=True, + alpha=render_params.outline_alpha, ) # pass the less-transparent _cax for the legend @@ -1035,7 +1102,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) legend_fontweight=legend_params.legend_fontweight, legend_loc=legend_params.legend_loc, legend_fontoutline=legend_params.legend_fontoutline, - na_in_legend=legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector)), + na_in_legend=(legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector))), colorbar=legend_params.colorbar, scalebar_dx=scalebar_params.scalebar_dx, scalebar_units=scalebar_params.scalebar_units, diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 5d745b71..a2e8f767 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -55,7 +55,13 @@ from skimage.morphology import erosion, square from skimage.segmentation import find_boundaries from skimage.util import map_array -from spatialdata import SpatialData, get_element_annotators, get_extent, get_values, rasterize +from spatialdata import ( + SpatialData, + get_element_annotators, + get_extent, + get_values, + rasterize, +) from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement @@ -182,7 +188,12 @@ def _prepare_params_plot( dpi = rcParams["figure.dpi"] if dpi is None else dpi if num_panels > 1 and ax is None: fig, grid = _panel_grid( - num_panels=num_panels, hspace=hspace, wspace=wspace, ncols=ncols, dpi=dpi, figsize=figsize + num_panels=num_panels, + hspace=hspace, + wspace=wspace, + ncols=ncols, + dpi=dpi, + figsize=figsize, ) axs: None | Sequence[Axes] = [plt.subplot(grid[c]) for c in range(num_panels)] elif num_panels > 1: @@ -384,7 +395,11 @@ def _get_collection_shape( shapes_df = shapes_df.reset_index(drop=True) def _assign_fill_and_outline_to_row( - fill_c: list[Any], outline_c: list[Any], row: dict[str, Any], idx: int, is_multiple_shapes: bool + fill_c: list[Any], + outline_c: list[Any], + row: dict[str, Any], + idx: int, + is_multiple_shapes: bool, ) -> None: try: if is_multiple_shapes and len(fill_c) == 1: @@ -400,7 +415,10 @@ def _process_polygon(row: pd.Series, s: float) -> dict[str, Any]: coords = np.array(row["geometry"].exterior.coords) centroid = np.mean(coords, axis=0) scaled_coords = (centroid + (coords - centroid) * s).tolist() - return {**row.to_dict(), "geometry": mpatches.Polygon(scaled_coords, closed=True)} + return { + **row.to_dict(), + "geometry": mpatches.Polygon(scaled_coords, closed=True), + } def _process_multipolygon(row: pd.Series, s: float) -> list[dict[str, Any]]: mp = _make_patch_from_multipolygon(row["geometry"]) @@ -721,7 +739,12 @@ def _set_color_source_vec( return color, color, False # Figure out where to get the color from - origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name) + origins = _locate_value( + value_key=value_to_plot, + sdata=sdata, + element_name=element_name, + table_name=table_name, + ) if len(origins) > 1: raise ValueError( @@ -854,6 +877,7 @@ def _generate_base_categorial_color_mapping( cluster_key: str, color_source_vector: ArrayLike | pd.Series[CategoricalDtype], na_color: ColorLike, + cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns: colors = adata.uns[f"{cluster_key}_colors"] @@ -870,7 +894,7 @@ def _generate_base_categorial_color_mapping( return dict(zip(categories, colors, strict=True)) - return _get_default_categorial_color_mapping(color_source_vector) + return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) def _modify_categorical_color_mapping( @@ -894,22 +918,36 @@ def _modify_categorical_color_mapping( def _get_default_categorial_color_mapping( color_source_vector: ArrayLike | pd.Series[CategoricalDtype], + cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: len_cat = len(color_source_vector.categories.unique()) - if len_cat <= 20: - palette = default_20 - elif len_cat <= 28: - palette = default_28 - elif len_cat <= len(default_102): # 103 colors - palette = default_102 + # Try to use provided colormap first + if cmap_params is not None and cmap_params.cmap is not None and not cmap_params.cmap_is_default: + # Generate evenly spaced indices for the colormap + color_idx = np.linspace(0, 1, len_cat) + if isinstance(cmap_params.cmap, ListedColormap): + palette = [to_hex(x) for x in cmap_params.cmap(color_idx)] + elif isinstance(cmap_params.cmap, LinearSegmentedColormap): + palette = [to_hex(cmap_params.cmap(x)) for x in color_idx] + else: + # Fall back to default palettes if cmap is not of expected type + palette = None else: - palette = ["grey" for _ in range(len_cat)] - logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + palette = None - return { - cat: to_hex(to_rgba(col)[:3]) - for cat, col in zip(color_source_vector.categories, palette[:len_cat], strict=True) - } + # Fall back to default palettes if needed + if palette is None: + if len_cat <= 20: + palette = default_20 + elif len_cat <= 28: + palette = default_28 + elif len_cat <= len(default_102): # 103 colors + palette = default_102 + else: + palette = ["grey"] * len_cat + logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + + return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True)) def _get_categorical_color_mapping( @@ -944,15 +982,26 @@ def _get_categorical_color_mapping( if cluster_key is None: # user didn't specify a column to use for coloring - base_mapping = _get_default_categorial_color_mapping(color_source_vector) + base_mapping = _get_default_categorial_color_mapping( + color_source_vector=color_source_vector, cmap_params=cmap_params + ) else: - base_mapping = _generate_base_categorial_color_mapping(adata, cluster_key, color_source_vector, na_color) + base_mapping = _generate_base_categorial_color_mapping( + adata=adata, + cluster_key=cluster_key, + color_source_vector=color_source_vector, + na_color=na_color, + cmap_params=cmap_params, + ) return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette) def _maybe_set_colors( - source: AnnData, target: AnnData, key: str, palette: str | ListedColormap | Cycler | Sequence[Any] | None = None + source: AnnData, + target: AnnData, + key: str, + palette: str | ListedColormap | Cycler | Sequence[Any] | None = None, ) -> None: color_key = f"{key}_colors" try: @@ -1074,7 +1123,13 @@ def _get_list( raise ValueError(f"Can't make a list from variable: `{var}`") -def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "png", **kwargs: Any) -> None: +def save_fig( + fig: Figure, + path: str | Path, + make_dir: bool = True, + ext: str = "png", + **kwargs: Any, +) -> None: """ Save a figure. @@ -1366,7 +1421,12 @@ def _multiscale_to_spatial_image( def _get_elements_to_be_rendered( - render_cmds: list[tuple[str, ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams]], + render_cmds: list[ + tuple[ + str, + ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, + ] + ], cs_contents: pd.DataFrame, cs: str, ) -> list[str]: @@ -1440,7 +1500,15 @@ def _validate_show_parameters( f"the following strings: {readable_font_weights}.", ) - font_sizes = ["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"] + font_sizes = [ + "xx-small", + "x-small", + "small", + "medium", + "large", + "x-large", + "xx-large", + ] if legend_fontsize is not None and ( not isinstance(legend_fontsize, int | float | str) @@ -1533,7 +1601,11 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if (contour_px := param_dict.get("contour_px")) and not isinstance(contour_px, int): raise TypeError("Parameter 'contour_px' must be an integer.") - if (color := param_dict.get("color")) and element_type in {"shapes", "points", "labels"}: + if (color := param_dict.get("color")) and element_type in { + "shapes", + "points", + "labels", + }: if not isinstance(color, str): raise TypeError("Parameter 'color' must be a string.") if element_type in {"shapes", "points"}: @@ -1908,7 +1980,11 @@ def _validate_shape_render_params( def _validate_col_for_column_table( - sdata: SpatialData, element_name: str, col_for_color: str | None, table_name: str | None, labels: bool = False + sdata: SpatialData, + element_name: str, + col_for_color: str | None, + table_name: str | None, + labels: bool = False, ) -> tuple[str | None, str | None]: if not labels and col_for_color in sdata[element_name].columns: table_name = None @@ -1929,7 +2005,11 @@ def _validate_col_for_column_table( elif len(tables) >= 1: table_name = next(iter(tables)) if len(tables) > 1: - warnings.warn(f"Multiple tables contain color column, using {table_name}", UserWarning, stacklevel=2) + warnings.warn( + f"Multiple tables contain color column, using {table_name}", + UserWarning, + stacklevel=2, + ) return col_for_color, table_name @@ -2006,12 +2086,17 @@ def _validate_image_render_params( def _get_wanted_render_elements( sdata: SpatialData, sdata_wanted_elements: list[str], - params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, + params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams), cs: str, element_type: Literal["images", "labels", "points", "shapes"], ) -> tuple[list[str], list[str], bool]: wants_elements = True - if element_type in ["images", "labels", "points", "shapes"]: # Prevents eval security risk + if element_type in [ + "images", + "labels", + "points", + "shapes", + ]: # Prevents eval security risk wanted_elements: list[str] = [params.element] wanted_elements_on_cs = [ element for element in wanted_elements if cs in set(get_transformation(sdata[element], get_all=True).keys()) @@ -2097,7 +2182,10 @@ def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = No def _get_extent_and_range_for_datashader_canvas( - spatial_element: SpatialElement, coordinate_system: str, ax: Axes, fig_params: FigParams + spatial_element: SpatialElement, + coordinate_system: str, + ax: Axes, + fig_params: FigParams, ) -> tuple[Any, Any, list[Any], list[Any], Any]: extent = get_extent(spatial_element, coordinate_system=coordinate_system) x_ext = [min(0, extent["x"][0]), extent["x"][1]] @@ -2133,7 +2221,9 @@ def _get_extent_and_range_for_datashader_canvas( def _create_image_from_datashader_result( - ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], factor: float, ax: Axes + ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], + factor: float, + ax: Axes, ) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]: # create SpatialImage from datashader output to get it back to original size rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base @@ -2153,7 +2243,7 @@ def _create_image_from_datashader_result( def _datashader_aggregate_with_function( - reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, + reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), cvs: Canvas, spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, col_for_color: str | None, @@ -2217,7 +2307,7 @@ def _datashader_aggregate_with_function( def _datshader_get_how_kw_for_spread( - reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, + reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), ) -> str: # Get the best input for the how argument of ds.tf.spread(), needed for numerical values reduction = reduction or "sum" @@ -2243,8 +2333,13 @@ def _datshader_get_how_kw_for_spread( def _prepare_transformation( - element: DataArray | GeoDataFrame | dask.dataframe.core.DataFrame, coordinate_system: str, ax: Axes | None = None -) -> tuple[matplotlib.transforms.Affine2D, matplotlib.transforms.CompositeGenericTransform | None]: + element: DataArray | GeoDataFrame | dask.dataframe.core.DataFrame, + coordinate_system: str, + ax: Axes | None = None, +) -> tuple[ + matplotlib.transforms.Affine2D, + matplotlib.transforms.CompositeGenericTransform | None, +]: trans = get_transformation(element, get_all=True)[coordinate_system] affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) trans = mtransforms.Affine2D(matrix=affine_trans) @@ -2312,11 +2407,19 @@ def _datashader_map_aggregate_to_color( agg_under = agg.where(agg < span[0]) img_under = ds.tf.shade( - agg_under, cmap=[to_hex(cmap.get_under())[:7]], min_alpha=min_alpha, color_key=color_key + agg_under, + cmap=[to_hex(cmap.get_under())[:7]], + min_alpha=min_alpha, + color_key=color_key, ) agg_over = agg.where(agg > span[1]) - img_over = ds.tf.shade(agg_over, cmap=[to_hex(cmap.get_over())[:7]], min_alpha=min_alpha, color_key=color_key) + img_over = ds.tf.shade( + agg_over, + cmap=[to_hex(cmap.get_over())[:7]], + min_alpha=min_alpha, + color_key=color_key, + ) # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 stack = img_under.to_numpy().base @@ -2329,4 +2432,49 @@ def _datashader_map_aggregate_to_color( stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] return stack - return ds.tf.shade(agg, cmap=cmap, color_key=color_key, min_alpha=min_alpha, span=span, how="linear") + return ds.tf.shade( + agg, + cmap=cmap, + color_key=color_key, + min_alpha=min_alpha, + span=span, + how="linear", + ) + + +def _hex_no_alpha(hex: str) -> str: + """ + Return a hex color string without an alpha component. + + Parameters + ---------- + hex : str + The input hex color string. Must be in one of the following formats: + - "#RRGGBB": a hex color without an alpha channel. + - "#RRGGBBAA": a hex color with an alpha channel that will be removed. + + Returns + ------- + str + The hex color string in "#RRGGBB" format. + """ + if not isinstance(hex, str): + raise TypeError("Input must be a string") + if not hex.startswith("#"): + raise ValueError("Invalid hex color: must start with '#'") + + hex_digits = hex[1:] + length = len(hex_digits) + + if length == 6: + if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): + raise ValueError("Invalid hex color: contains non-hex characters") + return hex # Already in #RRGGBB format. + + if length == 8: + if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): + raise ValueError("Invalid hex color: contains non-hex characters") + # Return only the first 6 characters, stripping the alpha. + return "#" + hex_digits[:6] + + raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'") diff --git a/tests/_images/Shapes_can_color_by_category_with_cmap.png b/tests/_images/Shapes_can_color_by_category_with_cmap.png new file mode 100644 index 00000000..e7d300f4 Binary files /dev/null and b/tests/_images/Shapes_can_color_by_category_with_cmap.png differ diff --git a/tests/_images/Shapes_datashader_can_color_by_category_with_cmap.png b/tests/_images/Shapes_datashader_can_color_by_category_with_cmap.png new file mode 100644 index 00000000..dcef7d21 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_color_by_category_with_cmap.png differ diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index affeebd0..953eb843 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -315,11 +315,54 @@ def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData): adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) adata.obs["instance_id"] = list(range(adata.n_obs)) adata.obs["region"] = "blobs_polygons" - table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_polygons") + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region="blobs_polygons", + ) sdata_blobs["table"] = table sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", method="datashader").pl.show() + def test_plot_datashader_can_color_by_category_with_cmap(self, sdata_blobs: SpatialData): + RNG = np.random.default_rng(seed=42) + n_obs = len(sdata_blobs["blobs_polygons"]) + adata = AnnData(RNG.normal(size=(n_obs, 10))) + adata.obs = pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"]) + adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) + adata.obs["instance_id"] = list(range(adata.n_obs)) + adata.obs["region"] = "blobs_polygons" + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region="blobs_polygons", + ) + sdata_blobs["table"] = table + + sdata_blobs.pl.render_shapes( + element="blobs_polygons", color="category", method="datashader", cmap="cool" + ).pl.show() + + def test_plot_can_color_by_category_with_cmap(self, sdata_blobs: SpatialData): + RNG = np.random.default_rng(seed=42) + n_obs = len(sdata_blobs["blobs_polygons"]) + adata = AnnData(RNG.normal(size=(n_obs, 10))) + adata.obs = pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"]) + adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) + adata.obs["instance_id"] = list(range(adata.n_obs)) + adata.obs["region"] = "blobs_polygons" + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region="blobs_polygons", + ) + sdata_blobs["table"] = table + + sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", cmap="cool").pl.show() + def test_plot_datashader_can_color_by_value(self, sdata_blobs: SpatialData): sdata_blobs["table"].obs["region"] = "blobs_polygons" sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"