From 52c66025c72295350d9dbf88b74bfb861a922f75 Mon Sep 17 00:00:00 2001 From: Joshix Date: Sat, 29 Jun 2024 16:00:00 +0000 Subject: [PATCH] multi_reduce --- typed_stream/_impl/stream.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/typed_stream/_impl/stream.py b/typed_stream/_impl/stream.py index c6868bf..f667ff1 100644 --- a/typed_stream/_impl/stream.py +++ b/typed_stream/_impl/stream.py @@ -13,7 +13,7 @@ import operator import sys import typing -from collections.abc import Callable, Iterable, Iterator, Mapping +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from numbers import Number, Real from types import EllipsisType @@ -965,6 +965,32 @@ def min( raise exceptions.StreamEmptyError() from None return self._finish(min_, close_source=True) + def multi_reduce(self, *funs: Callable[[T, T], T]) -> Sequence[T]: + """Reduce the values of this stream with multiple functions. + + >>> data = [1, 2, 3, 4, 5] + >>> sum_, count = Stream(data).multi_reduce(operator.add, lambda x, _: x + 1) + >>> sum_ + 15 + >>> count + 5 + """ + iterator = iter(self._data) + try: + first_value = next(iterator) + except StopIteration: + raise exceptions.StreamEmptyError() from None + + def multi_update(acc: list[T], value: T) -> list[T]: + for idx, fun in enumerate(funs): + acc[idx] = fun(acc[idx], value) + return acc + + return self._finish( + functools.reduce(multi_update, iterator, [first_value] * len(funs)), + True, + ) + @typing.overload def nth(self, index: int, /) -> T: ... # noqa: D102