Skip to content

Commit 1e40ea1

Browse files
committed
small refactor
1 parent c658853 commit 1e40ea1

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

src/spatialdata_plot/pl/utils.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -732,10 +732,10 @@ def _set_color_source_vec(
732732
table_name: str | None = None,
733733
table_layer: str | None = None,
734734
render_type: Literal["points"] | None = None,
735-
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
735+
) -> tuple[pd.Categorical | None, ArrayLike, bool]:
736736
if value_to_plot is None and element is not None:
737737
color = np.full(len(element), na_color)
738-
return color, color, False
738+
return None, color, False
739739

740740
# First check if value_to_plot is likely a color specification rather than a column name
741741
if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None:
@@ -764,29 +764,27 @@ def _set_color_source_vec(
764764
table_layer=table_layer,
765765
)[value_to_plot]
766766

767-
is_categorical = isinstance(color_source_vector.dtype, pd.CategoricalDtype)
768-
is_numeric = pd.api.types.is_numeric_dtype(color_source_vector)
769-
770-
if is_numeric and not is_categorical:
771-
if (
772-
not isinstance(element, GeoDataFrame)
773-
and isinstance(palette, list)
774-
and palette[0] is not None
775-
or isinstance(element, GeoDataFrame)
776-
and isinstance(palette, list)
777-
):
778-
logger.warning(
779-
"Ignoring categorical palette which is given for a continuous variable. "
780-
"Consider using `cmap` to pass a ColorMap."
781-
)
782-
return None, color_source_vector, False
783-
784-
if not is_categorical:
767+
# Convert to categorical if not already
768+
if not isinstance(color_source_vector, pd.Categorical):
785769
try:
786770
color_source_vector = pd.Categorical(color_source_vector)
787771
except (ValueError, TypeError) as e:
788772
logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}")
789-
# Fall back to returning the original values
773+
# For numeric data, return None to indicate non-categorical
774+
if pd.api.types.is_numeric_dtype(color_source_vector):
775+
if (
776+
not isinstance(element, GeoDataFrame)
777+
and isinstance(palette, list)
778+
and palette[0] is not None
779+
or isinstance(element, GeoDataFrame)
780+
and isinstance(palette, list)
781+
):
782+
logger.warning(
783+
"Ignoring categorical palette which is given for a continuous variable. "
784+
"Consider using `cmap` to pass a ColorMap."
785+
)
786+
return None, color_source_vector, False
787+
# For other types, try to use as is
790788
return None, color_source_vector, False
791789

792790
# At this point color_source_vector should be categorical
@@ -806,7 +804,7 @@ def _set_color_source_vec(
806804
first_table = next(iter(annotator_tables))
807805
adata_with_colors = sdata.tables[first_table]
808806
adata_with_colors.uns["spatialdata_key"] = first_table
809-
807+
810808
# If no specific table is found, try using the default table
811809
elif sdata.table is not None:
812810
adata_with_colors = sdata.table
@@ -830,7 +828,6 @@ def _set_color_source_vec(
830828
raise ValueError("Unable to create color palette.")
831829

832830
# Map categorical values to colors
833-
# Do not rename categories, as colors need not be unique
834831
try:
835832
color_vector = color_source_vector.map(color_mapping)
836833
except (KeyError, TypeError, ValueError) as e:
@@ -846,7 +843,7 @@ def _set_color_source_vec(
846843

847844
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.")
848845
color = np.full(sdata[table_name].n_obs, to_hex(na_color))
849-
return color, color, False
846+
return None, color, False
850847

851848

852849
def _map_color_seg(
@@ -955,7 +952,6 @@ def _generate_base_categorial_color_mapping(
955952
na_color: ColorLike,
956953
cmap_params: CmapParams | None = None,
957954
) -> Mapping[str, str]:
958-
959955
color_key = f"{cluster_key}_colors"
960956
color_found_in_uns_msg_template = (
961957
"Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. "

0 commit comments

Comments
 (0)