Skip to content

Commit 9716ce6

Browse files
CopilotricardoV94
andcommitted
Update dispatch functions and rewrite rules for new AdvancedSubtensor interface, store expected_inputs_len
Co-authored-by: ricardoV94 <[email protected]>
1 parent d30acca commit 9716ce6

File tree

5 files changed

+127
-45
lines changed

5 files changed

+127
-45
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
7777

7878
@jax_funcify.register(AdvancedIncSubtensor)
7979
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
80+
idx_list = getattr(op, "idx_list", None)
81+
8082
if getattr(op, "set_instead_of_inc", False):
8183

8284
def jax_fn(x, indices, y):
@@ -87,8 +89,11 @@ def jax_fn(x, indices, y):
8789
def jax_fn(x, indices, y):
8890
return x.at[indices].add(y)
8991

90-
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
91-
return jax_fn(x, ilist, y)
92+
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
93+
indices = indices_from_subtensor(ilist, idx_list)
94+
if len(indices) == 1:
95+
indices = indices[0]
96+
return jax_fn(x, indices, y)
9297

9398
return advancedincsubtensor
9499

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,28 +107,30 @@ def {function_name}({", ".join(input_names)}):
107107
@numba_funcify.register(AdvancedIncSubtensor)
108108
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
109109
if isinstance(op, AdvancedSubtensor):
110-
x, y, idxs = node.inputs[0], None, node.inputs[1:]
110+
x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:]
111111
else:
112-
x, y, *idxs = node.inputs
113-
114-
basic_idxs = [
115-
idx
116-
for idx in idxs
117-
if (
118-
isinstance(idx.type, NoneTypeT)
119-
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
120-
)
121-
]
122-
adv_idxs = [
123-
{
124-
"axis": i,
125-
"dtype": idx.type.dtype,
126-
"bcast": idx.type.broadcastable,
127-
"ndim": idx.type.ndim,
128-
}
129-
for i, idx in enumerate(idxs)
130-
if isinstance(idx.type, TensorType)
131-
]
112+
x, y, *tensor_inputs = node.inputs
113+
114+
# Reconstruct indexing information from idx_list and tensor inputs
115+
basic_idxs = []
116+
adv_idxs = []
117+
input_idx = 0
118+
119+
for i, entry in enumerate(op.idx_list):
120+
if isinstance(entry, slice):
121+
# Basic slice index
122+
basic_idxs.append(entry)
123+
elif isinstance(entry, Type):
124+
# Advanced tensor index
125+
if input_idx < len(tensor_inputs):
126+
idx_input = tensor_inputs[input_idx]
127+
adv_idxs.append({
128+
"axis": i,
129+
"dtype": idx_input.type.dtype,
130+
"bcast": idx_input.type.broadcastable,
131+
"ndim": idx_input.type.ndim,
132+
})
133+
input_idx += 1
132134

133135
# Special implementation for consecutive integer vector indices
134136
if (

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def makeslice(start, stop, step):
6363
@pytorch_funcify.register(AdvancedSubtensor1)
6464
@pytorch_funcify.register(AdvancedSubtensor)
6565
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
66-
def advsubtensor(x, *indices):
66+
idx_list = getattr(op, "idx_list", None)
67+
68+
def advsubtensor(x, *flattened_indices):
69+
indices = indices_from_subtensor(flattened_indices, idx_list)
6770
check_negative_steps(indices)
6871
return x[indices]
6972

@@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices):
102105
@pytorch_funcify.register(AdvancedIncSubtensor)
103106
@pytorch_funcify.register(AdvancedIncSubtensor1)
104107
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
108+
idx_list = getattr(op, "idx_list", None)
105109
inplace = op.inplace
106110
ignore_duplicates = getattr(op, "ignore_duplicates", False)
107111

108112
if op.set_instead_of_inc:
109113

110-
def adv_set_subtensor(x, y, *indices):
114+
def adv_set_subtensor(x, y, *flattened_indices):
115+
indices = indices_from_subtensor(flattened_indices, idx_list)
111116
check_negative_steps(indices)
112117
if isinstance(op, AdvancedIncSubtensor1):
113118
op._check_runtime_broadcasting(node, x, y, indices)
@@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices):
120125

