Skip to content

Commit 17e3b35

Browse files
authored
fix: Finalize dask patch (#113)
1 parent 6645ead commit 17e3b35

File tree

1 file changed

+6
-8
lines changed
  • src/fast_array_utils/_plugins

1 file changed

+6
-8
lines changed
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
# SPDX-License-Identifier: MPL-2.0
22
from __future__ import annotations
33

4-
import numpy as np
5-
6-
# Other lookup candidates: tensordot_lookup and take_lookup
7-
from dask.array.dispatch import concatenate_lookup
4+
from dask.array.dispatch import concatenate_lookup, take_lookup, tensordot_lookup
85
from scipy.sparse import sparray, spmatrix
96

107

@@ -13,11 +10,12 @@
1310
def patch() -> None: # pragma: no cover
1411
"""Patch dask to support sparse arrays.
1512
16-
See <https://github.com/dask/dask/blob/4d71629d1f22ced0dd780919f22e70a642ec6753/dask/array/backends.py#L212-L232>
13+
See <https://github.com/dask/dask/blob/d9b5c5b0256208f1befe94b26bfa8eaabcd0536d/dask/array/backends.py#L239-L241>
1714
"""
1815
# Avoid patch if already patched or upstream support has been added
19-
if concatenate_lookup.dispatch(sparray) is not np.concatenate:
16+
if concatenate_lookup.dispatch(sparray) is not concatenate_lookup.dispatch(spmatrix):
2017
return
2118

22-
concatenate = concatenate_lookup.dispatch(spmatrix)
23-
concatenate_lookup.register(sparray, concatenate)
19+
concatenate_lookup.register(sparray, concatenate_lookup.dispatch(spmatrix))
20+
tensordot_lookup.register(sparray, tensordot_lookup.dispatch(spmatrix))
21+
take_lookup.register(sparray, take_lookup.dispatch(spmatrix))

0 commit comments

Comments
 (0)