Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ Bug fixes
and :py:meth:`DataArray.str.wrap` (:issue:`4334`). By `Mathias Hauser <https://github.com/mathause>`_.
- Fixed overflow issue causing incorrect results in computing means of :py:class:`cftime.datetime`
arrays (:issue:`4341`). By `Spencer Clark <https://github.com/spencerkclark>`_.
- Fix :py:func:`xarray.apply_ufunc` with ``vectorize=True`` and ``exclude_dims`` (:issue:`3890`).
By `Mathias Hauser <https://github.com/mathause>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
39 changes: 33 additions & 6 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,23 @@ def __ne__(self, other):

def __repr__(self):
return "{}({!r}, {!r})".format(
type(self).__name__, list(self.input_core_dims), list(self.output_core_dims)
type(self).__name__,
list(self.input_core_dims),
list(self.output_core_dims),
)

def __str__(self):
lhs = ",".join("({})".format(",".join(dims)) for dims in self.input_core_dims)
rhs = ",".join("({})".format(",".join(dims)) for dims in self.output_core_dims)
return f"{lhs}->{rhs}"

def to_gufunc_string(self):
def to_gufunc_string(self, exclude_dims=frozenset()):
"""Create an equivalent signature string for a NumPy gufunc.

Unlike __str__, handles dimensions that don't map to Python
identifiers.

Also creates unique names for input_core_dims contained in exclude_dims.
"""
input_core_dims = [
[self.dims_map[dim] for dim in core_dims]
Expand All @@ -142,6 +146,25 @@ def to_gufunc_string(self):
[self.dims_map[dim] for dim in core_dims]
for core_dims in self.output_core_dims
]

# enumerate input_core_dims contained in exclude_dims to make them unique
if exclude_dims:

exclude_dims = [self.dims_map[dim] for dim in exclude_dims]

counter = Counter()

def _enumerate(dim):
if dim in exclude_dims:
n = counter[dim]
counter.update([dim])
dim = f"{dim}_{n}"
return dim

input_core_dims = [
[_enumerate(dim) for dim in arg] for arg in input_core_dims
]

alt_signature = type(self)(input_core_dims, output_core_dims)
return str(alt_signature)

Expand Down Expand Up @@ -545,10 +568,12 @@ def broadcast_compat_data(
return data


def _vectorize(func, signature, output_dtypes):
def _vectorize(func, signature, output_dtypes, exclude_dims):
if signature.all_core_dims:
func = np.vectorize(
func, otypes=output_dtypes, signature=signature.to_gufunc_string()
func,
otypes=output_dtypes,
signature=signature.to_gufunc_string(exclude_dims),
)
else:
func = np.vectorize(func, otypes=output_dtypes)
Expand Down Expand Up @@ -623,7 +648,7 @@ def func(*arrays):

res = da.apply_gufunc(
numpy_func,
signature.to_gufunc_string(),
signature.to_gufunc_string(exclude_dims),
*arrays,
vectorize=vectorize,
output_dtypes=output_dtypes,
Expand All @@ -649,7 +674,9 @@ def func(*arrays):
)
else:
if vectorize:
func = _vectorize(func, signature, output_dtypes=output_dtypes)
func = _vectorize(
func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
)

result_data = func(*input_data)

Expand Down
45 changes: 45 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def test_signature_properties():
assert sig.num_outputs == 1
assert str(sig) == "(x),(x,y)->(z)"
assert sig.to_gufunc_string() == "(dim0),(dim0,dim1)->(dim2)"
assert (
sig.to_gufunc_string(exclude_dims=set("x")) == "(dim0_0),(dim0_1,dim1)->(dim2)"
)
# dimension names matter
assert _UFuncSignature([["x"]]) != _UFuncSignature([["y"]])

Expand Down Expand Up @@ -895,6 +898,48 @@ def test_vectorize_dask_dtype_meta():
assert np.float == actual.dtype


def pandas_median_add(x, y):
# function which can consume input of unequal length
return pd.Series(x).median() + pd.Series(y).median()


def test_vectorize_exclude_dims():
# GH 3890
data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y"))

expected = xr.DataArray([3, 5], dims=["x"])
actual = apply_ufunc(
pandas_median_add,
data_array_a,
data_array_b,
input_core_dims=[["y"], ["y"]],
vectorize=True,
exclude_dims=set("y"),
)
assert_identical(expected, actual)


@requires_dask
def test_vectorize_exclude_dims_dask():
# GH 3890
data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y"))

expected = xr.DataArray([3, 5], dims=["x"])
actual = apply_ufunc(
pandas_median_add,
data_array_a.chunk({"x": 1}),
data_array_b.chunk({"x": 1}),
input_core_dims=[["y"], ["y"]],
exclude_dims=set("y"),
vectorize=True,
dask="parallelized",
output_dtypes=[float],
)
assert_identical(expected, actual)


with raises_regex(TypeError, "Only xr.DataArray is supported"):
xr.corr(xr.Dataset(), xr.Dataset())

Expand Down