Skip to content

Commit c4940e6

Browse files
committed
moved changes over
1 parent 31b1ee3 commit c4940e6

File tree

2 files changed

+2
-100
lines changed

2 files changed

+2
-100
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _render_shapes(
189189
element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system)
190190
tm = _get_transformation_matrix_for_datashader(element_trans)
191191
transformed_element = sdata_filt.shapes[element].transform(
192-
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
192+
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2]
193193
)
194194
transformed_element = ShapesModel.parse(
195195
gpd.GeoDataFrame(

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,8 @@
6666
from spatialdata._types import ArrayLike
6767
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement
6868

69-
# from spatialdata.transformations.transformations import Scale
70-
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation
71-
from spatialdata.transformations import Sequence as SDSequence
7269
from spatialdata.transformations.operations import get_transformation
70+
from spatialdata.transformations.transformations import Scale
7371
from xarray import DataArray, DataTree
7472

7573
from spatialdata_plot._logging import logger
@@ -2381,102 +2379,6 @@ def _prepare_transformation(
23812379
return trans, trans_data
23822380

23832381

2384-
def _get_datashader_trans_matrix_of_single_element(
2385-
trans: Identity | Scale | Affine | MapAxis | Translation,
2386-
) -> npt.NDArray[Any]:
2387-
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
2388-
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))
2389-
2390-
if isinstance(trans, Identity):
2391-
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2392-
if isinstance(trans, (Scale | Affine)):
2393-
# idea: "flip the y-axis", apply transformation, flip back
2394-
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
2395-
return flip_and_transform
2396-
if isinstance(trans, MapAxis):
2397-
# no flipping needed
2398-
return tm
2399-
# for a Translation, we need the transposed transformation matrix
2400-
tm_T = tm.T
2401-
assert isinstance(tm_T, np.ndarray)
2402-
return tm_T
2403-
2404-
2405-
def _get_transformation_matrix_for_datashader(
2406-
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
2407-
) -> npt.NDArray[Any]:
2408-
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
2409-
if isinstance(trans, SDSequence):
2410-
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2411-
for x in trans.transformations:
2412-
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
2413-
return tm
2414-
return _get_datashader_trans_matrix_of_single_element(trans)
2415-
2416-
2417-
def _datashader_map_aggregate_to_color(
2418-
agg: DataArray,
2419-
cmap: str | list[str] | ListedColormap,
2420-
color_key: None | list[str] = None,
2421-
min_alpha: float = 40,
2422-
span: None | list[float] = None,
2423-
clip: bool = True,
2424-
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
2425-
"""ds.tf.shade() part, ensuring correct clipping behavior.
2426-
2427-
If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
2428-
This ensures the correct clipping behavior, because else datashader would always automatically clip.
2429-
"""
2430-
if not clip and isinstance(cmap, Colormap) and span is not None:
2431-
# in case we use datashader together with a Normalize object where clip=False
2432-
# why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
2433-
agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
2434-
img_in = ds.tf.shade(
2435-
agg_in,
2436-
cmap=cmap,
2437-
span=(span[0], span[1]),
2438-
how="linear",
2439-
color_key=color_key,
2440-
min_alpha=min_alpha,
2441-
)
2442-
2443-
agg_under = agg.where(agg < span[0])
2444-
img_under = ds.tf.shade(
2445-
agg_under,
2446-
cmap=[to_hex(cmap.get_under())[:7]],
2447-
min_alpha=min_alpha,
2448-
color_key=color_key,
2449-
)
2450-
2451-
agg_over = agg.where(agg > span[1])
2452-
img_over = ds.tf.shade(
2453-
agg_over,
2454-
cmap=[to_hex(cmap.get_over())[:7]],
2455-
min_alpha=min_alpha,
2456-
color_key=color_key,
2457-
)
2458-
2459-
# stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0
2460-
stack = img_under.to_numpy().base
2461-
if stack is None:
2462-
stack = img_in.to_numpy().base
2463-
else:
2464-
stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0]
2465-
img_over = img_over.to_numpy().base
2466-
if img_over is not None:
2467-
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
2468-
return stack
2469-
2470-
return ds.tf.shade(
2471-
agg,
2472-
cmap=cmap,
2473-
color_key=color_key,
2474-
min_alpha=min_alpha,
2475-
span=span,
2476-
how="linear",
2477-
)
2478-
2479-
24802382
def _hex_no_alpha(hex: str) -> str:
24812383
"""
24822384
Return a hex color string without an alpha component.

0 commit comments

Comments
 (0)