Skip to content

Commit c77d1ef

Browse files
authored
Implement Scan based filter helper (#1717)
1 parent 0b731c2 commit c77d1ef

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

pytensor/scan/views.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,63 @@ def foldr(
198198
name=name,
199199
return_updates=return_updates,
200200
)
201+
202+
203+
def filter(
204+
fn,
205+
sequences,
206+
non_sequences=None,
207+
go_backwards=False,
208+
mode=None,
209+
name=None,
210+
):
211+
"""Construct a `Scan` `Op` that functions like `filter`.
212+
213+
Parameters
214+
----------
215+
fn : callable
216+
Predicate function returning a boolean tensor.
217+
sequences : list
218+
Sequences to filter.
219+
non_sequences : list
220+
Non-iterated arguments passed to `fn`.
221+
go_backwards : bool
222+
Whether to iterate in reverse.
223+
mode : str or None
224+
See ``scan``.
225+
name : str or None
226+
See ``scan``.
227+
228+
Notes
229+
-----
230+
If the predicate function `fn` returns multiple boolean masks (one per sequence),
231+
each mask will be applied to its corresponding sequence. If it returns a single mask,
232+
that mask will be broadcast to all sequences.
233+
"""
234+
mask, _ = scan(
235+
fn=fn,
236+
sequences=sequences,
237+
outputs_info=None,
238+
non_sequences=non_sequences,
239+
go_backwards=go_backwards,
240+
mode=mode,
241+
name=name,
242+
)
243+
244+
if isinstance(mask, (list, tuple)):
245+
# One mask per sequence
246+
if not isinstance(sequences, (list, tuple)):
247+
raise TypeError(
248+
"If multiple masks are returned, sequences must be a list or tuple."
249+
)
250+
if len(mask) != len(sequences):
251+
raise ValueError("Number of masks must match number of sequences.")
252+
filtered_sequences = [seq[m] for seq, m in zip(sequences, mask)]
253+
else:
254+
# Single mask applied to all sequences
255+
if isinstance(sequences, (list, tuple)):
256+
filtered_sequences = [seq[mask] for seq in sequences]
257+
else:
258+
filtered_sequences = sequences[mask]
259+
260+
return filtered_sequences

tests/scan/test_views.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytensor.tensor as pt
55
from pytensor import config, function, grad, shared
66
from pytensor.compile.mode import FAST_RUN
7+
from pytensor.scan.views import filter as pt_filter
78
from pytensor.scan.views import foldl, foldr
89
from pytensor.scan.views import map as pt_map
910
from pytensor.scan.views import reduce as pt_reduce
@@ -166,3 +167,42 @@ def test_foldr_memory_consumption(return_updates):
166167
gx = grad(o, x)
167168
f2 = function([], gx)
168169
utt.assert_allclose(f2(), np.ones((10,)))
170+
171+
172+
def test_filter():
173+
v = pt.vector("v")
174+
175+
def fn(x):
176+
return pt.eq(x % 2, 0)
177+
178+
filtered = pt_filter(fn, v)
179+
f = function([v], filtered, allow_input_downcast=True)
180+
181+
rng = np.random.default_rng(utt.fetch_seed())
182+
vals = rng.integers(0, 10, size=(10,))
183+
expected = vals[vals % 2 == 0]
184+
result = f(vals)
185+
utt.assert_allclose(expected, result)
186+
187+
188+
def test_filter_multiple_masks():
189+
v1 = pt.vector("v1")
190+
v2 = pt.vector("v2")
191+
192+
def fn(x1, x2):
193+
# Mask v1 for even numbers, mask v2 for numbers > 5
194+
return pt.eq(x1 % 2, 0), pt.gt(x2, 5)
195+
196+
filtered_v1, filtered_v2 = pt_filter(fn, [v1, v2])
197+
f = function([v1, v2], [filtered_v1, filtered_v2], allow_input_downcast=True)
198+
199+
vals1 = np.arange(10)
200+
vals2 = np.arange(10)
201+
202+
expected_v1 = vals1[vals1 % 2 == 0]
203+
expected_v2 = vals2[vals2 > 5]
204+
205+
result_v1, result_v2 = f(vals1, vals2)
206+
207+
utt.assert_allclose(expected_v1, result_v1)
208+
utt.assert_allclose(expected_v2, result_v2)

0 commit comments

Comments
 (0)