Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,46 +256,70 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:

def rec_map_array_container(
f: Callable[[Any], Any],
ary: ArrayOrContainerT) -> ArrayOrContainerT:
ary: ArrayOrContainerT,
leaf_class: Optional[type] = None) -> ArrayOrContainerT:
r"""Applies *f* recursively to an :class:`ArrayContainer`.

For a non-recursive version see :func:`map_array_container`.

:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
or an instance of a base array type.
"""
return _map_array_container_impl(f, ary, recursive=True)
return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)


def mapped_over_array_containers(
f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
f: Optional[Callable[[Any], Any]] = None,
leaf_class: Optional[type] = None) -> Union[
Callable[[ArrayOrContainerT], ArrayOrContainerT],
Callable[
[Callable[[Any], Any]],
Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
"""Decorator around :func:`rec_map_array_container`."""
wrapper = partial(rec_map_array_container, f)
update_wrapper(wrapper, f)
return wrapper
def decorator(g: Callable[[Any], Any]) -> Callable[
[ArrayOrContainerT], ArrayOrContainerT]:
wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
update_wrapper(wrapper, g)
return wrapper
if f is not None:
return decorator(f)
else:
return decorator


def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
def rec_multimap_array_container(
f: Callable[..., Any],
*args: Any,
leaf_class: Optional[type] = None) -> Any:
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.

For a non-recursive version see :func:`multimap_array_container`.

:param args: all :class:`ArrayContainer` arguments must be of the same
type and with the same structure (same number of components, etc.).
"""
return _multimap_array_container_impl(f, *args, recursive=True)
return _multimap_array_container_impl(
f, *args, leaf_cls=leaf_class, recursive=True)


def multimapped_over_array_containers(
f: Callable[..., Any]) -> Callable[..., Any]:
f: Optional[Callable[..., Any]] = None,
leaf_class: Optional[type] = None) -> Union[
Callable[..., Any],
Callable[[Callable[..., Any]], Callable[..., Any]]]:
"""Decorator around :func:`rec_multimap_array_container`."""
# can't use functools.partial, because its result is insufficiently
# function-y to be used as a method definition.
def wrapper(*args: Any) -> Any:
return rec_multimap_array_container(f, *args)
def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
# can't use functools.partial, because its result is insufficiently
# function-y to be used as a method definition.
def wrapper(*args: Any) -> Any:
return rec_multimap_array_container(g, *args, leaf_class=leaf_class)
update_wrapper(wrapper, g)
return wrapper
if f is not None:
return decorator(f)
else:
return decorator

update_wrapper(wrapper, f)
return wrapper

# }}}

Expand Down Expand Up @@ -401,7 +425,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
def rec_map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainerT) -> "DeviceArray":
ary: ArrayOrContainerT,
leaf_class: Optional[type] = None) -> "DeviceArray":
"""Perform a map-reduce over array containers recursively.

:param reduce_func: callable used to reduce over the components of *ary*
Expand Down Expand Up @@ -440,22 +465,26 @@ def rec_map_reduce_array_container(
or any other such traversal.
"""
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
try:
iterable = serialize_container(_ary)
except NotAnArrayContainerError:
if type(_ary) is leaf_class:
return map_func(_ary)
else:
return reduce_func([
rec(subary) for _, subary in iterable
])
try:
iterable = serialize_container(_ary)
except NotAnArrayContainerError:
return map_func(_ary)
else:
return reduce_func([
rec(subary) for _, subary in iterable
])

return rec(ary)


def rec_multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any) -> "DeviceArray":
*args: Any,
leaf_class: Optional[type] = None) -> "DeviceArray":
r"""Perform a map-reduce over multiple array containers recursively.

:param reduce_func: callable used to reduce over the components of any
Expand All @@ -478,7 +507,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any

return _multimap_array_container_impl(
map_func, *args,
reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True)
reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True)

# }}}

Expand Down
94 changes: 86 additions & 8 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
assert result is not None


def test_container_map(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)

# {{{ check

def _check_allclose(f, arg1, arg2, atol=2.0e-14):
from arraycontext import NotAnArrayContainerError
try:
arg1_iterable = serialize_container(arg1)
arg2_iterable = serialize_container(arg2)
except NotAnArrayContainerError:
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
else:
arg1_subarrays = [
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
_check_allclose(f, subarray1, subarray2)

def func(x):
return x + 1

from arraycontext import rec_map_array_container
result = rec_map_array_container(func, 1)
assert result == 2

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = rec_map_array_container(func, ary)
_check_allclose(func, ary, result)

from arraycontext import mapped_over_array_containers

@mapped_over_array_containers
def mapped_func(x):
return func(x)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = mapped_func(ary)
_check_allclose(func, ary, result)

@mapped_over_array_containers(leaf_class=DOFArray)
def check_leaf(x):
assert isinstance(x, DOFArray)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
check_leaf(ary)

# }}}


def test_container_multimap(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
Expand All @@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
# {{{ check

def _check_allclose(f, arg1, arg2, atol=2.0e-14):
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
from arraycontext import NotAnArrayContainerError
try:
arg1_iterable = serialize_container(arg1)
arg2_iterable = serialize_container(arg2)
except NotAnArrayContainerError:
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
else:
arg1_subarrays = [
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
_check_allclose(f, subarray1, subarray2)

def func_all_scalar(x, y):
return x + y
Expand All @@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
result = rec_multimap_array_container(func_all_scalar, 1, 2)
assert result == 3

from functools import partial
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = rec_multimap_array_container(func_first_scalar, 1, ary)
rec_multimap_array_container(
partial(_check_allclose, lambda x: 1 + x),
ary, result)
_check_allclose(lambda x: 1 + x, ary, result)

result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary)
rec_multimap_array_container(
partial(_check_allclose, lambda x: 4 * x),
ary, result)
_check_allclose(lambda x: 4 * x, ary, result)

from arraycontext import multimapped_over_array_containers

@multimapped_over_array_containers
def mapped_func(a, subary1, b, subary2):
return func_multiple_scalar(a, subary1, b, subary2)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = mapped_func(2, ary, 2, ary)
_check_allclose(lambda x: 4 * x, ary, result)

@multimapped_over_array_containers(leaf_class=DOFArray)
def check_leaf(a, subary1, b, subary2):
assert isinstance(subary1, DOFArray)
assert isinstance(subary2, DOFArray)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
check_leaf(2, ary, 2, ary)

with pytest.raises(AssertionError):
rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)
Expand Down