From db5e5641a3b8de3effb45f1a8c831ade851239ce Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Fri, 9 May 2025 18:13:26 +0200 Subject: [PATCH 1/3] fixed categorical import; added logging; fixed code for looking for xyz_colors col --- src/spatialdata_plot/pl/utils.py | 263 ++++++++++++++++++++++++++----- 1 file changed, 227 insertions(+), 36 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a2e8f767..58b07c68 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -738,6 +738,12 @@ def _set_color_source_vec( color = np.full(len(element), na_color) return color, color, False + # First check if value_to_plot is likely a color specification rather than a column name + if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None: + # User passed a color, not a column name + color = np.full(len(element), value_to_plot) + return None, color, False + # Figure out where to get the color from origins = _locate_value( value_key=value_to_plot, @@ -760,9 +766,12 @@ def _set_color_source_vec( table_layer=table_layer, )[value_to_plot] - # numerical case, return early - # TODO temporary split until refactor is complete - if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype): + # Check what type of data we're dealing with + is_categorical = isinstance(color_source_vector.dtype, pd.CategoricalDtype) + is_numeric = pd.api.types.is_numeric_dtype(color_source_vector) + + # If it's numeric data, handle it appropriately + if is_numeric and not is_categorical: if ( not isinstance(element, GeoDataFrame) and isinstance(palette, list) @@ -776,11 +785,43 @@ def _set_color_source_vec( ) return None, color_source_vector, False - color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series` - + # For non-numeric, non-categorical data (like strings), convert to categorical + if not is_categorical: + try: + color_source_vector = pd.Categorical(color_source_vector) + except (ValueError, TypeError) as e: + logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}") + # Fall back to returning the original values + return None, color_source_vector, False + + # At this point color_source_vector should be categorical + + # Look for predefined colors in the AnnData object + adata_with_colors = None + cluster_key = value_to_plot + + # First check if the table_name is specified + if table_name is not None and table_name in sdata.tables: + adata_with_colors = sdata.tables[table_name] + adata_with_colors.uns["spatialdata_key"] = table_name + + # If not, but the element is annotated by any table, use that + elif element_name is not None: + annotator_tables = get_element_annotators(sdata, element_name) + if len(annotator_tables) > 0: + # Use the first table that annotates this element + first_table = next(iter(annotator_tables)) + adata_with_colors = sdata.tables[first_table] + adata_with_colors.uns["spatialdata_key"] = first_table + # If no specific table is found, try using the default table + elif sdata.table is not None: + adata_with_colors = sdata.table + adata_with_colors.uns["spatialdata_key"] = "default_table" + + # Now generate the color mapping using the appropriate AnnData object and cluster_key color_mapping = _get_categorical_color_mapping( - adata=sdata.table, - cluster_key=value_to_plot, + adata=adata_with_colors, + cluster_key=cluster_key, color_source_vector=color_source_vector, cmap_params=cmap_params, alpha=alpha, @@ -790,16 +831,27 @@ def _set_color_source_vec( render_type=render_type, ) + # Set categories to match the mapping keys color_source_vector = color_source_vector.set_categories(color_mapping.keys()) if color_mapping is None: raise ValueError("Unable to create color palette.") - # do not rename categories, as colors need not be unique - color_vector = color_source_vector.map(color_mapping) + # Map categorical values to colors + # Do not rename categories, as colors need not be unique + try: + color_vector = color_source_vector.map(color_mapping) + except (KeyError, TypeError, ValueError) as e: + logger.warning(f"Error mapping colors: {e}. Attempting alternate approach.") + # Try mapping with string conversion + str_mapping = {str(k): v for k, v in color_mapping.items()} + color_vector = pd.Series( + [str_mapping.get(str(x), color_mapping.get("NaN", "#d3d3d3")) for x in color_source_vector], + index=color_source_vector.index, + ) return color_source_vector, color_vector, True - logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.") + logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.") color = np.full(sdata[table_name].n_obs, to_hex(na_color)) return color, color, False @@ -817,20 +869,35 @@ def _map_color_seg( ) -> ArrayLike: cell_id = np.array(cell_id) - if pd.api.types.is_categorical_dtype(color_vector.dtype): - # Case A: users wants to plot a categorical column + # Safely handle different types of color_vector + is_categorical = pd.api.types.is_categorical_dtype(getattr(color_vector, "dtype", None)) + is_numeric = pd.api.types.is_numeric_dtype(getattr(color_vector, "dtype", None)) + is_pandas_series = isinstance(color_vector, pd.Series) + + # Case A: categorical column + if is_categorical: if np.any(color_source_vector.isna()): cell_id[color_source_vector.isna()] = 0 val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1) cols = colors.to_rgba_array(color_vector.categories) - elif pd.api.types.is_numeric_dtype(color_vector.dtype): - # Case B: user wants to plot a continous column - if isinstance(color_vector, pd.Series): + + # Case B: continuous column + elif is_numeric: + if is_pandas_series: color_vector = color_vector.to_numpy() cols = cmap_params.cmap(cmap_params.norm(color_vector)) val_im = map_array(seg.copy(), cell_id, cell_id) + + # Case C & D: Other cases (could be strings, or hex colors) else: - # Case C: User didn't specify any colors + # Get the first color safely, regardless of index structure + first_color = None + if is_pandas_series and len(color_vector) > 0: + first_color = color_vector.iloc[0] + elif not is_pandas_series and len(color_vector) > 0: + first_color = color_vector[0] + + # Case C: Using default colors with random generation if color_source_vector is not None and ( set(color_vector) == set(color_source_vector) and len(set(color_vector)) == 1 @@ -840,14 +907,31 @@ def _map_color_seg( val_im = map_array(seg.copy(), cell_id, cell_id) RNG = default_rng(42) cols = RNG.random((len(color_vector), 3)) + + # Case D: User specified explicit colors or we're using defaults else: - # Case D: User didn't specify a column to color by, but modified the na_color val_im = map_array(seg.copy(), cell_id, cell_id) - if "#" in str(color_vector[0]): - # we have hex colors - assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like." - cols = colors.to_rgba_array(color_vector) + + # Check if we're dealing with hex colors + if first_color is not None and isinstance(first_color, str) and "#" in first_color: + # We have hex colors + all_is_color = True + for c in color_vector: + if not _is_color_like(c): + all_is_color = False + break + + if all_is_color: + try: + cols = colors.to_rgba_array(color_vector) + except ValueError as e: + logger.warning(f"Error converting colors: {e}, falling back to default colormap") + cols = cmap_params.cmap(cmap_params.norm(np.arange(len(color_vector)))) + else: + # Fall back to colormap + cols = cmap_params.cmap(cmap_params.norm(color_vector)) else: + # Use the colormap cols = cmap_params.cmap(cmap_params.norm(color_vector)) if seg_erosionpx is not None: @@ -879,21 +963,118 @@ def _generate_base_categorial_color_mapping( na_color: ColorLike, cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: - if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns: - colors = adata.uns[f"{cluster_key}_colors"] - categories = color_source_vector.categories.tolist() + ["NaN"] - if "#" not in na_color: - # should be unreachable, but just for safety - raise ValueError("Expected `na_color` to be a hex color, but got a non-hex color.") - - colors = [to_hex(to_rgba(color)[:3]) for color in colors] - na_color = to_hex(to_rgba(na_color)[:3]) + color_key = f"{cluster_key}_colors" - if na_color and len(categories) > len(colors): - return dict(zip(categories, colors + [na_color], strict=True)) - - return dict(zip(categories, colors, strict=True)) + # Break long string template into multiple lines to fix E501 error + color_found_in_uns_msg_template = ( + "Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. " + "If this is unexpected, please delete the column from your AnnData object." + ) + # Check if we have a valid AnnData and if the color key exists in uns + if adata is not None and cluster_key is not None: + # Check for direct color dictionary in uns (e.g., {'A': '#FF5733', 'B': '#3498DB'}) + if cluster_key in adata.uns and isinstance(adata.uns[cluster_key], dict): + # We have a direct color mapping dictionary + color_dict = adata.uns[cluster_key] + table_name = getattr(adata, "uns", {}).get("spatialdata_key", "") + if table_name: + # Format the template with the actual values + logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) + + # Ensure all values are hex colors + for k, v in color_dict.items(): + if isinstance(v, str) and not v.startswith("#"): + color_dict[k] = to_hex(to_rgba(v)) + + # Add NA color if missing + categories = color_source_vector.categories.tolist() + na_color_hex = to_hex(to_rgba(na_color)[:3]) + + return {cat: color_dict.get(str(cat), color_dict.get(cat, na_color_hex)) for cat in categories} + + if color_key in adata.uns: + colors = adata.uns[color_key] + table_name = getattr(adata, "uns", {}).get("spatialdata_key", "") + if table_name: + if isinstance(colors, dict): + # Format the template with the actual values + logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) + else: + # Format the template with the actual values + logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) + + # Ensure colors are in hex format + if isinstance(colors, list): + colors = [to_hex(to_rgba(color)[:3]) for color in colors] + categories = color_source_vector.categories.tolist() + + # Handle NaN values + na_color_hex = to_hex(to_rgba(na_color)[:3]) + if "NaN" not in categories: + categories.append("NaN") + + # Make sure we have enough colors + if len(colors) < len(categories) - 1: # -1 for NaN + logger.warning( + f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). " + "Some categories will use default colors." + ) + # Extend with default colors or duplicate the last color + colors.extend([na_color_hex] * (len(categories) - 1 - len(colors))) + + # Create mapping with NaN color + return dict(zip(categories, colors + [na_color_hex], strict=False)) + + if isinstance(colors, np.ndarray): + # Convert numpy array to list of hex colors + colors = [to_hex(to_rgba(color)[:3]) for color in colors] + categories = color_source_vector.categories.tolist() + + # Handle NaN values + na_color_hex = to_hex(to_rgba(na_color)[:3]) + if "NaN" not in categories: + categories.append("NaN") + + # Make sure we have enough colors + if len(colors) < len(categories) - 1: # -1 for NaN + logger.warning( + f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). " + "Some categories will use default colors." + ) + # Extend with default colors + colors.extend([na_color_hex] * (len(categories) - 1 - len(colors))) + + # Create mapping with NaN color + return dict(zip(categories, colors + [na_color_hex], strict=False)) + + # Dictionary format - direct color mapping + if isinstance(colors, dict): + # Ensure all values are hex colors + for k, v in colors.items(): + if isinstance(v, str) and not v.startswith("#"): + colors[k] = to_hex(to_rgba(v)) + + # Get categories and handle NaN + categories = color_source_vector.categories.tolist() + na_color_hex = to_hex(to_rgba(na_color)[:3]) + + # Try to match color keys to categories, accounting for string/categorical differences + result = {} + for cat in categories: + # Try direct match first + if cat in colors: + result[cat] = colors[cat] + # Then try string conversion - handles int/string mismatches + elif str(cat) in colors: + result[cat] = colors[str(cat)] + else: + result[cat] = na_color_hex + + return result + + # If we reach here, we didn't find usable colors in uns, use default color mapping + logger.info(f"No colors found for '{cluster_key}' in AnnData.uns, using default colors") return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) @@ -1007,13 +1188,23 @@ def _maybe_set_colors( try: if palette is not None: raise KeyError("Unable to copy the palette when there was other explicitly specified.") - target.uns[color_key] = source.uns[color_key] + + # First check if source has the colors + if color_key in source.uns: + logger.info(f"Copying color information for '{key}' from source to target AnnData") + target.uns[color_key] = source.uns[color_key] + # Then check if the base key has colors (direct dict mapping) + elif key in source.uns and isinstance(source.uns[key], dict): + logger.info(f"Copying direct color mappings for '{key}' from source to target AnnData") + target.uns[key] = source.uns[key] + else: + raise KeyError(f"No color information found for '{key}' in source AnnData") + except KeyError: if isinstance(palette, str): palette = ListedColormap([palette]) if isinstance(palette, ListedColormap): # `scanpy` requires it palette = cycler(color=palette.colors) - palette = None add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette) From ac6b49e4ec6a58bf897c66ac46657e57a674d161 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Fri, 9 May 2025 18:26:22 +0200 Subject: [PATCH 2/3] removed comments --- src/spatialdata_plot/pl/utils.py | 36 +++----------------------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 58b07c68..1cb78899 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -744,7 +744,6 @@ def _set_color_source_vec( color = np.full(len(element), value_to_plot) return None, color, False - # Figure out where to get the color from origins = _locate_value( value_key=value_to_plot, sdata=sdata, @@ -766,11 +765,9 @@ def _set_color_source_vec( table_layer=table_layer, )[value_to_plot] - # Check what type of data we're dealing with is_categorical = isinstance(color_source_vector.dtype, pd.CategoricalDtype) is_numeric = pd.api.types.is_numeric_dtype(color_source_vector) - # If it's numeric data, handle it appropriately if is_numeric and not is_categorical: if ( not isinstance(element, GeoDataFrame) @@ -785,7 +782,6 @@ def _set_color_source_vec( ) return None, color_source_vector, False - # For non-numeric, non-categorical data (like strings), convert to categorical if not is_categorical: try: color_source_vector = pd.Categorical(color_source_vector) @@ -795,8 +791,6 @@ def _set_color_source_vec( return None, color_source_vector, False # At this point color_source_vector should be categorical - - # Look for predefined colors in the AnnData object adata_with_colors = None cluster_key = value_to_plot @@ -813,12 +807,12 @@ def _set_color_source_vec( first_table = next(iter(annotator_tables)) adata_with_colors = sdata.tables[first_table] adata_with_colors.uns["spatialdata_key"] = first_table + # If no specific table is found, try using the default table elif sdata.table is not None: adata_with_colors = sdata.table adata_with_colors.uns["spatialdata_key"] = "default_table" - # Now generate the color mapping using the appropriate AnnData object and cluster_key color_mapping = _get_categorical_color_mapping( adata=adata_with_colors, cluster_key=cluster_key, @@ -869,7 +863,6 @@ def _map_color_seg( ) -> ArrayLike: cell_id = np.array(cell_id) - # Safely handle different types of color_vector is_categorical = pd.api.types.is_categorical_dtype(getattr(color_vector, "dtype", None)) is_numeric = pd.api.types.is_numeric_dtype(getattr(color_vector, "dtype", None)) is_pandas_series = isinstance(color_vector, pd.Series) @@ -963,23 +956,19 @@ def _generate_base_categorial_color_mapping( na_color: ColorLike, cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: - color_key = f"{cluster_key}_colors" - # Break long string template into multiple lines to fix E501 error + color_key = f"{cluster_key}_colors" color_found_in_uns_msg_template = ( "Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. " "If this is unexpected, please delete the column from your AnnData object." ) - # Check if we have a valid AnnData and if the color key exists in uns if adata is not None and cluster_key is not None: - # Check for direct color dictionary in uns (e.g., {'A': '#FF5733', 'B': '#3498DB'}) if cluster_key in adata.uns and isinstance(adata.uns[cluster_key], dict): # We have a direct color mapping dictionary color_dict = adata.uns[cluster_key] table_name = getattr(adata, "uns", {}).get("spatialdata_key", "") if table_name: - # Format the template with the actual values logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) # Ensure all values are hex colors @@ -987,7 +976,6 @@ def _generate_base_categorial_color_mapping( if isinstance(v, str) and not v.startswith("#"): color_dict[k] = to_hex(to_rgba(v)) - # Add NA color if missing categories = color_source_vector.categories.tolist() na_color_hex = to_hex(to_rgba(na_color)[:3]) @@ -997,24 +985,16 @@ def _generate_base_categorial_color_mapping( colors = adata.uns[color_key] table_name = getattr(adata, "uns", {}).get("spatialdata_key", "") if table_name: - if isinstance(colors, dict): - # Format the template with the actual values - logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) - else: - # Format the template with the actual values - logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) + logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) - # Ensure colors are in hex format if isinstance(colors, list): colors = [to_hex(to_rgba(color)[:3]) for color in colors] categories = color_source_vector.categories.tolist() - # Handle NaN values na_color_hex = to_hex(to_rgba(na_color)[:3]) if "NaN" not in categories: categories.append("NaN") - # Make sure we have enough colors if len(colors) < len(categories) - 1: # -1 for NaN logger.warning( f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). " @@ -1023,39 +1003,31 @@ def _generate_base_categorial_color_mapping( # Extend with default colors or duplicate the last color colors.extend([na_color_hex] * (len(categories) - 1 - len(colors))) - # Create mapping with NaN color return dict(zip(categories, colors + [na_color_hex], strict=False)) if isinstance(colors, np.ndarray): - # Convert numpy array to list of hex colors colors = [to_hex(to_rgba(color)[:3]) for color in colors] categories = color_source_vector.categories.tolist() - # Handle NaN values na_color_hex = to_hex(to_rgba(na_color)[:3]) if "NaN" not in categories: categories.append("NaN") - # Make sure we have enough colors if len(colors) < len(categories) - 1: # -1 for NaN logger.warning( f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). " "Some categories will use default colors." ) - # Extend with default colors colors.extend([na_color_hex] * (len(categories) - 1 - len(colors))) - # Create mapping with NaN color return dict(zip(categories, colors + [na_color_hex], strict=False)) - # Dictionary format - direct color mapping if isinstance(colors, dict): # Ensure all values are hex colors for k, v in colors.items(): if isinstance(v, str) and not v.startswith("#"): colors[k] = to_hex(to_rgba(v)) - # Get categories and handle NaN categories = color_source_vector.categories.tolist() na_color_hex = to_hex(to_rgba(na_color)[:3]) @@ -1073,8 +1045,6 @@ def _generate_base_categorial_color_mapping( return result - # If we reach here, we didn't find usable colors in uns, use default color mapping - logger.info(f"No colors found for '{cluster_key}' in AnnData.uns, using default colors") return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) From a1d28e5822331dde3b0d4594824c8087ae62efac Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Fri, 9 May 2025 18:58:41 +0200 Subject: [PATCH 3/3] small refactor --- src/spatialdata_plot/pl/utils.py | 46 +++++++++++++++----------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 1cb78899..ed139b4a 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -733,10 +733,10 @@ def _set_color_source_vec( table_name: str | None = None, table_layer: str | None = None, render_type: Literal["points"] | None = None, -) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: +) -> tuple[pd.Categorical | None, ArrayLike, bool]: if value_to_plot is None and element is not None: color = np.full(len(element), na_color) - return color, color, False + return None, color, False # First check if value_to_plot is likely a color specification rather than a column name if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None: @@ -765,29 +765,27 @@ def _set_color_source_vec( table_layer=table_layer, )[value_to_plot] - is_categorical = isinstance(color_source_vector.dtype, pd.CategoricalDtype) - is_numeric = pd.api.types.is_numeric_dtype(color_source_vector) - - if is_numeric and not is_categorical: - if ( - not isinstance(element, GeoDataFrame) - and isinstance(palette, list) - and palette[0] is not None - or isinstance(element, GeoDataFrame) - and isinstance(palette, list) - ): - logger.warning( - "Ignoring categorical palette which is given for a continuous variable. " - "Consider using `cmap` to pass a ColorMap." - ) - return None, color_source_vector, False - - if not is_categorical: + # Convert to categorical if not already + if not isinstance(color_source_vector, pd.Categorical): try: color_source_vector = pd.Categorical(color_source_vector) except (ValueError, TypeError) as e: logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}") - # Fall back to returning the original values + # For numeric data, return None to indicate non-categorical + if pd.api.types.is_numeric_dtype(color_source_vector): + if ( + not isinstance(element, GeoDataFrame) + and isinstance(palette, list) + and palette[0] is not None + or isinstance(element, GeoDataFrame) + and isinstance(palette, list) + ): + logger.warning( + "Ignoring categorical palette which is given for a continuous variable. " + "Consider using `cmap` to pass a ColorMap." + ) + return None, color_source_vector, False + # For other types, try to use as is return None, color_source_vector, False # At this point color_source_vector should be categorical @@ -807,7 +805,7 @@ def _set_color_source_vec( first_table = next(iter(annotator_tables)) adata_with_colors = sdata.tables[first_table] adata_with_colors.uns["spatialdata_key"] = first_table - + # If no specific table is found, try using the default table elif sdata.table is not None: adata_with_colors = sdata.table @@ -831,7 +829,6 @@ def _set_color_source_vec( raise ValueError("Unable to create color palette.") # Map categorical values to colors - # Do not rename categories, as colors need not be unique try: color_vector = color_source_vector.map(color_mapping) except (KeyError, TypeError, ValueError) as e: @@ -847,7 +844,7 @@ def _set_color_source_vec( logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.") color = np.full(sdata[table_name].n_obs, to_hex(na_color)) - return color, color, False + return None, color, False def _map_color_seg( @@ -956,7 +953,6 @@ def _generate_base_categorial_color_mapping( na_color: ColorLike, cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: - color_key = f"{cluster_key}_colors" color_found_in_uns_msg_template = ( "Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. "