Skip to content

Error in trying to use Optimization.jl for LSTM training based on Lux.jl #860

@chooron

Description

@chooron

Describe the bug 🐞

Hello, I want to try using Optimization.jl to perform model optimization based on Lux.jl. Here's my code.

Minimal Reproducible Example 👇

using Lux
using Zygote
using StableRNGs
using ComponentArrays
using Optimization
using OptimizationOptimisers

function LSTMCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,2} where {T}
        x = reshape(x, size(x)..., 1)
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        output = [vec(classifier(y))]
        for x in x_rest
            y, carry = lstm_cell((x, carry))
            output = vcat(output, [vec(classifier(y))])
        end
        @return hcat(output...)
    end
end

model = LSTMCompact(3, 10, 1)
ps, st = Lux.setup(StableRNGs.LehmerRNG(1234), model)
ps_axes = getaxes(ComponentVector(ps))
model_func = (x, ps) -> Lux.apply(model, x, ps, st)
x = rand(3, 10)
y = rand(1, 10)

function object(u, p)
    ps = ComponentVector(u, ps_axes)
    sum((model_func(x, ps)[1] .- y) .^ 2)
end

opt_func = Optimization.OptimizationFunction(object, Optimization.AutoZygote())
opt_prob = Optimization.OptimizationProblem(opt_func, Vector(ComponentVector(ps)))
opt_sol = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(0.1), maxiters=1000)

Error & Stacktrace ⚠️
Translation: The code works when using AutoForwardDiff as the AD type, but when using AutoZygote it encounters the following error:

ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{…}})(::NTuple{9, Vector{…}})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:121
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.AbstractZero) where T
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:120
  (::ChainRulesCore.ProjectTo{AbstractArray})(::ChainRulesCore.Tangent)
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:200
  ...

Stacktrace:
  [1] (::ChainRules.var"#480#485"{ChainRulesCore.ProjectTo{}, Tuple{}, ChainRulesCore.Tangent{}})()
    @ ChainRules D:\Julia\Julia-1.10.4\packages\packages\ChainRules\hShjJ\src\rulesets\Base\array.jl:314
  [2] unthunk
    @ D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:205 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#479#484"{…}})
    @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:238
  [4] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
  [5] map
    @ .\tuple.jl:293 [inlined]
  [6] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{…}, Tuple{…}, Val{…}}})(dy::NTuple{10, Vector{Float64}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211
  [8] #21
    @ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:0 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Matrix{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [10] CompactLuxLayer
    @ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:366 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [12] apply
    @ D:\Julia\Julia-1.10.4\packages\packages\LuxCore\kYVM5\src\LuxCore.jl:171 [inlined]
 [13] #23
    @ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:27 [inlined]
 [14] object
    @ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:33 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [16] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] OptimizationFunction
    @ D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\scimlfunctions.jl:3812 [inlined]
 [19] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [20] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{}, Zygote.Pullback{}}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [21] #37
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:94 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [23] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [24] #2169#back
    @ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [25] #39
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [28] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [29] (::OptimizationZygoteExt.var"#38#56"{})(::ComponentVector{…}, ::ComponentVector{…})
    @ OptimizationZygoteExt D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97
 [30] macro expansion
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:68 [inlined]
 [31] macro expansion
    @ D:\Julia\Julia-1.10.4\packages\packages\Optimization\fPKIF\src\utils.jl:32 [inlined]
 [32] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:66
 [33] solve!(cache::OptimizationCache{…})
    @ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:188
 [34] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{})
    @ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:96
 [35] top-level scope
    @ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.

This issue seems to only occur with recurrent neural networks like LSTM, but not with regular fully connected neural networks. So I want to ask if there's a way to optimize Lux.jl's LSTMCell and other RNN models using Optimization.jl

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
  [7d9f7c33] Accessors v0.1.38
⌃ [4c88cf16] Aqua v0.8.7
  [6e4b80f9] BenchmarkTools v1.5.0
⌃ [336ed68f] CSV v0.10.14
⌃ [052768ef] CUDA v5.4.3
  [d360d2e6] ChainRulesCore v1.25.0
⌃ [b0b7db55] ComponentArrays v0.15.16
⌃ [a93c6f00] DataFrames v1.6.1
⌃ [82cc6244] DataInterpolations v6.1.0
⌅ [459566f4] DiffEqCallbacks v3.7.0
  [ffbed154] DocStringExtensions v0.9.3
⌃ [f6369f11] ForwardDiff v0.10.36
⌃ [86223c79] Graphs v1.11.2
  [cde335eb] HydroErrors v0.1.0 `D:\Julia\Julia-1.10.4\packages\dev\HydroErrors`
  [de52edbc] Integrals v4.5.0
  [a98d9a8b] Interpolations v0.15.1
  [c8e1da08] IterTools v1.10.0
⌃ [7ed4a6bd] LinearSolve v2.34.0
⌃ [b2108857] Lux v0.5.65
⌅ [bb33d45b] LuxCore v0.1.24
⌃ [961ee093] ModelingToolkit v9.32.0
⌃ [872c559c] NNlib v0.9.22
  [d9ec5142] NamedTupleTools v0.14.3
⌅ [7f7a1694] Optimization v3.27.0
⌃ [3e6eede4] OptimizationBBO v0.3.0
⌃ [42dfb2eb] OptimizationOptimisers v0.2.1
⌃ [1dea7af3] OrdinaryDiffEq v6.87.0
⌃ [d7d3b36b] ParameterSchedulers v0.4.2
⌃ [91a5bcdd] Plots v1.40.5
  [92933f4c] ProgressMeter v1.10.2
⌃ [731186ca] RecursiveArrayTools v3.27.0
  [189a3867] Reexport v1.2.2
  [7e49a35a] RuntimeGeneratedFunctions v0.5.13
⌃ [0bca4576] SciMLBase v2.50.0
⌃ [c0aeaf25] SciMLOperators v0.3.11
⌃ [1ed8b502] SciMLSensitivity v7.64.0
  [860ef19b] StableRNGs v1.0.2
⌃ [90137ffa] StaticArrays v1.9.7
⌅ [d1185830] SymbolicUtils v2.1.2
⌅ [0c5d862f] Symbolics v5.36.0
⌃ [e88e6eb3] Zygote v0.6.70
  [ade2ca70] Dates
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random
  [2f01184e] SparseArrays v1.10.0
  [10745b16] Statistics v1.10.0
  [fa267f1f] TOML v1.0.3
  • Output of versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 24 × 12th Gen Intel(R) Core(TM) i9-12900HX
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 24 virtual cores)
Environment:
  JULIA_DEPOT_PATH = D:\Julia\Julia-1.10.4\packages
  JULIA_PKG_SERVER = https://mirrors.pku.edu.cn/julia/
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions