Skip to content

(WIP) Implement a FowardDiff version of FFJORD #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

ChrisRackauckas
Copy link
Member

This implements a forward-mode version of FFJORD via ForwardDiff. However, Dual tag ordering issues are showing up, so it's failing.

using DiffEqFlux, DifferentialEquations, GalacticOptim, Distributions

nn = Chain(
    Dense(1, 3, tanh),
    Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 10.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())

data_dist = Normal(6.0f0, 0.7f0)
train_data = rand(data_dist, 1, 100)

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

adtype = GalacticOptim.AutoZygote()
res1 = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; maxiters=100)

also with AutoForwardDiff

@mateuszbaran
Copy link

I'm trying to fix this but the failing test looks really strange, and that's after I've fixed (I think) the problem with Dual tag. The e variable keeps being just true, even when monte_carlo is true. Zygote somehow manages to work with it but auto_jacvec does not. Does backpropagation of true have some special meaning in Zygote?

@ChrisRackauckas
Copy link
Member Author

Yeah I'm not sure, and that's why I dropped this for a bit. I'm not convinced it's not a Zygote bug.

@mateuszbaran
Copy link

OK, in any case what I did to auto_jacvec was replacing broadcasting with map. Very likely Zygote's broadcast differentiation isn't perfect and avoiding unnecessary broadcasting in differentiated code seems like a good idea.

function auto_jacvec(f, x, v)
    fval = f(map((xi, vi) -> Dual{typeof(ForwardDiff.Tag(f,eltype(x)))}(xi, vi), x, v))
    map(u -> partials(u)[1], fval)
end

I've also changed a tag here to resemble more closely what ForwardDiff.jl does, I'm not sure why auto_jacvec has its own separate tag there.

@ChrisRackauckas
Copy link
Member Author

Yeah, that's a better tag. Would be worth upstreaming.

This implements a forward-mode version of FFJORD via ForwardDiff. However, Dual tag ordering issues are showing up, so it's failing.

```julia
using DiffEqFlux, DifferentialEquations, GalacticOptim, Distributions

nn = Chain(
    Dense(1, 3, tanh),
    Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 10.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())

data_dist = Normal(6.0f0, 0.7f0)
train_data = rand(data_dist, 1, 100)

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

adtype = GalacticOptim.AutoZygote()
res1 = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; maxiters=100)
```

also with AutoForwardDiff
@ChrisRackauckas ChrisRackauckas deleted the forward_ffjord branch May 20, 2024 02:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants