Skip to content

Commit e7e5b37

Browse files
committed
add a leaf_class parameter to flatten
1 parent e8c74a6 commit e7e5b37

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

arraycontext/container/traversal.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,10 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
583583

584584
# {{{ flatten / unflatten
585585

586-
def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
586+
def flatten(
587+
ary: ArrayOrContainerT, actx: ArrayContext, *,
588+
leaf_class: Optional[type] = None,
589+
) -> Any:
587590
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
588591
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
589592
@@ -629,12 +632,36 @@ def _flatten(subary: ArrayOrContainerT) -> None:
629632
for _, isubary in iterable:
630633
_flatten(isubary)
631634

632-
_flatten(ary)
635+
def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any:
636+
nonlocal result, common_dtype
637+
result = []
638+
common_dtype = None
639+
640+
_flatten(ary)
641+
642+
if len(result) == 1:
643+
return result[0]
644+
else:
645+
return actx.np.concatenate(result)
646+
647+
def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any:
648+
if type(subary) is leaf_class:
649+
return _flatten_without_leaf_class(subary)
650+
651+
try:
652+
iterable = serialize_container(subary)
653+
except NotAnArrayContainerError:
654+
return subary
655+
else:
656+
return deserialize_container(subary, [
657+
(key, _flatten_without_leaf_class(isubary))
658+
for key, isubary in iterable
659+
])
633660

634-
if len(result) == 1:
635-
return result[0]
661+
if leaf_class is None:
662+
return _flatten_without_leaf_class(ary)
636663
else:
637-
return actx.np.concatenate(result)
664+
return _flatten_with_leaf_class(ary)
638665

639666

640667
def unflatten(

run-pylint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ if [[ -f .pylintrc-local.yml ]]; then
2020
PYLINT_RUNNER_ARGS+=" --yaml-rcfile=.pylintrc-local.yml"
2121
fi
2222

23-
python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@"
23+
PYTHONWARNINGS=ignore python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@"

test/test_arraycontext.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,26 @@ def test_flatten_array_container_failure(actx_factory):
10641064
# cannot unflatten partially
10651065
unflatten(ary, flat_ary[:-1], actx)
10661066

1067+
1068+
def test_flatten_with_leaf_class(actx_factory):
1069+
actx = actx_factory()
1070+
1071+
from arraycontext import flatten
1072+
arys = _get_test_containers(actx, shapes=512)
1073+
1074+
flat = flatten(arys[0], actx, leaf_class=DOFArray)
1075+
assert isinstance(flat, actx.array_types)
1076+
1077+
flat = flatten(arys[1], actx, leaf_class=DOFArray)
1078+
assert isinstance(flat, np.ndarray) and flat.dtype == object
1079+
assert all(isinstance(entry, actx.array_types) for entry in flat)
1080+
1081+
flat = flatten(arys[3], actx, leaf_class=DOFArray)
1082+
assert isinstance(flat, MyContainer)
1083+
assert isinstance(flat.mass, actx.array_types)
1084+
assert isinstance(flat.enthalpy, actx.array_types)
1085+
assert all(isinstance(entry, actx.array_types) for entry in flat.momentum)
1086+
10671087
# }}}
10681088

10691089

0 commit comments

Comments
 (0)