Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,19 +583,31 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:

# {{{ flatten / unflatten

def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
def flatten(
ary: ArrayOrContainerT, actx: ArrayContext, *,
leaf_class: Optional[type] = None,
) -> Any:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.

The operation requires :attr:`arraycontext.ArrayContext.np` to have
``ravel`` and ``concatenate`` methods implemented. The order in which the
individual leaf arrays appear in the final array is dependent on the order
given by :func:`~arraycontext.serialize_container`.

If *leaf_class* is given, then :func:`unflatten` will not be able to recover
the original *ary*.

:arg leaf_class: an :class:`~arraycontext.ArrayContainer` class on which
the recursion is stopped (subclasses are not considered). If given, only
the entries of this type are flattened and the rest of the tree
structure is left as is. By default, the recursion is stopped when
a non-:class:`~arraycontext.ArrayContainer` is found, which results in
the whole input container *ary* being flattened.
"""
common_dtype = None
result: List[Any] = []

def _flatten(subary: ArrayOrContainerT) -> None:
def _flatten(subary: ArrayOrContainerT) -> List[Any]:
nonlocal common_dtype

try:
Expand Down Expand Up @@ -624,17 +636,40 @@ def _flatten(subary: ArrayOrContainerT) -> None:
"This functionality needs to be implemented by the "
"array context.") from exc

result.append(flat_subary)
result = [flat_subary]
else:
result = []
for _, isubary in iterable:
_flatten(isubary)
result.extend(_flatten(isubary))

return result

def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any:
result = _flatten(subary)

if len(result) == 1:
return result[0]
else:
return actx.np.concatenate(result)

def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any:
if type(subary) is leaf_class:
return _flatten_without_leaf_class(subary)

_flatten(ary)
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
return subary
else:
return deserialize_container(subary, [
(key, _flatten_with_leaf_class(isubary))
for key, isubary in iterable
])

if len(result) == 1:
return result[0]
if leaf_class is None:
return _flatten_without_leaf_class(ary)
else:
return actx.np.concatenate(result)
return _flatten_with_leaf_class(ary)


def unflatten(
Expand All @@ -647,6 +682,8 @@ def unflatten(
The order and sizes of each slice into *ary* are determined by the
array container *template*.

:arg ary: a flat one-dimensional array with a size that matches the
number of entries in *template*.
:arg strict: if *True* additional :class:`~numpy.dtype` and stride
checking is performed on the unflattened array. Otherwise, these
checks are skipped.
Expand Down
2 changes: 1 addition & 1 deletion run-pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ if [[ -f .pylintrc-local.yml ]]; then
PYLINT_RUNNER_ARGS+=" --yaml-rcfile=.pylintrc-local.yml"
fi

python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@"
PYTHONWARNINGS=ignore python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@"
28 changes: 28 additions & 0 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def _deserialize_init_arrays_code(cls, template_instance_name, args):
# Why tuple([...])? https://stackoverflow.com/a/48592299
return (f"{template_instance_name}.array_context, tuple([{arg}])")

@property
def size(self):
return sum(ary.size for ary in self.data)

@property
def real(self):
return DOFArray(self.array_context, tuple([subary.real for subary in self]))
Expand Down Expand Up @@ -1064,6 +1068,30 @@ def test_flatten_array_container_failure(actx_factory):
# cannot unflatten partially
unflatten(ary, flat_ary[:-1], actx)


def test_flatten_with_leaf_class(actx_factory):
actx = actx_factory()

from arraycontext import flatten
arys = _get_test_containers(actx, shapes=512)

flat = flatten(arys[0], actx, leaf_class=DOFArray)
assert isinstance(flat, actx.array_types)
assert flat.shape == (arys[0].size,)

flat = flatten(arys[1], actx, leaf_class=DOFArray)
assert isinstance(flat, np.ndarray) and flat.dtype == object
assert all(isinstance(entry, actx.array_types) for entry in flat)
assert all(entry.shape == (arys[0].size,) for entry in flat)

flat = flatten(arys[3], actx, leaf_class=DOFArray)
assert isinstance(flat, MyContainer)
assert isinstance(flat.mass, actx.array_types)
assert flat.mass.shape == (arys[3].mass.size,)
assert isinstance(flat.enthalpy, actx.array_types)
assert flat.enthalpy.shape == (arys[3].enthalpy.size,)
assert all(isinstance(entry, actx.array_types) for entry in flat.momentum)

# }}}


Expand Down