Skip to content

Commit 48c2c83

Browse files
Factor out shared functions
1 parent c587e33 commit 48c2c83

File tree

1 file changed

+147
-71
lines changed

1 file changed

+147
-71
lines changed

pytensor/tensor/optimize.py

Lines changed: 147 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,32 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
128128
]
129129

130130

131+
def _get_parameter_grads_from_vector(
132+
grad_wrt_args_vector: Variable,
133+
x_star: Variable,
134+
args: Sequence[Variable],
135+
output_grad: Variable,
136+
):
137+
"""
138+
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
139+
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
140+
"""
141+
cursor = 0
142+
grad_wrt_args = []
143+
144+
for arg in args:
145+
arg_shape = arg.shape
146+
arg_size = arg_shape.prod()
147+
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
148+
(*x_star.shape, *arg_shape)
149+
)
150+
151+
grad_wrt_args.append(dot(output_grad, arg_grad))
152+
cursor += arg_size
153+
154+
return grad_wrt_args
155+
156+
131157
class ScipyWrapperOp(Op, HasInnerGraph):
132158
"""Shared logic for scipy optimization ops"""
133159

@@ -197,6 +223,98 @@ def make_node(self, *inputs):
197223
)
198224

199225

226+
def scalar_implict_optimization_grads(
227+
inner_fx: Variable,
228+
inner_x: Variable,
229+
inner_args: Sequence[Variable],
230+
args: Sequence[Variable],
231+
x_star: Variable,
232+
output_grad: Variable,
233+
fgraph: FunctionGraph,
234+
) -> list[Variable]:
235+
df_dx, *df_dthetas = grad(
236+
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
237+
)
238+
239+
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
240+
df_dx_star, *df_dthetas_stars = graph_replace([df_dx, *df_dthetas], replace=replace)
241+
242+
grad_wrt_args = [
243+
(-df_dtheta_star / df_dx_star) * output_grad
244+
for df_dtheta_star in df_dthetas_stars
245+
]
246+
247+
return grad_wrt_args
248+
249+
250+
def implict_optimization_grads(
251+
df_dx: Variable,
252+
df_dtheta_columns: Sequence[Variable],
253+
args: Sequence[Variable],
254+
x_star: Variable,
255+
output_grad: Variable,
256+
fgraph: FunctionGraph,
257+
):
258+
r"""
259+
Compute gradients of an optimization problem with respect to its parameters.
260+
261+
The gradents are computed using the implicit function theorem. Given a fuction `f(x, theta) =`, and a function
262+
`x_star(theta)` that, given input parameters theta returns `x` such that `f(x_star(theta), theta) = 0`, we can
263+
compute the gradients of `x_star` with respect to `theta` as follows:
264+
265+
.. math::
266+
267+
\underbrace{\frac{\partial f}{\partial x}\left(x^*(\theta), \theta\right)}_{\text{Jacobian wrt } x}
268+
\frac{d x^*(\theta)}{d \theta}
269+
+
270+
\underbrace{\frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)}_{\text{Jacobian wrt } \theta}
271+
= 0
272+
273+
Which, after rearranging, gives us:
274+
275+
.. math::
276+
277+
\frac{d x^*(\theta)}{d \theta} = - \left(\frac{\partial f}{\partial x}\left(x^*(\theta), \theta\right)\right)^{-1} \frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)
278+
279+
Note that this method assumes `f(x_star(theta), theta) = 0`; so it is not immediately applicable to a minimization
280+
problem, where `f` is the objective function. In this case, we instead take `f` to be the gradient of the objective
281+
function, which *is* indeed zero at the minimum.
282+
283+
Parameters
284+
----------
285+
df_dx : Variable
286+
The Jacobian of the objective function with respect to the variable `x`.
287+
df_dtheta_columns : Sequence[Variable]
288+
The Jacobians of the objective function with respect to the optimization parameters `theta`.
289+
Each column (or columns) corresponds to a different parameter. Should be returned by pytensor.gradient.jacobian.
290+
args : Sequence[Variable]
291+
The optimization parameters `theta`.
292+
x_star : Variable
293+
Symbolic graph representing the value of the variable `x` such that `f(x_star, theta) = 0 `.
294+
output_grad : Variable
295+
The gradient of the output with respect to the objective function.
296+
fgraph : FunctionGraph
297+
The function graph that contains the inputs and outputs of the optimization problem.
298+
"""
299+
df_dtheta = concatenate(
300+
[atleast_2d(jac_col, left=False) for jac_col in df_dtheta_columns],
301+
axis=-1,
302+
)
303+
304+
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
305+
306+
df_dx_star, df_dtheta_star = graph_replace(
307+
[atleast_2d(df_dx), df_dtheta], replace=replace
308+
)
309+
310+
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
311+
grad_wrt_args = _get_parameter_grads_from_vector(
312+
grad_wrt_args_vector, x_star, args, output_grad
313+
)
314+
315+
return grad_wrt_args
316+
317+
200318
class MinimizeScalarOp(ScipyWrapperOp):
201319
__props__ = ("method",)
202320

