Skip to content

Commit b7548b0

Browse files
committed
add leaf_class to decorators
1 parent 3974d67 commit b7548b0

File tree

1 file changed

+46
-11
lines changed

1 file changed

+46
-11
lines changed

arraycontext/container/traversal.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,31 @@ def rec_map_array_container(
268268
return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)
269269

270270

271+
def _mapped_over_array_containers_factory(
272+
leaf_class: Optional[type] = None) -> Callable[[Callable[
273+
[Any], Any]], Callable[[ArrayOrContainerT], ArrayOrContainerT]]:
274+
"""Decorator factory around :func:`rec_map_array_container`."""
275+
def decorator(f: Callable[[Any], Any]) -> Callable[
276+
[ArrayOrContainerT], ArrayOrContainerT]:
277+
wrapper = partial(rec_map_array_container, f, leaf_class=leaf_class)
278+
update_wrapper(wrapper, f)
279+
return wrapper
280+
return decorator
281+
282+
271283
def mapped_over_array_containers(
272-
f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
284+
f: Optional[Callable[[Any], Any]] = None,
285+
leaf_class: Optional[type] = None) -> Union[
286+
Callable[[ArrayOrContainerT], ArrayOrContainerT],
287+
Callable[
288+
Callable[[Any], Any],
289+
Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
273290
"""Decorator around :func:`rec_map_array_container`."""
274-
wrapper = partial(rec_map_array_container, f)
275-
update_wrapper(wrapper, f)
276-
return wrapper
291+
decorator = _mapped_over_array_containers_factory(leaf_class=leaf_class)
292+
if f is not None:
293+
return decorator(f)
294+
else:
295+
return decorator
277296

278297

279298
def rec_multimap_array_container(
@@ -291,16 +310,32 @@ def rec_multimap_array_container(
291310
f, *args, leaf_cls=leaf_class, recursive=True)
292311

293312

313+
def _multimapped_over_array_containers_factory(
314+
leaf_class: Optional[type] = None) -> Callable[
315+
[Callable[..., Any]], Callable[..., Any]]:
316+
"""Decorator factory around :func:`rec_multimap_array_container`."""
317+
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
318+
# can't use functools.partial, because its result is insufficiently
319+
# function-y to be used as a method definition.
320+
def wrapper(*args: Any) -> Any:
321+
return rec_multimap_array_container(f, *args, leaf_class=leaf_class)
322+
update_wrapper(wrapper, f)
323+
return wrapper
324+
return decorator
325+
326+
294327
def multimapped_over_array_containers(
295-
f: Callable[..., Any]) -> Callable[..., Any]:
328+
f: Optional[Callable[..., Any]] = None,
329+
leaf_class: Optional[type] = None) -> Union[
330+
Callable[..., Any],
331+
Callable[[Callable[..., Any]], Callable[..., Any]]]:
296332
"""Decorator around :func:`rec_multimap_array_container`."""
297-
# can't use functools.partial, because its result is insufficiently
298-
# function-y to be used as a method definition.
299-
def wrapper(*args: Any) -> Any:
300-
return rec_multimap_array_container(f, *args)
333+
decorator = _multimapped_over_array_containers_factory(leaf_class=leaf_class)
334+
if f is not None:
335+
return decorator(f)
336+
else:
337+
return decorator
301338

302-
update_wrapper(wrapper, f)
303-
return wrapper
304339

305340
# }}}
306341

0 commit comments

Comments
 (0)