diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 06ffe738..eeebf42f 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -764,6 +764,10 @@ def _binary_op(self, ], right: ArrayOrContainer ) -> ArrayOrContainer: + if not isinstance(self.broadcastee, self._stop_types): + # Defer to array container broadcast rules + return op(self.broadcastee, right) + try: serialized = serialize_container(right) except NotAnArrayContainerError: @@ -783,6 +787,10 @@ def _rev_binary_op(self, ], left: ArrayOrContainer ) -> ArrayOrContainer: + if not isinstance(self.broadcastee, self._stop_types): + # Defer to array container broadcast rules + return op(left, self.broadcastee) + try: serialized = serialize_container(left) except NotAnArrayContainerError: diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 31fa9e79..5f6078d6 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1248,6 +1248,20 @@ def test_no_leaf_array_type_broadcasting(actx_factory): np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + with pytest.raises(TypeError): + mc_op = mc + bcast(DOFArray(actx, (actx_ary,))) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + + mcdofbcast = MyContainerDOFBcast( + name="hi", + mass=dof_ary, + momentum=make_obj_array([dof_ary, dof_ary]), + enthalpy=dof_ary) + + mc_op = mcdofbcast + bcast(DOFArray(actx, (actx_ary,))) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + def _actx_allows_scalar_broadcast(actx): if not isinstance(actx, PyOpenCLArrayContext): return True diff --git a/test/testlib.py b/test/testlib.py index da33deae..dc275099 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -23,6 +23,7 @@ THE SOFTWARE. """ from dataclasses import dataclass +from numbers import Number import numpy as np @@ -58,6 +59,14 @@ def __init__(self, actx, data): if not isinstance(data, tuple): raise TypeError("'data' argument must be a tuple") + if actx is not None: + for ary in data: + if ( + ary is not None + and not isinstance(ary, ( + *actx.array_types, np.ndarray, Number))): + raise TypeError(f"invalid data array type {type(ary)}.") + self.array_context = actx self.data = data