Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 27 additions & 94 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,9 @@ def _render_images(
stacked = np.stack([layers[c] for c in channels], axis=-1)
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
# Apply cmaps to each channel, add up and normalize to [0, 1]
stacked = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
)
# Remove alpha channel so we can overwrite it from render_params.alpha
stacked = stacked[:, :, :3]
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
Expand All @@ -676,11 +674,7 @@ def _render_images(
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
Expand All @@ -691,24 +685,16 @@ def _render_images(
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]

# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)

elif palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]

# Apply cmaps to each channel, add up and normalize to [0, 1]
colored = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
Expand Down Expand Up @@ -794,119 +780,66 @@ def _render_labels(
table_name=table_name,
)

# default case: no contour, just fill
# if fill_alpha and outline_alpha are the same, we're technically also at a no-outline situation
if render_params.outline_alpha == 0.0 or render_params.outline_alpha == render_params.fill_alpha:
labels_infill = _map_color_seg(
def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
labels = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=None,
seg_boundaries=False,
seg_erosionpx=seg_erosionpx,
seg_boundaries=seg_boundaries,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
)

_cax = ax.imshow(
labels_infill,
labels,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
alpha=alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
return cax # noqa: RET504

# default case: no contour, just fill
# since contour_px is passed to skimage.morphology.erosion to create the contour,
# any border thickness is only within the label, not outside. Therefore, the case
# of fill_alpha == outline_alpha is equivalent to fill-only
if (render_params.fill_alpha > 0.0 and render_params.outline_alpha == 0.0) or (
render_params.fill_alpha == render_params.outline_alpha
):
cax = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha)
alpha_to_decorate_ax = render_params.fill_alpha

# outline-only case
if render_params.fill_alpha == 0.0 and render_params.outline_alpha != 0.0:
labels_contour = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=True,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0:
cax = _draw_labels(
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
)
_cax = ax.imshow(
labels_contour,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
alpha_to_decorate_ax = render_params.outline_alpha

# pretty case: both outline and infill
if (
render_params.fill_alpha > 0.0
and render_params.outline_alpha > 0.0
and render_params.fill_alpha != render_params.outline_alpha
):
elif render_params.fill_alpha > 0.0 and render_params.outline_alpha > 0.0:
# first plot the infill ...
label_infill = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=None,
seg_boundaries=False,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
)

_cax_infill = ax.imshow(
label_infill,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax_infill.set_transform(trans_data)
cax_infill = ax.add_image(_cax_infill)
cax_infill = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha)

# ... then overlay the contour
label_contour = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=True,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
cax_contour = _draw_labels(
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
)

_cax_contour = ax.imshow(
label_contour,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax_contour.set_transform(trans_data)
cax_contour = ax.add_image(_cax_contour)

# pass the less-transparent _cax for the legend
cax = cax_infill if render_params.fill_alpha > render_params.outline_alpha else cax_contour
alpha_to_decorate_ax = max(render_params.fill_alpha, render_params.outline_alpha)

else:
raise ValueError("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0.")

_ = _decorate_axs(
ax=ax,
cax=cax,
Expand Down
Loading