| 
17 | 17 | from matplotlib.cm import ScalarMappable  | 
18 | 18 | from matplotlib.colors import ListedColormap, Normalize  | 
19 | 19 | from scanpy._settings import settings as sc_settings  | 
20 |  | -from spatialdata import get_extent  | 
 | 20 | +from spatialdata import get_extent, join_spatialelement_table  | 
21 | 21 | from spatialdata.models import PointsModel, ShapesModel, get_table_keys  | 
22 | 22 | from spatialdata.transformations import get_transformation, set_transformation  | 
23 | 23 | from spatialdata.transformations.transformations import Identity  | 
@@ -76,13 +76,18 @@ def _render_shapes(  | 
76 | 76 |         filter_tables=bool(render_params.table_name),  | 
77 | 77 |     )  | 
78 | 78 | 
 
  | 
79 |  | -    shapes = sdata[element]  | 
80 |  | - | 
81 | 79 |     if (table_name := render_params.table_name) is None:  | 
82 | 80 |         table = None  | 
 | 81 | +        shapes = sdata_filt[element]  | 
83 | 82 |     else:  | 
84 |  | -        _, region_key, _ = get_table_keys(sdata[table_name])  | 
85 |  | -        table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]  | 
 | 83 | +        element_dict, joined_table = join_spatialelement_table(  | 
 | 84 | +            sdata, spatial_element_names=element, table_name=table_name, how="inner"  | 
 | 85 | +        )  | 
 | 86 | +        sdata_filt[element] = shapes = element_dict[element]  | 
 | 87 | +        joined_table.uns["spatialdata_attrs"]["region"] = (  | 
 | 88 | +            joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist()  | 
 | 89 | +        )  | 
 | 90 | +        sdata_filt[table_name] = table = joined_table  | 
86 | 91 | 
 
  | 
87 | 92 |     if (  | 
88 | 93 |         col_for_color is not None  | 
 | 
0 commit comments