-
-
Notifications
You must be signed in to change notification settings - Fork 614
Description
We expose a pullback/vjp-based API for gradients (y, back = forward(f, x); x̄ = back(ȳ)
), with x̄ = gradient(f, x)
as simple syntax sugar of top of this. This interface is pretty awesome – gradient
aligns nicely with the mathematical and intuitive notions of a derivative operator, it naturally expresses nested derivatives, and you can build pretty much any other AD-related functionality (checkpointing, forward mode, gradient hooks, etc) on top of pullbacks, without having to go into AD internals. So far I haven't come across anything that pullbacks can't do straightforwardly; in one case the PyTorch-style back!
may be slightly more convenient, but it's overall more cumbersome and requires more knowledge of internals.
However, a challenge of the "mathematical" gradient operator is that it's cumbersome to pass in all our parameter arrays explicitly (gradient(resnet, W1, b1, W2, b2, ...)
). So we need to be able to handle taking gradients of large models without it being cumbersome. There are currently two ideas about how to do this: the structural approach and the implicit approach.
Edit: Since writing this I have convinced myself that we can get the convenience of implicit params by slightly generalising the structural approach. I think this gives us a clear path forward, though unfortunately it does mean additional API churn.
Structural Gradients
The first approach (which Zygote will support whatever happens, and could be added to Flux) is to take the gradients w.r.t. some structure containing all of the parameters. The structure could be a dictionary or list, but it's usually convenient to combine the weight structure with the definition of the forward pass. This is effectively a closure, which we refer to as a "layer". Layers can contain other layers (we often call a compound layer a "model", but there's no fundamental difference). Taking a gradient looks like this:
m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
x, y = ...
m̄ = gradient(m -> loss(m(x), y), m)
This looks pretty weird at first but makes a lot of sense once it clicks. One then carries out the update step m .+= m̄
.
Implicit Gradients
The implicit approach is what Flux supports natively, though it works in Zygote as well. In this case we ask for gradients of a shapeless set of parameters, which are implicitly used at some point during the forward pass. In this case we have something more like:
m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
x, y = ...
θ = params(m)
θ̄ = gradient(() -> loss(m(x), y), θ)
θ̄
is a dictionary from param to gradient. One then loops over the parameters, doing p .+= θ̄[p]
.
Do we need implicit parameters?
Implicit parameters have some downsides. They feel somewhat less clean and functional than structural ones. It does not support scalars or immutable arrays well, which are needed for more restrictive backends like TPUs; supporting both means having more than one way to do things.
However, implicit parameters have a huge advantage: they make it easy to write "script-like" models. I see them as being a lot like global variables: sure they're a bit unclean, but sometimes it's just convenient to build up a model gradually in a notebook, without a lot of structure (and if I have one non-negotiable rule of API design, it's that you should never have to define a struct/class or satisfy an interface to use a library). Our VAE model is a nice example of this style which I think would be made significantly more cumbersome otherwise.
A potential solution is to make it easier to define "anonymous layers". In my ideal world these would also just be closures, but unfortunately this isn't workable (see discussion below) – closures don't explicitly store parameters when they are closed over from global scope, making them invisible to structural AD. Functions that return closures would be completely fine, but the distinction is too subtle / tied to implementation details.
Other concerns
A couple of other small subtleties.
In the implicit style parameter identity matters, which means we can reuse parameters when creating a model. For example:
d = Dense(10, 10, relu)
m = Chain(d, d, softmax)
dm = gradient(...)
In the implicit parameter style the weights of d
are shared, but in the structural version we get two separate gradients for d
at each point in the chain, and we'd have to construct the chain inside the forward pass to get the gradient we want. Similar issues come up in nested AD. I don't think either behaviour is more correct or better – both are well-defined and predictable – but they are different.
[This is further complicated by the fact that in-place updates mean the weight is effectively shared even in the structural case, just with weird semantics in optimiser state.]
This has some relevance to RNNs, which currently weight-ties the initial and current state fields. I don't think this needs to result in user-facing changes though. It's also possible to make a nice pure-functional RNN interface (see e.g. JAX), you just can't abstract over RNNs quite like we currently do; state needs to be a bit more explicitly managed (which isn't necessarily a deal breaker, but worth considering).
[To be clear though, while the structural approach is more functional, it does not force differentiated programs to be functional themselves, so something very like our current RNN design is still possible.]
Metadata
Metadata
Assignees
Labels
Type
Projects
Status