121126
elif ignore_duplicates:
122127

123-
def adv_inc_subtensor_no_duplicates(x, y, *indices):
128+
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
129+
indices = indices_from_subtensor(flattened_indices, idx_list)
124130
check_negative_steps(indices)
125131
if isinstance(op, AdvancedIncSubtensor1):
126132
op._check_runtime_broadcasting(node, x, y, indices)
@@ -132,13 +138,16 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
132138
return adv_inc_subtensor_no_duplicates
133139

134140
else:
135-
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
141+
# Check if we have slice indexing in idx_list
142+
has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
143+
if has_slice_indexing:
136144
raise NotImplementedError(
137145
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
138146
)
139147

140-
def adv_inc_subtensor(x, y, *indices):
141-
# Not needed because slices aren't supported
148+
def adv_inc_subtensor(x, y, *flattened_indices):
149+
indices = indices_from_subtensor(flattened_indices, idx_list)
150+
# Not needed because slices aren't supported in this path
142151
# check_negative_steps(indices)
143152
if not inplace:
144153
x = x.clone()

pytensor/tensor/rewriting/subtensor.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,18 @@ def local_replace_AdvancedSubtensor(fgraph, node):
228228
return
229229

230230
indexed_var = node.inputs[0]
231-
indices = node.inputs[1:]
231+
tensor_inputs = node.inputs[1:]
232+
233+
# Reconstruct indices from idx_list and tensor inputs
234+
indices = []
235+
input_idx = 0
236+
for entry in node.op.idx_list:
237+
if isinstance(entry, slice):
238+
indices.append(entry)
239+
elif isinstance(entry, Type):
240+
if input_idx < len(tensor_inputs):
241+
indices.append(tensor_inputs[input_idx])
242+
input_idx += 1
232243

233244
axis = get_advsubtensor_axis(indices)
234245

@@ -255,7 +266,18 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
255266

256267
res = node.inputs[0]
257268
val = node.inputs[1]
258-
indices = node.inputs[2:]
269+
tensor_inputs = node.inputs[2:]
270+
271+
# Reconstruct indices from idx_list and tensor inputs
272+
indices = []
273+
input_idx = 0
274+
for entry in node.op.idx_list:
275+
if isinstance(entry, slice):
276+
indices.append(entry)
277+
elif isinstance(entry, Type):
278+
if input_idx < len(tensor_inputs):
279+
indices.append(tensor_inputs[input_idx])
280+
input_idx += 1
259281

260282
axis = get_advsubtensor_axis(indices)
261283

