@@ -128,6 +128,32 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
128
128
]
129
129
130
130
131
+ def _get_parameter_grads_from_vector (
132
+ grad_wrt_args_vector : Variable ,
133
+ x_star : Variable ,
134
+ args : Sequence [Variable ],
135
+ output_grad : Variable ,
136
+ ):
137
+ """
138
+ Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
139
+ returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
140
+ """
141
+ cursor = 0
142
+ grad_wrt_args = []
143
+
144
+ for arg in args :
145
+ arg_shape = arg .shape
146
+ arg_size = arg_shape .prod ()
147
+ arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
148
+ (* x_star .shape , * arg_shape )
149
+ )
150
+
151
+ grad_wrt_args .append (dot (output_grad , arg_grad ))
152
+ cursor += arg_size
153
+
154
+ return grad_wrt_args
155
+
156
+
131
157
class ScipyWrapperOp (Op , HasInnerGraph ):
132
158
"""Shared logic for scipy optimization ops"""
133
159
@@ -197,6 +223,98 @@ def make_node(self, *inputs):
197
223
)
198
224
199
225
226
+ def scalar_implict_optimization_grads (
227
+ inner_fx : Variable ,
228
+ inner_x : Variable ,
229
+ inner_args : Sequence [Variable ],
230
+ args : Sequence [Variable ],
231
+ x_star : Variable ,
232
+ output_grad : Variable ,
233
+ fgraph : FunctionGraph ,
234
+ ) -> list [Variable ]:
235
+ df_dx , * df_dthetas = grad (
236
+ inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore"
237
+ )
238
+
239
+ replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
240
+ df_dx_star , * df_dthetas_stars = graph_replace ([df_dx , * df_dthetas ], replace = replace )
241
+
242
+ grad_wrt_args = [
243
+ (- df_dtheta_star / df_dx_star ) * output_grad
244
+ for df_dtheta_star in df_dthetas_stars
245
+ ]
246
+
247
+ return grad_wrt_args
248
+
249
+
250
+ def implict_optimization_grads (
251
+ df_dx : Variable ,
252
+ df_dtheta_columns : Sequence [Variable ],
253
+ args : Sequence [Variable ],
254
+ x_star : Variable ,
255
+ output_grad : Variable ,
256
+ fgraph : FunctionGraph ,
257
+ ):
258
+ r"""
259
+ Compute gradients of an optimization problem with respect to its parameters.
260
+
261
+ The gradents are computed using the implicit function theorem. Given a fuction `f(x, theta) =`, and a function
262
+ `x_star(theta)` that, given input parameters theta returns `x` such that `f(x_star(theta), theta) = 0`, we can
263
+ compute the gradients of `x_star` with respect to `theta` as follows:
264
+
265
+ .. math::
266
+
267
+ \underbrace{\frac{\partial f}{\partial x}\left(x^*(\theta), \theta\right)}_{\text{Jacobian wrt } x}
268
+ \frac{d x^*(\theta)}{d \theta}
269
+ +
270
+ \underbrace{\frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)}_{\text{Jacobian wrt } \theta}
271
+ = 0
272
+
273
+ Which, after rearranging, gives us:
274
+
275
+ .. math::
276
+
277
+ \frac{d x^*(\theta)}{d \theta} = - \left(\frac{\partial f}{\partial x}\left(x^*(\theta), \theta\right)\right)^{-1} \frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)
278
+
279
+ Note that this method assumes `f(x_star(theta), theta) = 0`; so it is not immediately applicable to a minimization
280
+ problem, where `f` is the objective function. In this case, we instead take `f` to be the gradient of the objective
281
+ function, which *is* indeed zero at the minimum.
282
+
283
+ Parameters
284
+ ----------
285
+ df_dx : Variable
286
+ The Jacobian of the objective function with respect to the variable `x`.
287
+ df_dtheta_columns : Sequence[Variable]
288
+ The Jacobians of the objective function with respect to the optimization parameters `theta`.
289
+ Each column (or columns) corresponds to a different parameter. Should be returned by pytensor.gradient.jacobian.
290
+ args : Sequence[Variable]
291
+ The optimization parameters `theta`.
292
+ x_star : Variable
293
+ Symbolic graph representing the value of the variable `x` such that `f(x_star, theta) = 0 `.
294
+ output_grad : Variable
295
+ The gradient of the output with respect to the objective function.
296
+ fgraph : FunctionGraph
297
+ The function graph that contains the inputs and outputs of the optimization problem.
298
+ """
299
+ df_dtheta = concatenate (
300
+ [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
301
+ axis = - 1 ,
302
+ )
303
+
304
+ replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
305
+
306
+ df_dx_star , df_dtheta_star = graph_replace (
307
+ [atleast_2d (df_dx ), df_dtheta ], replace = replace
308
+ )
309
+
310
+ grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
311
+ grad_wrt_args = _get_parameter_grads_from_vector (
312
+ grad_wrt_args_vector , x_star , args , output_grad
313
+ )
314
+
315
+ return grad_wrt_args
316
+
317
+
200
318
class MinimizeScalarOp (ScipyWrapperOp ):
201
319
__props__ = ("method" ,)
202
320
@@ -242,20 +360,17 @@ def L_op(self, inputs, outputs, output_grads):
242
360
inner_fx = self .fgraph .outputs [0 ]
243
361
244
362
implicit_f = grad (inner_fx , inner_x )
245
- df_dx , * df_dthetas = grad (
246
- implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
247
- )
248
363
249
- replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
250
- df_dx_star , * df_dthetas_stars = graph_replace (
251
- [df_dx , * df_dthetas ], replace = replace
364
+ grad_wrt_args = scalar_implict_optimization_grads (
365
+ inner_fx = implicit_f ,
366
+ inner_x = inner_x ,
367
+ inner_args = inner_args ,
368
+ args = args ,
369
+ x_star = x_star ,
370
+ output_grad = output_grad ,
371
+ fgraph = self .fgraph ,
252
372
)
253
373
254
- grad_wrt_args = [
255
- (- df_dtheta_star / df_dx_star ) * output_grad
256
- for df_dtheta_star in df_dthetas_stars
257
- ]
258
-
259
374
return [zeros_like (x ), * grad_wrt_args ]
260
375
261
376
@@ -348,34 +463,17 @@ def L_op(self, inputs, outputs, output_grads):
348
463
349
464
implicit_f = grad (inner_fx , inner_x )
350
465
351
- df_dx = atleast_2d (concatenate (jacobian (implicit_f , [inner_x ]), axis = - 1 ))
352
-
353
- df_dtheta = concatenate (
354
- [
355
- atleast_2d (x , left = False )
356
- for x in jacobian (implicit_f , inner_args , disconnected_inputs = "ignore" )
357
- ],
358
- axis = - 1 ,
466
+ df_dx , * df_dtheta_columns = jacobian (
467
+ implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
468
+ )
469
+ grad_wrt_args = implict_optimization_grads (
470
+ df_dx = df_dx ,
471
+ df_dtheta_columns = df_dtheta_columns ,
472
+ args = args ,
473
+ x_star = x_star ,
474
+ output_grad = output_grad ,
475
+ fgraph = self .fgraph ,
359
476
)
360
-
361
- replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
362
-
363
- df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
364
-
365
- grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
366
-
367
- cursor = 0
368
- grad_wrt_args = []
369
-
370
- for arg in args :
371
- arg_shape = arg .shape
372
- arg_size = arg_shape .prod ()
373
- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
374
- (* x_star .shape , * arg_shape )
375
- )
376
-
377
- grad_wrt_args .append (dot (output_grad , arg_grad ))
378
- cursor += arg_size
379
477
380
478
return [zeros_like (x ), * grad_wrt_args ]
381
479
@@ -432,7 +530,7 @@ def minimize(
432
530
433
531
434
532
class RootOp (ScipyWrapperOp ):
435
- __props__ = ("method" , "jac" )
533
+ __props__ = ("method" , "jac" , "optimizer_kwargs" )
436
534
437
535
def __init__ (
438
536
self ,
@@ -489,35 +587,17 @@ def L_op(
489
587
inner_fx = self .fgraph .outputs [0 ]
490
588
491
589
df_dx = jacobian (inner_fx , inner_x ) if not self .jac else self .fgraph .outputs [1 ]
492
-
493
- df_dtheta = concatenate (
494
- [
495
- atleast_2d ( jac_column , left = False )
496
- for jac_column in jacobian (
497
- inner_fx , inner_args , disconnected_inputs = "ignore"
498
- )
499
- ] ,
500
- axis = - 1 ,
590
+ df_dtheta_columns = jacobian ( inner_fx , inner_args , disconnected_inputs = "ignore" )
591
+
592
+ grad_wrt_args = implict_optimization_grads (
593
+ df_dx = df_dx ,
594
+ df_dtheta_columns = df_dtheta_columns ,
595
+ args = args ,
596
+ x_star = x_star ,
597
+ output_grad = output_grad ,
598
+ fgraph = self . fgraph ,
501
599
)
502
600
503
- replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
504
- df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
505
-
506
- grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
507
-
508
- cursor = 0
509
- grad_wrt_args = []
510
-
511
- for arg in args :
512
- arg_shape = arg .shape
513
- arg_size = arg_shape .prod ()
514
- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
515
- (* x_star .shape , * arg_shape )
516
- )
517
-
518
- grad_wrt_args .append (dot (output_grad , arg_grad ))
519
- cursor += arg_size
520
-
521
601
return [zeros_like (x ), * grad_wrt_args ]
522
602
523
603
@@ -529,11 +609,7 @@ def root(
529
609
):
530
610
"""Find roots of a system of equations using scipy.optimize.root."""
531
611
532
- args = [
533
- arg
534
- for arg in truncated_graph_inputs ([equations ], [variables ])
535
- if (arg is not variables and not isinstance (arg , Constant ))
536
- ]
612
+ args = _find_optimization_parameters (equations , variables )
537
613
538
614
root_op = RootOp (variables , * args , equations = equations , method = method , jac = jac )
539
615
0 commit comments