Skip to content

Commit 62e8b5a

Browse files
allow plotting with mismatch between element and table (#396)
2 parents e4a0170 + 4be40aa commit 62e8b5a

6 files changed

+19
-6
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from matplotlib.cm import ScalarMappable
1818
from matplotlib.colors import ListedColormap, Normalize
1919
from scanpy._settings import settings as sc_settings
20-
from spatialdata import get_extent
20+
from spatialdata import get_extent, join_spatialelement_table
2121
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
2222
from spatialdata.transformations import get_transformation, set_transformation
2323
from spatialdata.transformations.transformations import Identity
@@ -76,13 +76,18 @@ def _render_shapes(
7676
filter_tables=bool(render_params.table_name),
7777
)
7878

79-
shapes = sdata[element]
80-
8179
if (table_name := render_params.table_name) is None:
8280
table = None
81+
shapes = sdata_filt[element]
8382
else:
84-
_, region_key, _ = get_table_keys(sdata[table_name])
85-
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]
83+
element_dict, joined_table = join_spatialelement_table(
84+
sdata, spatial_element_names=element, table_name=table_name, how="inner"
85+
)
86+
sdata_filt[element] = shapes = element_dict[element]
87+
joined_table.uns["spatialdata_attrs"]["region"] = (
88+
joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist()
89+
)
90+
sdata_filt[table_name] = table = joined_table
8691

8792
if (
8893
col_for_color is not None
355 Bytes
Loading
22 KB
Loading
338 Bytes
Loading
-2 Bytes
Loading

tests/pl/test_render_shapes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData
169169
cropped_blob.pl.render_shapes().pl.show()
170170

171171
def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
172-
new_table = sdata_blobs["table"].copy()
173172
sdata_blobs["table"].obs["region"] = "blobs_circles"
174173
new_table = sdata_blobs["table"][:5]
175174
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
@@ -447,3 +446,12 @@ def test_plot_datashader_can_transform_circles(self, sdata_blobs: SpatialData):
447446
_set_transformations(sdata_blobs["blobs_circles"], {"global": seq})
448447

449448
sdata_blobs.pl.render_shapes("blobs_circles", method="datashader", outline_alpha=1.0).pl.show()
449+
450+
def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData):
451+
table_shapes = sdata_blobs["table"][:3].copy()
452+
table_shapes.obs.instance_id = list(range(3))
453+
table_shapes.obs["region"] = "blobs_circles"
454+
table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles"
455+
sdata_blobs["new_table"] = table_shapes
456+
457+
sdata_blobs.pl.render_shapes("blobs_circles", color="instance_id").pl.show()

0 commit comments

Comments
 (0)