Skip to content

Commit 010560f

Browse files
committed
2 parents 60b191b + 0d5f729 commit 010560f

File tree

11 files changed

+144
-80
lines changed

11 files changed

+144
-80
lines changed

.github/release.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ changelog:
44
- release-ignore
55
authors:
66
- pre-commit-ci
7+
- pre-commit-ci[bot]
78
categories:
89
- title: Added
910
labels:

.github/workflows/test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ jobs:
4949
pip install pytest-cov
5050
- name: Install dependencies
5151
run: |
52+
pip install numpy
5253
pip install --pre -e ".[dev,test,pre]"
5354
- name: Test
5455
env:

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
hooks:
1414
- id: prettier
1515
- repo: https://github.com/astral-sh/ruff-pre-commit
16-
rev: v0.12.8
16+
rev: v0.12.12
1717
hooks:
1818
- id: ruff
1919
args: [--fix, --exit-non-zero-on-fix]

src/spatialdata_plot/pl/render.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _render_shapes(
136136
if isinstance(groups, list) and color_source_vector is not None:
137137
mask = color_source_vector.isin(groups)
138138
shapes = shapes[mask]
139-
shapes = shapes.reset_index()
139+
shapes = shapes.reset_index(drop=True)
140140
color_source_vector = color_source_vector[mask]
141141
color_vector = color_vector[mask]
142142

@@ -338,7 +338,7 @@ def _render_shapes(
338338
cax = None
339339
if aggregate_with_reduction is not None:
340340
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
341-
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
341+
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
342342
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
343343
assert norm.vmin is not None
344344
assert norm.vmax is not None
@@ -850,20 +850,22 @@ def _render_images(
850850
# 2) Image has any number of channels but 1
851851
else:
852852
layers = {}
853-
for ch_index, c in enumerate(channels):
854-
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
855-
856-
if not isinstance(render_params.cmap_params, list):
857-
if render_params.cmap_params.norm is not None:
858-
layers[c] = render_params.cmap_params.norm(layers[c])
853+
for ch_idx, ch in enumerate(channels):
854+
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
855+
if isinstance(render_params.cmap_params, list):
856+
ch_norm = render_params.cmap_params[ch_idx].norm
857+
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
859858
else:
860-
if render_params.cmap_params[ch_index].norm is not None:
861-
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])
859+
ch_norm = render_params.cmap_params.norm
860+
ch_cmap_is_default = render_params.cmap_params.cmap_is_default
861+
862+
if not ch_cmap_is_default and ch_norm is not None:
863+
layers[ch_idx] = ch_norm(layers[ch_idx])
862864

863865
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
864866
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
865867
if render_params.cmap_params.cmap_is_default: # -> use RGB
866-
stacked = np.stack([layers[c] for c in channels], axis=-1)
868+
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
867869
else: # -> use given cmap for each channel
868870
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
869871
stacked = (
@@ -896,12 +898,54 @@ def _render_images(
896898
# overwrite if n_channels == 2 for intuitive result
897899
if n_channels == 2:
898900
seed_colors = ["#ff0000ff", "#00ff00ff"]
899-
else:
901+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
902+
colored = np.stack(
903+
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
904+
0,
905+
).sum(0)
906+
colored = colored[:, :, :3]
907+
elif n_channels == 3:
900908
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
909+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
910+
colored = np.stack(
911+
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
912+
0,
913+
).sum(0)
914+
colored = colored[:, :, :3]
915+
else:
916+
if isinstance(render_params.cmap_params, list):
917+
cmap_is_default = render_params.cmap_params[0].cmap_is_default
918+
else:
919+
cmap_is_default = render_params.cmap_params.cmap_is_default
901920

902-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
903-
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
904-
colored = colored[:, :, :3]
921+
if cmap_is_default:
922+
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
923+
else:
924+
# Sample n_channels colors evenly from the colormap
925+
if isinstance(render_params.cmap_params, list):
926+
seed_colors = [
927+
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
928+
]
929+
else:
930+
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
931+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
932+
933+
# Stack (n_channels, height, width) → (height*width, n_channels)
934+
H, W = next(iter(layers.values())).shape
935+
comp_rgb = np.zeros((H, W, 3), dtype=float)
936+
937+
# For each channel: map to RGBA, apply constant alpha, then add
938+
for ch_idx, ch in enumerate(channels):
939+
layer_arr = layers[ch]
940+
rgba = channel_cmaps[ch_idx](layer_arr)
941+
rgba[..., 3] = render_params.alpha
942+
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]
943+
944+
colored = np.clip(comp_rgb, 0, 1)
945+
logger.info(
946+
f"Your image has {n_channels} channels. Sampling categorical colors and using "
947+
f"multichannel strategy 'stack' to render."
948+
) # TODO: update when pca is added as strategy
905949

906950
_ax_show_and_transform(
907951
colored,
@@ -947,6 +991,7 @@ def _render_images(
947991
zorder=render_params.zorder,
948992
)
949993

994+
# 2D) Image has n channels, no palette but cmap info
950995
elif palette is not None and got_multiple_cmaps:
951996
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
952997

src/spatialdata_plot/pl/utils.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -780,8 +780,9 @@ def _set_color_source_vec(
780780

781781
color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`
782782

783+
# TODO check why table_name is not passed here.
783784
color_mapping = _get_categorical_color_mapping(
784-
adata=sdata.table,
785+
adata=sdata["table"],
785786
cluster_key=value_to_plot,
786787
color_source_vector=color_source_vector,
787788
cmap_params=cmap_params,
@@ -2008,7 +2009,7 @@ def _validate_col_for_column_table(
20082009
table_name = next(iter(tables))
20092010
if len(tables) > 1:
20102011
warnings.warn(
2011-
f"Multiple tables contain color column, using {table_name}",
2012+
f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.",
20122013
UserWarning,
20132014
stacklevel=2,
20142015
)
@@ -2044,44 +2045,57 @@ def _validate_image_render_params(
20442045
element_params[el] = {}
20452046
spatial_element = param_dict["sdata"][el]
20462047

2048+
# robustly get channel names from image or multiscale image
20472049
spatial_element_ch = (
2048-
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
2050+
spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
20492051
)
2050-
20512052
channel = param_dict["channel"]
2052-
channel_list: list[str] | list[int] | None
2053-
if isinstance(channel, list):
2054-
type_ = type(channel[0])
2055-
assert all(isinstance(ch, type_) for ch in channel), "All channels must be of the same type."
2056-
# mypy complains that channel_list can be also of type list[str | int]
2057-
channel_list = [channel] if isinstance(channel, int | str) else channel # type: ignore[assignment]
2058-
2059-
if channel_list is not None and (
2060-
(isinstance(channel_list[0], int) and max([abs(ch) for ch in channel_list]) <= len(spatial_element_ch)) # type: ignore[arg-type]
2061-
or all(ch in spatial_element_ch for ch in channel_list)
2062-
):
2063-
element_params[el]["channel"] = channel_list
2053+
if channel is not None:
2054+
# Normalize channel to always be a list of str or a list of int
2055+
if isinstance(channel, str):
2056+
channel = [channel]
2057+
2058+
if isinstance(channel, int):
2059+
channel = [channel]
2060+
2061+
# If channel is a list, ensure all elements are the same type
2062+
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
2063+
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")
2064+
2065+
invalid = [c for c in channel if c not in spatial_element_ch]
2066+
if invalid:
2067+
raise ValueError(
2068+
f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
2069+
)
2070+
element_params[el]["channel"] = channel
20642071
else:
20652072
element_params[el]["channel"] = None
20662073

20672074
element_params[el]["alpha"] = param_dict["alpha"]
20682075

2069-
if isinstance(palette := param_dict["palette"], list):
2076+
palette = param_dict["palette"]
2077+
assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure
2078+
2079+
if isinstance(palette, list):
2080+
# case A: single palette for all channels
20702081
if len(palette) == 1:
2071-
palette_length = len(channel_list) if channel_list is not None else len(spatial_element_ch)
2082+
palette_length = len(channel) if channel is not None else len(spatial_element_ch)
20722083
palette = palette * palette_length
2073-
if (channel_list is not None and len(palette) != len(channel_list)) and len(palette) != len(
2074-
spatial_element_ch
2075-
):
2076-
palette = None
2084+
# case B: one palette per channel (either given or derived from channel length)
2085+
channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
2086+
if channels_to_use is not None and len(palette) != len(channels_to_use):
2087+
raise ValueError(
2088+
f"Palette length ({len(palette)}) does not match channel length "
2089+
f"({', '.join(str(c) for c in channels_to_use)})."
2090+
)
20772091
element_params[el]["palette"] = palette
20782092
element_params[el]["na_color"] = param_dict["na_color"]
20792093

20802094
if (cmap := param_dict["cmap"]) is not None:
20812095
if len(cmap) == 1:
2082-
cmap_length = len(channel_list) if channel_list is not None else len(spatial_element_ch)
2096+
cmap_length = len(channel) if channel is not None else len(spatial_element_ch)
20832097
cmap = cmap * cmap_length
2084-
if (channel_list is not None and len(cmap) != len(channel_list)) or len(cmap) != len(spatial_element_ch):
2098+
if (channel is not None and len(cmap) != len(channel)) or len(cmap) != len(spatial_element_ch):
20852099
cmap = None
20862100
element_params[el]["cmap"] = cmap
20872101
element_params[el]["norm"] = param_dict["norm"]
@@ -2099,7 +2113,7 @@ def _validate_image_render_params(
20992113
def _get_wanted_render_elements(
21002114
sdata: SpatialData,
21012115
sdata_wanted_elements: list[str],
2102-
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
2116+
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
21032117
cs: str,
21042118
element_type: Literal["images", "labels", "points", "shapes"],
21052119
) -> tuple[list[str], list[str], bool]:
@@ -2256,7 +2270,7 @@ def _create_image_from_datashader_result(
22562270

22572271

22582272
def _datashader_aggregate_with_function(
2259-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2273+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
22602274
cvs: Canvas,
22612275
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
22622276
col_for_color: str | None,
@@ -2320,7 +2334,7 @@ def _datashader_aggregate_with_function(
23202334

23212335

23222336
def _datshader_get_how_kw_for_spread(
2323-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2337+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
23242338
) -> str:
23252339
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
23262340
reduction = reduction or "sum"
-11.2 KB
Loading

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_sdata_multiple_images_diverging_dims():
154154
def sdata_blobs_shapes_annotated() -> SpatialData:
155155
"""Get blobs sdata with continuous annotation of polygons."""
156156
blob = blobs()
157-
blob["table"].obs["region"] = "blobs_polygons"
157+
blob["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * blob["table"].n_obs)
158158
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
159159
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
160160
return blob

tests/pl/test_render_labels.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def test_plot_can_render_labels(self, sdata_blobs: SpatialData):
3333
sdata_blobs.pl.render_labels(element="blobs_labels").pl.show()
3434

3535
def test_plot_can_render_multiscale_labels(self, sdata_blobs: SpatialData):
36-
sdata_blobs["table"].obs["region"] = "blobs_multiscale_labels"
36+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * sdata_blobs["table"].n_obs)
3737
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
3838
sdata_blobs.pl.render_labels("blobs_multiscale_labels").pl.show()
3939

4040
def test_plot_can_render_given_scale_of_multiscale_labels(self, sdata_blobs: SpatialData):
41-
sdata_blobs["table"].obs["region"] = "blobs_multiscale_labels"
41+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * sdata_blobs["table"].n_obs)
4242
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
4343
sdata_blobs.pl.render_labels("blobs_multiscale_labels", scale="scale1").pl.show()
4444

@@ -50,7 +50,7 @@ def test_plot_can_do_rasterization(self, sdata_blobs: SpatialData):
5050
img.attrs["transform"] = sdata_blobs["blobs_labels"].transform
5151
sdata_blobs["blobs_giant_labels"] = img
5252

53-
sdata_blobs["table"].obs["region"] = "blobs_giant_labels"
53+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_giant_labels"] * sdata_blobs["table"].n_obs)
5454
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_giant_labels"
5555

5656
sdata_blobs.pl.render_labels("blobs_giant_labels").pl.show()
@@ -63,7 +63,7 @@ def test_plot_can_stop_rasterization_with_scale_full(self, sdata_blobs: SpatialD
6363
img.attrs["transform"] = sdata_blobs["blobs_labels"].transform
6464
sdata_blobs["blobs_giant_labels"] = img
6565

66-
sdata_blobs["table"].obs["region"] = "blobs_giant_labels"
66+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_giant_labels"] * sdata_blobs["table"].n_obs)
6767
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_giant_labels"
6868

6969
sdata_blobs.pl.render_labels("blobs_giant_labels", scale="full").pl.show()
@@ -110,7 +110,7 @@ def _make_tablemodel_with_categorical_labels(sdata_blobs, label):
110110
max_col = max_col.str.replace("channel_", "ch").str.replace("_sum", "")
111111
max_col = pd.Categorical(max_col, categories=set(max_col), ordered=True)
112112
adata.obs["which_max"] = max_col
113-
adata.obs["region"] = label
113+
adata.obs["region"] = pd.Categorical([label] * adata.n_obs)
114114
del adata.uns["spatialdata_attrs"]
115115
table = TableModel.parse(
116116
adata=adata,
@@ -142,7 +142,7 @@ def test_plot_two_calls_with_coloring_result_in_two_colorbars(self, sdata_blobs:
142142
sdata_blobs_local = deepcopy(sdata_blobs)
143143

144144
table = sdata_blobs_local["table"].copy()
145-
table.obs["region"] = "blobs_multiscale_labels"
145+
table.obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * table.n_obs)
146146
table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
147147
table = table[:, ~table.var_names.isin(["channel_0_sum"])]
148148
sdata_blobs_local["multi_table"] = table
@@ -187,7 +187,7 @@ def test_plot_label_colorbar_uses_alpha_of_less_transparent_outline(
187187

188188
def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData):
189189
table = sdata_blobs["table"].copy()
190-
table.obs["region"] = "blobs_multiscale_labels"
190+
table.obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * table.n_obs)
191191
table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
192192
table = table[:, ~table.var_names.isin(["channel_0_sum"])]
193193
sdata_blobs["multi_table"] = table
@@ -196,9 +196,9 @@ def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData):
196196
).pl.show()
197197

198198
def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: SpatialData):
199-
max_col = sdata_blobs.table.to_df().idxmax(axis=1)
200-
max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True)
201-
sdata_blobs.table.obs["which_max"] = max_col
199+
max_col = sdata_blobs["table"].to_df().idxmax(axis=1)
200+
max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True)
201+
sdata_blobs["table"].obs["which_max"] = max_col
202202

203203
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
204204

@@ -210,9 +210,9 @@ def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: Spatia
210210
).pl.show(ax=axs[1])
211211

212212
def test_plot_subset_categorical_label_maintains_order_when_palette_overwrite(self, sdata_blobs: SpatialData):
213-
max_col = sdata_blobs.table.to_df().idxmax(axis=1)
214-
max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True)
215-
sdata_blobs.table.obs["which_max"] = max_col
213+
max_col = sdata_blobs["table"].to_df().idxmax(axis=1)
214+
max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True)
215+
sdata_blobs["table"].obs["which_max"] = max_col
216216

217217
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
218218

0 commit comments

Comments
 (0)