Skip to content

Commit 59a2397

Browse files
authored
speed up map_blocks (#4149)
* replace the object array with generator expressions and zip/enumerate * remove a leftover grouping pair of parentheses * reuse is_array instead of comparing again
1 parent 8f688ea commit 59a2397

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

xarray/core/parallel.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,8 @@
3434
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
3535

3636

37-
def to_object_array(iterable):
38-
# using empty_like calls compute
39-
npargs = np.empty((len(iterable),), dtype=np.object)
40-
npargs[:] = iterable
41-
return npargs
37+
def unzip(iterable):
38+
return zip(*iterable)
4239

4340

4441
def assert_chunks_compatible(a: Dataset, b: Dataset):
@@ -335,23 +332,33 @@ def _wrapper(
335332
if not dask.is_dask_collection(obj):
336333
return func(obj, *args, **kwargs)
337334

338-
npargs = to_object_array([obj] + list(args))
339-
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs]
340-
is_array = [isinstance(arg, DataArray) for arg in npargs]
335+
all_args = [obj] + list(args)
336+
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args]
337+
is_array = [isinstance(arg, DataArray) for arg in all_args]
338+
339+
# there should be a better way to group this. partition?
340+
xarray_indices, xarray_objs = unzip(
341+
(index, arg) for index, arg in enumerate(all_args) if is_xarray[index]
342+
)
343+
others = [
344+
(index, arg) for index, arg in enumerate(all_args) if not is_xarray[index]
345+
]
341346

342347
# all xarray objects must be aligned. This is consistent with apply_ufunc.
343-
aligned = align(*npargs[is_xarray], join="exact")
344-
# assigning to object arrays works better when RHS is object array
345-
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
346-
npargs[is_xarray] = to_object_array(aligned)
347-
npargs[is_array] = to_object_array(
348-
[dataarray_to_dataset(da) for da in npargs[is_array]]
348+
aligned = align(*xarray_objs, join="exact")
349+
xarray_objs = tuple(
350+
dataarray_to_dataset(arg) if is_da else arg
351+
for is_da, arg in zip(is_array, aligned)
352+
)
353+
354+
_, npargs = unzip(
355+
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
349356
)
350357

351358
# check that chunk sizes are compatible
352359
input_chunks = dict(npargs[0].chunks)
353360
input_indexes = dict(npargs[0].indexes)
354-
for arg in npargs[1:][is_xarray[1:]]:
361+
for arg in xarray_objs[1:]:
355362
assert_chunks_compatible(npargs[0], arg)
356363
input_chunks.update(arg.chunks)
357364
input_indexes.update(arg.indexes)

0 commit comments

Comments
 (0)