Skip to content

Commit 968e2be

Browse files
authored
Add several itertools recipes to the test_cases directory (#10992)
1 parent 6afb72f commit 968e2be

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
"""Type-annotated versions of the recipes from the itertools docs.
2+
3+
These are all meant to be examples of idiomatic itertools usage,
4+
so they should all type-check without error.
5+
"""
6+
from __future__ import annotations
7+
8+
import collections
9+
import math
10+
import operator
11+
import sys
12+
from itertools import chain, combinations, count, cycle, filterfalse, islice, repeat, starmap, tee, zip_longest
13+
from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
14+
from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack
15+
16+
_T = TypeVar("_T")
17+
_T1 = TypeVar("_T1")
18+
_T2 = TypeVar("_T2")
19+
_HashableT = TypeVar("_HashableT", bound=Hashable)
20+
_Ts = TypeVarTuple("_Ts")
21+
22+
23+
def take(n: int, iterable: Iterable[_T]) -> list[_T]:
24+
"Return first n items of the iterable as a list"
25+
return list(islice(iterable, n))
26+
27+
28+
# Note: the itertools docs uses the parameter name "iterator",
29+
# but the function actually accepts any iterable
30+
# as its second argument
31+
def prepend(value: _T1, iterator: Iterable[_T2]) -> Iterator[_T1 | _T2]:
32+
"Prepend a single value in front of an iterator"
33+
# prepend(1, [2, 3, 4]) --> 1 2 3 4
34+
return chain([value], iterator)
35+
36+
37+
def tabulate(function: Callable[[int], _T], start: int = 0) -> Iterator[_T]:
38+
"Return function(0), function(1), ..."
39+
return map(function, count(start))
40+
41+
42+
def repeatfunc(func: Callable[[Unpack[_Ts]], _T], times: int | None = None, *args: Unpack[_Ts]) -> Iterator[_T]:
43+
"""Repeat calls to func with specified arguments.
44+
45+
Example: repeatfunc(random.random)
46+
"""
47+
if times is None:
48+
return starmap(func, repeat(args))
49+
return starmap(func, repeat(args, times))
50+
51+
52+
def flatten(list_of_lists: Iterable[Iterable[_T]]) -> Iterator[_T]:
53+
"Flatten one level of nesting"
54+
return chain.from_iterable(list_of_lists)
55+
56+
57+
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]:
58+
"Returns the sequence elements n times"
59+
return chain.from_iterable(repeat(tuple(iterable), n))
60+
61+
62+
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]:
63+
"Return an iterator over the last n items"
64+
# tail(3, 'ABCDEFG') --> E F G
65+
return iter(collections.deque(iterable, maxlen=n))
66+
67+
68+
# This function *accepts* any iterable,
69+
# but it only *makes sense* to use it with an iterator
70+
def consume(iterator: Iterator[object], n: int | None = None) -> None:
71+
"Advance the iterator n-steps ahead. If n is None, consume entirely."
72+
# Use functions that consume iterators at C speed.
73+
if n is None:
74+
# feed the entire iterator into a zero-length deque
75+
collections.deque(iterator, maxlen=0)
76+
else:
77+
# advance to the empty slice starting at position n
78+
next(islice(iterator, n, n), None)
79+
80+
81+
@overload
82+
def nth(iterable: Iterable[_T], n: int, default: None = None) -> _T | None:
83+
...
84+
85+
86+
@overload
87+
def nth(iterable: Iterable[_T], n: int, default: _T1) -> _T | _T1:
88+
...
89+
90+
91+
def nth(iterable: Iterable[object], n: int, default: object = None) -> object:
92+
"Returns the nth item or a default value"
93+
return next(islice(iterable, n, None), default)
94+
95+
96+
@overload
97+
def quantify(iterable: Iterable[object]) -> int:
98+
...
99+
100+
101+
@overload
102+
def quantify(iterable: Iterable[_T], pred: Callable[[_T], bool]) -> int:
103+
...
104+
105+
106+
def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> int:
107+
"Given a predicate that returns True or False, count the True results."
108+
return sum(map(pred, iterable))
109+
110+
111+
@overload
112+
def first_true(
113+
iterable: Iterable[_T], default: Literal[False] = False, pred: Callable[[_T], bool] | None = None
114+
) -> _T | Literal[False]:
115+
...
116+
117+
118+
@overload
119+
def first_true(iterable: Iterable[_T], default: _T1, pred: Callable[[_T], bool] | None = None) -> _T | _T1:
120+
...
121+
122+
123+
def first_true(iterable: Iterable[object], default: object = False, pred: Callable[[Any], bool] | None = None) -> object:
124+
"""Returns the first true value in the iterable.
125+
If no true value is found, returns *default*
126+
If *pred* is not None, returns the first item
127+
for which pred(item) is true.
128+
"""
129+
# first_true([a,b,c], x) --> a or b or c or x
130+
# first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
131+
return next(filter(pred, iterable), default)
132+
133+
134+
_ExceptionOrExceptionTuple: TypeAlias = Union[Type[BaseException], Tuple[Type[BaseException], ...]]
135+
136+
137+
@overload
138+
def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: None = None) -> Iterator[_T]:
139+
...
140+
141+
142+
@overload
143+
def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: Callable[[], _T1]) -> Iterator[_T | _T1]:
144+
...
145+
146+
147+
def iter_except(
148+
func: Callable[[], object], exception: _ExceptionOrExceptionTuple, first: Callable[[], object] | None = None
149+
) -> Iterator[object]:
150+
"""Call a function repeatedly until an exception is raised.
151+
Converts a call-until-exception interface to an iterator interface.
152+
Like builtins.iter(func, sentinel) but uses an exception instead
153+
of a sentinel to end the loop.
154+
Examples:
155+
iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator
156+
iter_except(d.popitem, KeyError) # non-blocking dict iterator
157+
iter_except(d.popleft, IndexError) # non-blocking deque iterator
158+
iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue
159+
iter_except(s.pop, KeyError) # non-blocking set iterator
160+
"""
161+
try:
162+
if first is not None:
163+
yield first() # For database APIs needing an initial cast to db.first()
164+
while True:
165+
yield func()
166+
except exception:
167+
pass
168+
169+
170+
def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]:
171+
# sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
172+
it = iter(iterable)
173+
window = collections.deque(islice(it, n - 1), maxlen=n)
174+
for x in it:
175+
window.append(x)
176+
yield tuple(window)
177+
178+
179+
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]:
180+
"roundrobin('ABC', 'D', 'EF') --> A D E B F C"
181+
# Recipe credited to George Sakkis
182+
num_active = len(iterables)
183+
nexts: Iterator[Callable[[], _T]] = cycle(iter(it).__next__ for it in iterables)
184+
while num_active:
185+
try:
186+
for next in nexts:
187+
yield next()
188+
except StopIteration:
189+
# Remove the iterator we just exhausted from the cycle.
190+
num_active -= 1
191+
nexts = cycle(islice(nexts, num_active))
192+
193+
194+
def partition(pred: Callable[[_T], bool], iterable: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]:
195+
"""Partition entries into false entries and true entries.
196+
If *pred* is slow, consider wrapping it with functools.lru_cache().
197+
"""
198+
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
199+
t1, t2 = tee(iterable)
200+
return filterfalse(pred, t1), filter(pred, t2)
201+
202+
203+
def subslices(seq: Sequence[_T]) -> Iterator[Sequence[_T]]:
204+
"Return all contiguous non-empty subslices of a sequence"
205+
# subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
206+
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
207+
return map(operator.getitem, repeat(seq), slices)
208+
209+
210+
def before_and_after(predicate: Callable[[_T], bool], it: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]:
211+
"""Variant of takewhile() that allows complete
212+
access to the remainder of the iterator.
213+
>>> it = iter('ABCdEfGhI')
214+
>>> all_upper, remainder = before_and_after(str.isupper, it)
215+
>>> ''.join(all_upper)
216+
'ABC'
217+
>>> ''.join(remainder) # takewhile() would lose the 'd'
218+
'dEfGhI'
219+
Note that the first iterator must be fully
220+
consumed before the second iterator can
221+
generate valid results.
222+
"""
223+
it = iter(it)
224+
transition: list[_T] = []
225+
226+
def true_iterator() -> Iterator[_T]:
227+
for elem in it:
228+
if predicate(elem):
229+
yield elem
230+
else:
231+
transition.append(elem)
232+
return
233+
234+
def remainder_iterator() -> Iterator[_T]:
235+
yield from transition
236+
yield from it
237+
238+
return true_iterator(), remainder_iterator()
239+
240+
241+
@overload
242+
def unique_everseen(iterable: Iterable[_HashableT], key: None = None) -> Iterator[_HashableT]:
243+
...
244+
245+
246+
@overload
247+
def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable]) -> Iterator[_T]:
248+
...
249+
250+
251+
def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable] | None = None) -> Iterator[_T]:
252+
"List unique elements, preserving order. Remember all elements ever seen."
253+
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
254+
# unique_everseen('ABBcCAD', str.lower) --> A B c D
255+
seen: set[Hashable] = set()
256+
if key is None:
257+
for element in filterfalse(seen.__contains__, iterable):
258+
seen.add(element)
259+
yield element
260+
# For order preserving deduplication,
261+
# a faster but non-lazy solution is:
262+
# yield from dict.fromkeys(iterable)
263+
else:
264+
for element in iterable:
265+
k = key(element)
266+
if k not in seen:
267+
seen.add(k)
268+
yield element
269+
# For use cases that allow the last matching element to be returned,
270+
# a faster but non-lazy solution is:
271+
# t1, t2 = tee(iterable)
272+
# yield from dict(zip(map(key, t1), t2)).values()
273+
274+
275+
def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]:
276+
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
277+
s = list(iterable)
278+
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
279+
280+
281+
def polynomial_derivative(coefficients: Sequence[float]) -> list[float]:
282+
"""Compute the first derivative of a polynomial.
283+
f(x) = x³ -4x² -17x + 60
284+
f'(x) = 3x² -8x -17
285+
"""
286+
# polynomial_derivative([1, -4, -17, 60]) -> [3, -8, -17]
287+
n = len(coefficients)
288+
powers = reversed(range(1, n))
289+
return list(map(operator.mul, coefficients, powers))
290+
291+
292+
if sys.version_info >= (3, 10):
293+
294+
@overload
295+
def grouper(
296+
iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: None = None
297+
) -> Iterator[tuple[_T | None, ...]]:
298+
...
299+
300+
@overload
301+
def grouper(
302+
iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: _T1
303+
) -> Iterator[tuple[_T | _T1, ...]]:
304+
...
305+
306+
@overload
307+
def grouper(
308+
iterable: Iterable[_T], n: int, *, incomplete: Literal["strict", "ignore"], fillvalue: None = None
309+
) -> Iterator[tuple[_T, ...]]:
310+
...
311+
312+
def grouper(
313+
iterable: Iterable[object], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: object = None
314+
) -> Iterator[tuple[object, ...]]:
315+
"Collect data into non-overlapping fixed-length chunks or blocks"
316+
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
317+
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
318+
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
319+
args = [iter(iterable)] * n
320+
if incomplete == "fill":
321+
return zip_longest(*args, fillvalue=fillvalue)
322+
if incomplete == "strict":
323+
return zip(*args, strict=True)
324+
if incomplete == "ignore":
325+
return zip(*args)
326+
else:
327+
raise ValueError("Expected fill, strict, or ignore")
328+
329+
def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]:
330+
"Swap the rows and columns of the input."
331+
# transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
332+
return zip(*it, strict=True)
333+
334+
335+
if sys.version_info >= (3, 12):
336+
337+
def sum_of_squares(it: Iterable[float]) -> float:
338+
"Add up the squares of the input values."
339+
# sum_of_squares([10, 20, 30]) -> 1400
340+
return math.sumprod(*tee(it))
341+
342+
def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float]:
343+
"""Discrete linear convolution of two iterables.
344+
The kernel is fully consumed before the calculations begin.
345+
The signal is consumed lazily and can be infinite.
346+
Convolutions are mathematically commutative.
347+
If the signal and kernel are swapped,
348+
the output will be the same.
349+
Article: https://betterexplained.com/articles/intuitive-convolution/
350+
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
351+
"""
352+
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
353+
# convolve(data, [1/2, 0, -1/2]) --> 1st derivative estimate
354+
# convolve(data, [1, -2, 1]) --> 2nd derivative estimate
355+
kernel = tuple(kernel)[::-1]
356+
n = len(kernel)
357+
padded_signal = chain(repeat(0, n - 1), signal, repeat(0, n - 1))
358+
windowed_signal = sliding_window(padded_signal, n)
359+
return map(math.sumprod, repeat(kernel), windowed_signal)
360+
361+
def polynomial_eval(coefficients: Sequence[float], x: float) -> float:
362+
"""Evaluate a polynomial at a specific value.
363+
Computes with better numeric stability than Horner's method.
364+
"""
365+
# Evaluate x³ -4x² -17x + 60 at x = 2.5
366+
# polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125
367+
n = len(coefficients)
368+
if not n:
369+
return type(x)(0)
370+
powers = map(pow, repeat(x), reversed(range(n)))
371+
return math.sumprod(coefficients, powers)

0 commit comments

Comments
 (0)