|
34 | 34 | T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
|
35 | 35 |
|
36 | 36 |
|
37 |
| -def to_object_array(iterable): |
38 |
| - # using empty_like calls compute |
39 |
| - npargs = np.empty((len(iterable),), dtype=np.object) |
40 |
| - npargs[:] = iterable |
41 |
| - return npargs |
| 37 | +def unzip(iterable): |
| 38 | + return zip(*iterable) |
42 | 39 |
|
43 | 40 |
|
44 | 41 | def assert_chunks_compatible(a: Dataset, b: Dataset):
|
@@ -335,23 +332,33 @@ def _wrapper(
|
335 | 332 | if not dask.is_dask_collection(obj):
|
336 | 333 | return func(obj, *args, **kwargs)
|
337 | 334 |
|
338 |
| - npargs = to_object_array([obj] + list(args)) |
339 |
| - is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] |
340 |
| - is_array = [isinstance(arg, DataArray) for arg in npargs] |
| 335 | + all_args = [obj] + list(args) |
| 336 | + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args] |
| 337 | + is_array = [isinstance(arg, DataArray) for arg in all_args] |
| 338 | + |
| 339 | + # there should be a better way to group this. partition? |
| 340 | + xarray_indices, xarray_objs = unzip( |
| 341 | + (index, arg) for index, arg in enumerate(all_args) if is_xarray[index] |
| 342 | + ) |
| 343 | + others = [ |
| 344 | + (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index] |
| 345 | + ] |
341 | 346 |
|
342 | 347 | # all xarray objects must be aligned. This is consistent with apply_ufunc.
|
343 |
| - aligned = align(*npargs[is_xarray], join="exact") |
344 |
| - # assigning to object arrays works better when RHS is object array |
345 |
| - # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array |
346 |
| - npargs[is_xarray] = to_object_array(aligned) |
347 |
| - npargs[is_array] = to_object_array( |
348 |
| - [dataarray_to_dataset(da) for da in npargs[is_array]] |
| 348 | + aligned = align(*xarray_objs, join="exact") |
| 349 | + xarray_objs = tuple( |
| 350 | + dataarray_to_dataset(arg) if is_da else arg |
| 351 | + for is_da, arg in zip(is_array, aligned) |
| 352 | + ) |
| 353 | + |
| 354 | + _, npargs = unzip( |
| 355 | + sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) |
349 | 356 | )
|
350 | 357 |
|
351 | 358 | # check that chunk sizes are compatible
|
352 | 359 | input_chunks = dict(npargs[0].chunks)
|
353 | 360 | input_indexes = dict(npargs[0].indexes)
|
354 |
| - for arg in npargs[1:][is_xarray[1:]]: |
| 361 | + for arg in xarray_objs[1:]: |
355 | 362 | assert_chunks_compatible(npargs[0], arg)
|
356 | 363 | input_chunks.update(arg.chunks)
|
357 | 364 | input_indexes.update(arg.indexes)
|
|
0 commit comments