@@ -1751,9 +1773,22 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17511773
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
17521774
"""
17531775
if isinstance(node.op, AdvancedSubtensor):
1754-
x, *idxs = node.inputs
1776+
x = node.inputs[0]
1777+
tensor_inputs = node.inputs[1:]
17551778
else:
1756-
x, y, *idxs = node.inputs
1779+
x, y = node.inputs[0], node.inputs[1]
1780+
tensor_inputs = node.inputs[2:]
1781+
1782+
# Reconstruct indices from idx_list and tensor inputs
1783+
idxs = []
1784+
input_idx = 0
1785+
for entry in node.op.idx_list:
1786+
if isinstance(entry, slice):
1787+
idxs.append(entry)
1788+
elif isinstance(entry, Type):
1789+
if input_idx < len(tensor_inputs):
1790+
idxs.append(tensor_inputs[input_idx])
1791+
input_idx += 1
17571792

17581793
if any(
17591794
(
@@ -1791,12 +1826,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17911826
new_idxs[bool_idx_pos] = raveled_bool_idx
17921827

17931828
if isinstance(node.op, AdvancedSubtensor):
1794-
new_out = node.op(raveled_x, *new_idxs)
1829+
# Create new AdvancedSubtensor with updated idx_list
1830+
new_idx_list = list(node.op.idx_list)
1831+
new_tensor_inputs = list(tensor_inputs)
1832+
1833+
# Update the idx_list and tensor_inputs for the raveled boolean index
1834+
input_idx = 0
1835+
for i, entry in enumerate(node.op.idx_list):
1836+
if isinstance(entry, Type):
1837+
if input_idx == bool_idx_pos:
1838+
new_tensor_inputs[input_idx] = raveled_bool_idx
1839+
input_idx += 1
1840+
1841+
new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs)
17951842
else:
1843+
# Create new AdvancedIncSubtensor with updated idx_list
1844+
new_idx_list = list(node.op.idx_list)
1845+
new_tensor_inputs = list(tensor_inputs)
1846+
1847+
# Update the tensor_inputs for the raveled boolean index
1848+
input_idx = 0
1849+
for i, entry in enumerate(node.op.idx_list):
1850+
if isinstance(entry, Type):
1851+
if input_idx == bool_idx_pos:
1852+
new_tensor_inputs[input_idx] = raveled_bool_idx
1853+
input_idx += 1
1854+
17961855
# The dimensions of y that correspond to the boolean indices
17971856
# must already be raveled in the original graph, so we don't need to do anything to it
1798-
new_out = node.op(raveled_x, y, *new_idxs)
1799-
# But we must reshape the output to math the original shape
1857+
new_out = AdvancedIncSubtensor(
1858+
new_idx_list,
1859+
inplace=node.op.inplace,
1860+
set_instead_of_inc=node.op.set_instead_of_inc,
1861+
ignore_duplicates=node.op.ignore_duplicates
1862+
)(raveled_x, y, *new_tensor_inputs)
1863+
# But we must reshape the output to match the original shape
18001864
new_out = new_out.reshape(x_shape)
18011865

18021866
return [copy_stack_trace(node.outputs[0], new_out)]

pytensor/tensor/subtensor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,10 +2805,12 @@ def __init__(self, idx_list):
28052805
Parameters
28062806
----------
28072807
idx_list : tuple
2808-
List of indices where slices and newaxis are stored as-is,
2808+
List of indices where slices are stored as-is,
28092809
and numerical indices are replaced by their types.
28102810
"""
28112811
self.idx_list = tuple(map(index_vars_to_types, idx_list))
2812+
# Store expected number of tensor inputs for validation
2813+
self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)))
28122814

28132815
def make_node(self, x, *inputs):
28142816
"""
@@ -2824,15 +2826,14 @@ def make_node(self, x, *inputs):
28242826
inputs = tuple(as_tensor_variable(a) for a in inputs)
28252827

28262828
idx_list = list(self.idx_list)
2827-
if (len([entry for entry in idx_list if entry is not np.newaxis]) > x.type.ndim):
2829+
if len(idx_list) > x.type.ndim:
28282830
raise IndexError("too many indices for array")
28292831

28302832
# Validate input count matches expected from idx_list
2831-
expected_inputs = get_slice_elements(idx_list, lambda entry: isinstance(entry, Type))
2832-
if len(inputs) != len(expected_inputs):
2833-
raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}")
2833+
if len(inputs) != self.expected_inputs_len:
2834+
raise ValueError(f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}")
28342835

2835-
# Build explicit_indices for shape inference (newaxis handled by __getitem__)
2836+
# Build explicit_indices for shape inference
28362837
explicit_indices = []
28372838
input_idx = 0
28382839

@@ -3202,6 +3203,8 @@ def __init__(
32023203
self, idx_list, inplace=False, set_instead_of_inc=False, ignore_duplicates=False
32033204
):
32043205
self.idx_list = tuple(map(index_vars_to_types, idx_list))
3206+
# Store expected number of tensor inputs for validation
3207+
self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)))
32053208
self.set_instead_of_inc = set_instead_of_inc
32063209
self.inplace = inplace
32073210
if inplace:
@@ -3220,9 +3223,8 @@ def make_node(self, x, y, *inputs):
32203223
y = as_tensor_variable(y)
32213224

32223225
# Validate that we have the right number of tensor inputs for our idx_list
3223-
expected_tensor_inputs = sum(1 for entry in self.idx_list if isinstance(entry, Type))
3224-
if len(inputs) != expected_tensor_inputs:
3225-
raise ValueError(f"Expected {expected_tensor_inputs} tensor inputs but got {len(inputs)}")
3226+
if len(inputs) != self.expected_inputs_len:
3227+
raise ValueError(f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}")
32263228

32273229
new_inputs = []
32283230
for inp in inputs:

0 commit comments

Comments
 (0)