Skip to content

Commit a1d28e5

Browse files
committed
small refactor
1 parent ac6b49e commit a1d28e5

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

src/spatialdata_plot/pl/utils.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -733,10 +733,10 @@ def _set_color_source_vec(
733733
table_name: str | None = None,
734734
table_layer: str | None = None,
735735
render_type: Literal["points"] | None = None,
736-
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
736+
) -> tuple[pd.Categorical | None, ArrayLike, bool]:
737737
if value_to_plot is None and element is not None:
738738
color = np.full(len(element), na_color)
739-
return color, color, False
739+
return None, color, False
740740

741741
# First check if value_to_plot is likely a color specification rather than a column name
742742
if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None:
@@ -765,29 +765,27 @@ def _set_color_source_vec(
765765
table_layer=table_layer,
766766
)[value_to_plot]
767767

768-
is_categorical = isinstance(color_source_vector.dtype, pd.CategoricalDtype)
769-
is_numeric = pd.api.types.is_numeric_dtype(color_source_vector)
770-
771-
if is_numeric and not is_categorical:
772-
if (
773-
not isinstance(element, GeoDataFrame)
774-
and isinstance(palette, list)
775-
and palette[0] is not None
776-
or isinstance(element, GeoDataFrame)
777-
and isinstance(palette, list)
778-
):
779-
logger.warning(
780-
"Ignoring categorical palette which is given for a continuous variable. "
781-
"Consider using `cmap` to pass a ColorMap."
782-
)
783-
return None, color_source_vector, False
784-
785-
if not is_categorical:
768+
# Convert to categorical if not already
769+
if not isinstance(color_source_vector, pd.Categorical):
786770
try:
787771
color_source_vector = pd.Categorical(color_source_vector)
788772
except (ValueError, TypeError) as e:
789773
logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}")
790-
# Fall back to returning the original values
774+
# For numeric data, return None to indicate non-categorical
775+
if pd.api.types.is_numeric_dtype(color_source_vector):
776+
if (
777+
not isinstance(element, GeoDataFrame)
778+
and isinstance(palette, list)
779+
and palette[0] is not None
780+
or isinstance(element, GeoDataFrame)
781+
and isinstance(palette, list)
782+
):
783+
logger.warning(
784+
"Ignoring categorical palette which is given for a continuous variable. "
785+
"Consider using `cmap` to pass a ColorMap."
786+
)
787+
return None, color_source_vector, False
788+
# For other types, try to use as is
791789
return None, color_source_vector, False
792790

793791
# At this point color_source_vector should be categorical
@@ -807,7 +805,7 @@ def _set_color_source_vec(
807805
first_table = next(iter(annotator_tables))
808806
adata_with_colors = sdata.tables[first_table]
809807
adata_with_colors.uns["spatialdata_key"] = first_table
810-
808+
811809
# If no specific table is found, try using the default table
812810
elif sdata.table is not None:
813811
adata_with_colors = sdata.table
@@ -831,7 +829,6 @@ def _set_color_source_vec(
831829
raise ValueError("Unable to create color palette.")
832830

833831
# Map categorical values to colors
834-
# Do not rename categories, as colors need not be unique
835832
try:
836833
color_vector = color_source_vector.map(color_mapping)
837834
except (KeyError, TypeError, ValueError) as e:
@@ -847,7 +844,7 @@ def _set_color_source_vec(
847844

848845
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.")
849846
color = np.full(sdata[table_name].n_obs, to_hex(na_color))
850-
return color, color, False
847+
return None, color, False
851848

852849

853850
def _map_color_seg(
@@ -956,7 +953,6 @@ def _generate_base_categorial_color_mapping(
956953
na_color: ColorLike,
957954
cmap_params: CmapParams | None = None,
958955
) -> Mapping[str, str]:
959-
960956
color_key = f"{cluster_key}_colors"
961957
color_found_in_uns_msg_template = (
962958
"Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. "

0 commit comments

Comments
 (0)