28
28
as_tensor_variable ,
29
29
cast ,
30
30
constant ,
31
+ expand_dims ,
31
32
get_underlying_scalar_constant_value ,
32
33
moveaxis ,
33
34
ones_like ,
34
35
register_infer_shape ,
35
36
switch ,
36
37
zeros_like ,
37
38
)
38
- from pytensor .tensor .blockwise import Blockwise
39
39
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
40
40
from pytensor .tensor .exceptions import NotScalarConstantError
41
41
from pytensor .tensor .extra_ops import broadcast_arrays
45
45
Sum ,
46
46
_conj ,
47
47
_dot ,
48
- _inner_prod ,
49
- _matrix_matrix_matmul ,
50
- _matrix_vec_prod ,
51
- _vec_matrix_prod ,
48
+ _matmul ,
52
49
add ,
53
50
digamma ,
54
51
dot ,
@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
182
179
if not (
183
180
is_matrix_transpose (node .out )
184
181
and node .inputs [0 ].owner
185
- and ((dot_op := node .inputs [0 ].owner .op ) in (_dot , _matrix_matrix_matmul ))
182
+ and ((dot_op := node .inputs [0 ].owner .op ) in (_dot , _matmul ))
186
183
):
187
184
return False
188
185
@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
197
194
return ret
198
195
199
196
200
- @register_stabilize
201
- @register_specialize
202
- @node_rewriter (tracks = [Blockwise ])
203
- def local_batched_matmul_to_core_matmul (fgraph , node ):
204
- """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
197
+ def _batched_matmul_to_core_matmul (fgraph , node , allow_reshape : bool ):
198
+ """Move batch dimensions of matmul operands to core matmul
205
199
206
- Example, if x has batch dimensions, but y not:
200
+ Example, if x has batch dimensions that don't overlap with batch dimensions of y
207
201
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
208
202
209
- It also works when y has batch dimensions, but x not.
210
- """
203
+ It also works for batch dimensions of y that don't overlap with batch dimensions of x
211
204
212
- # Check whether we have a matmul operation in this node
213
- if not (
214
- isinstance (node .op .core_op , Dot )
215
- and len (node .op .inputs_sig [0 ]) == 2
216
- and len (node .op .inputs_sig [1 ]) == 2
217
- ):
218
- return None
205
+ The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
206
+ """
219
207
220
208
x , y = node .inputs
221
209
batch_ndim = node .op .batch_ndim (node )
222
210
223
- # Check if x has batch dimensions, but y not (or only broadcastable dimensions)
224
- if any (not b_dim for b_dim in x .type .broadcastable [:- 2 ]) and all (
225
- y .type .broadcastable [:- 2 ]
226
- ):
227
- x_stacked = x .reshape ((- 1 , x .shape [- 1 ]))
228
- out_stacked = x_stacked @ y .squeeze (tuple (range (batch_ndim )))
229
- out = out_stacked .reshape ((* x .shape [:- 1 ], y .shape [- 1 ]))
230
- return [out ]
231
-
232
- # Otherwise, check if y has batch dimension, but x not
233
- elif any (not b_dim for b_dim in y .type .broadcastable [:- 2 ]) and all (
234
- x .type .broadcastable [:- 2 ]
235
- ):
236
- # For the y batch case we need to first move the batch axes and then reshape
237
- # y.shape == (*b, k, n)
238
- y_tr = moveaxis (y , - 2 , 0 ) # (k, *b, n)
239
- y_stacked = y_tr .reshape ((y .shape [- 2 ], - 1 )) # (k, *b * n)
240
- out_stacked = x .squeeze (tuple (range (batch_ndim ))) @ y_stacked # (m, *b * n)
241
- out_stacked_tr = out_stacked .reshape (
242
- (x .shape [- 2 ], * y .shape [:- 2 ], y .shape [- 1 ])
243
- ) # (m, *b, n)
244
- out = moveaxis (out_stacked_tr , 0 , - 2 ) # (*b, m, n)
245
- return [out ]
246
-
247
- # Both x and y have batch dimensions, nothing to do here
248
- return None
211
+ x_axis_to_merge = [
212
+ i
213
+ for i , (bcast_x , bcast_y ) in enumerate (
214
+ zip (x .type .broadcastable [:- 2 ], y .type .broadcastable [:- 2 ])
215
+ )
216
+ if bcast_y and not bcast_x
217
+ ]
218
+
219
+ y_axis_to_merge = [
220
+ i
221
+ for i , (bcast_x , bcast_y ) in enumerate (
222
+ zip (x .type .broadcastable [:- 2 ], y .type .broadcastable [:- 2 ])
223
+ )
224
+ if bcast_x and not bcast_y
225
+ ]
226
+
227
+ if not (x_axis_to_merge or y_axis_to_merge ):
228
+ return None
229
+
230
+ x_shape = tuple (x .shape )
231
+ y_shape = tuple (y .shape )
232
+ x_is_row = x .type .broadcastable [- 2 ]
233
+ y_is_col = y .type .broadcastable [- 1 ]
234
+ n_x_axis_to_merge = len (x_axis_to_merge )
235
+ n_y_axis_to_merge = len (y_axis_to_merge )
236
+ n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge
237
+
238
+ x_stacked , y_stacked = x , y
239
+ dims_were_merged = False
240
+
241
+ if n_x_axis_to_merge :
242
+ # ravel batch dimensions of x on the core (m) axis
243
+ x_axis_destination = tuple (range (- n_x_axis_to_merge - 2 , - 2 ))
244
+ x_stacked = moveaxis (x , x_axis_to_merge , x_axis_destination )
245
+ if x_is_row :
246
+ # x was a row matrix, squeeze it to clean up the graph
247
+ x_stacked = x_stacked .squeeze (- 2 )
248
+ if n_x_axis_to_merge > 1 or not x_is_row :
249
+ if not allow_reshape :
250
+ # TODO: We could allow the y rewrite to go on
251
+ # Or just move one axis (the largest) if x is row
252
+ return None
253
+
254
+ # Ravel moved batch dims together with (m) if needed
255
+ x_stacked_shape = tuple (x_stacked .shape )
256
+ x_stacked = x_stacked .reshape (
257
+ (* x_stacked_shape [: batch_ndim - n_x_axis_to_merge ], - 1 , x_shape [- 1 ])
258
+ )
259
+ dims_were_merged = True
260
+
261
+ if n_y_axis_to_merge :
262
+ # ravel batch dimensions of y on the core (n) axis
263
+ y_axis_destination = tuple (range (- n_y_axis_to_merge - 1 , - 1 ))
264
+ y_stacked = moveaxis (y , y_axis_to_merge , y_axis_destination )
265
+ if y_is_col :
266
+ # y was a column matrix, squeeze it to clean up the graph
267
+ y_stacked = y_stacked .squeeze (- 1 )
268
+ if n_y_axis_to_merge > 1 or not y_is_col :
269
+ if not allow_reshape :
270
+ # TODO: We could allow the x rewrite to go on
271
+ # Or just move one axis (the largest) if y is col
272
+ return False
273
+ # Ravel moved batch dims together with (n) if needed
274
+ y_stacked_shape = tuple (y_stacked .shape )
275
+ y_stacked = y_stacked .reshape (
276
+ (* y_stacked_shape [: batch_ndim - n_y_axis_to_merge ], y_shape [- 2 ], - 1 )
277
+ )
278
+ dims_were_merged = True
279
+
280
+ # Squeeze x_dims corresponding to merged dimensions of y
281
+ x_axis_to_squeeze = np .array (y_axis_to_merge )
282
+ for i in reversed (x_axis_to_merge ):
283
+ # The corresponding dimensions of y may have shifted when we merged dimensions of x
284
+ x_axis_to_squeeze [x_axis_to_squeeze > i ] -= 1
285
+ x_stacked = x_stacked .squeeze (tuple (x_axis_to_squeeze ))
286
+
287
+ # Same for y
288
+ y_axis_to_squeeze = np .array (x_axis_to_merge )
289
+ for i in reversed (y_axis_to_merge ):
290
+ y_axis_to_squeeze [y_axis_to_squeeze > i ] -= 1
291
+ y_stacked = y_stacked .squeeze (tuple (y_axis_to_squeeze ))
292
+
293
+ out_stacked = x_stacked @ y_stacked
294
+
295
+ # Split back any merged dimensions
296
+ if dims_were_merged :
297
+ x_merged_shapes = [x_shape [i ] for i in x_axis_to_merge ]
298
+ if not x_is_row :
299
+ # Otherwise we handle that later with expand_dims, which is cleaner
300
+ x_merged_shapes .append (x_shape [- 2 ])
301
+ y_merged_shapes = [y_shape [i ] for i in y_axis_to_merge ]
302
+ if not y_is_col :
303
+ # Otherwise we handle that later with expand_dims, which is cleaner
304
+ y_merged_shapes .append (y_shape [- 1 ])
305
+ out_stacked_shape = tuple (out_stacked .shape )
306
+ out_unstacked = out_stacked .reshape (
307
+ (
308
+ * out_stacked_shape [: batch_ndim - n_axis_to_merge ],
309
+ * x_merged_shapes ,
310
+ * y_merged_shapes ,
311
+ )
312
+ )
313
+ else :
314
+ out_unstacked = out_stacked
315
+
316
+ # Add back dummy row, col axis
317
+ # We do this separately to avoid the reshape as much as we can
318
+ if y_is_col and (n_y_axis_to_merge or dims_were_merged ):
319
+ out_unstacked = expand_dims (out_unstacked , - 1 )
320
+ if x_is_row and (n_x_axis_to_merge or dims_were_merged ):
321
+ out_unstacked = expand_dims (out_unstacked , - n_y_axis_to_merge - 2 )
322
+
323
+ # Move batch axis back to their original location
324
+ source = range (- n_axis_to_merge - 2 , 0 )
325
+ destination = (* x_axis_to_merge , - 2 , * y_axis_to_merge , - 1 )
326
+ out = moveaxis (out_unstacked , source , destination )
327
+ return [out ]
328
+
329
+
330
+ @register_canonicalize
331
+ @node_rewriter (tracks = [_matmul ])
332
+ def local_batched_matmul_to_core_matmul (fgraph , node ):
333
+ # Allow passing batch dimensions of matmul to core vector / column matrices
334
+ return _batched_matmul_to_core_matmul (fgraph , node , allow_reshape = False )
335
+
336
+
337
+ @register_specialize
338
+ @node_rewriter (tracks = [_matmul ])
339
+ def local_batched_matmul_to_core_matmul_with_reshape (fgraph , node ):
340
+ # Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
341
+ # We only apply this in specialize, because grahs with reshape are hard to work with
342
+ return _batched_matmul_to_core_matmul (fgraph , node , allow_reshape = True )
249
343
250
344
251
345
@register_canonicalize
252
346
@register_specialize
253
- @node_rewriter ([_inner_prod , _matrix_vec_prod , _vec_matrix_prod , _matrix_matrix_matmul ])
347
+ @node_rewriter ([_matmul ])
254
348
def local_blockwise_dot_to_mul (fgraph , node ):
255
349
"""Rewrite blockwise dots that correspond to multiplication without summation.
256
350
0 commit comments