@@ -242,20 +360,17 @@ def L_op(self, inputs, outputs, output_grads):
242360
inner_fx = self.fgraph.outputs[0]
243361

244362
implicit_f = grad(inner_fx, inner_x)
245-
df_dx, *df_dthetas = grad(
246-
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
247-
)
248363

249-
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
250-
df_dx_star, *df_dthetas_stars = graph_replace(
251-
[df_dx, *df_dthetas], replace=replace
364+
grad_wrt_args = scalar_implict_optimization_grads(
365+
inner_fx=implicit_f,
366+
inner_x=inner_x,
367+
inner_args=inner_args,
368+
args=args,
369+
x_star=x_star,
370+
output_grad=output_grad,
371+
fgraph=self.fgraph,
252372
)
253373

254-
grad_wrt_args = [
255-
(-df_dtheta_star / df_dx_star) * output_grad
256-
for df_dtheta_star in df_dthetas_stars
257-
]
258-
259374
return [zeros_like(x), *grad_wrt_args]
260375

261376

@@ -348,34 +463,17 @@ def L_op(self, inputs, outputs, output_grads):
348463

349464
implicit_f = grad(inner_fx, inner_x)
350465

351-
df_dx = atleast_2d(concatenate(jacobian(implicit_f, [inner_x]), axis=-1))
352-
353-
df_dtheta = concatenate(
354-
[
355-
atleast_2d(x, left=False)
356-
for x in jacobian(implicit_f, inner_args, disconnected_inputs="ignore")
357-
],
358-
axis=-1,
466+
df_dx, *df_dtheta_columns = jacobian(
467+
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
468+
)
469+
grad_wrt_args = implict_optimization_grads(
470+
df_dx=df_dx,
471+
df_dtheta_columns=df_dtheta_columns,
472+
args=args,
473+
x_star=x_star,
474+
output_grad=output_grad,
475+
fgraph=self.fgraph,
359476
)
360-
361-
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
362-
363-
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
364-
365-
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
366-
367-
cursor = 0
368-
grad_wrt_args = []
369-
370-
for arg in args:
371-
arg_shape = arg.shape
372-
arg_size = arg_shape.prod()
373-
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
374-
(*x_star.shape, *arg_shape)
375-
)
376-
377-
grad_wrt_args.append(dot(output_grad, arg_grad))
378-
cursor += arg_size
379477

380478
return [zeros_like(x), *grad_wrt_args]
381479

@@ -432,7 +530,7 @@ def minimize(
432530

433531

434532
class RootOp(ScipyWrapperOp):
435-
__props__ = ("method", "jac")
533+
__props__ = ("method", "jac", "optimizer_kwargs")
436534

437535
def __init__(
438536
self,
@@ -489,35 +587,17 @@ def L_op(
489587
inner_fx = self.fgraph.outputs[0]
490588

491589
df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1]
492-
493-
df_dtheta = concatenate(
494-
[
495-
atleast_2d(jac_column, left=False)
496-
for jac_column in jacobian(
497-
inner_fx, inner_args, disconnected_inputs="ignore"
498-
)
499-
],
500-
axis=-1,
590+
df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore")
591+
592+
grad_wrt_args = implict_optimization_grads(
593+
df_dx=df_dx,
594+
df_dtheta_columns=df_dtheta_columns,
595+
args=args,
596+
x_star=x_star,
597+
output_grad=output_grad,
598+
fgraph=self.fgraph,
501599
)
502600

503-
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
504-
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
505-
506-
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
507-
508-
cursor = 0
509-
grad_wrt_args = []
510-
511-
for arg in args:
512-
arg_shape = arg.shape
513-
arg_size = arg_shape.prod()
514-
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
515-
(*x_star.shape, *arg_shape)
516-
)
517-
518-
grad_wrt_args.append(dot(output_grad, arg_grad))
519-
cursor += arg_size
520-
521601
return [zeros_like(x), *grad_wrt_args]
522602

523603

@@ -529,11 +609,7 @@ def root(
529609
):
530610
"""Find roots of a system of equations using scipy.optimize.root."""
531611

532-
args = [
533-
arg
534-
for arg in truncated_graph_inputs([equations], [variables])
535-
if (arg is not variables and not isinstance(arg, Constant))
536-
]
612+
args = _find_optimization_parameters(equations, variables)
537613

538614
root_op = RootOp(variables, *args, equations=equations, method=method, jac=jac)
539615

0 commit comments

Comments
 (0)