-
-
Notifications
You must be signed in to change notification settings - Fork 101
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug 🐞
Hello, I want to try using Optimization.jl to perform model optimization based on Lux.jl. Here's my code.
Minimal Reproducible Example 👇
using Lux
using Zygote
using StableRNGs
using ComponentArrays
using Optimization
using OptimizationOptimisers
function LSTMCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,2} where {T}
x = reshape(x, size(x)..., 1)
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
output = [vec(classifier(y))]
for x in x_rest
y, carry = lstm_cell((x, carry))
output = vcat(output, [vec(classifier(y))])
end
@return hcat(output...)
end
end
model = LSTMCompact(3, 10, 1)
ps, st = Lux.setup(StableRNGs.LehmerRNG(1234), model)
ps_axes = getaxes(ComponentVector(ps))
model_func = (x, ps) -> Lux.apply(model, x, ps, st)
x = rand(3, 10)
y = rand(1, 10)
function object(u, p)
ps = ComponentVector(u, ps_axes)
sum((model_func(x, ps)[1] .- y) .^ 2)
end
opt_func = Optimization.OptimizationFunction(object, Optimization.AutoZygote())
opt_prob = Optimization.OptimizationProblem(opt_func, Vector(ComponentVector(ps)))
opt_sol = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(0.1), maxiters=1000)
Error & Stacktrace
Translation: The code works when using AutoForwardDiff
as the AD type, but when using AutoZygote
it encounters the following error:
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{…}})(::NTuple{9, Vector{…}})
Closest candidates are:
(::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
@ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:121
(::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.AbstractZero) where T
@ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:120
(::ChainRulesCore.ProjectTo{AbstractArray})(::ChainRulesCore.Tangent)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:200
...
Stacktrace:
[1] (::ChainRules.var"#480#485"{ChainRulesCore.ProjectTo{…}, Tuple{…}, ChainRulesCore.Tangent{…}})()
@ ChainRules D:\Julia\Julia-1.10.4\packages\packages\ChainRules\hShjJ\src\rulesets\Base\array.jl:314
[2] unthunk
@ D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:205 [inlined]
[3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#479#484"{…}})
@ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:238
[4] wrap_chainrules_output
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
[5] map
@ .\tuple.jl:293 [inlined]
[6] wrap_chainrules_output
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
[7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{…}, Tuple{…}, Val{…}}})(dy::NTuple{10, Vector{Float64}})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211
[8] #21
@ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:0 [inlined]
[9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Matrix{…}, Nothing})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[10] CompactLuxLayer
@ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:366 [inlined]
[11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}, Nothing})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[12] apply
@ D:\Julia\Julia-1.10.4\packages\packages\LuxCore\kYVM5\src\LuxCore.jl:171 [inlined]
[13] #23
@ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:27 [inlined]
[14] object
@ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:33 [inlined]
[15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[16] #291
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[17] #2169#back
@ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[18] OptimizationFunction
@ D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\scimlfunctions.jl:3812 [inlined]
[19] #291
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[20] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[21] #37
@ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:94 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[23] #291
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[24] #2169#back
@ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[25] #39
@ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97 [inlined]
[26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
[28] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
[29] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
@ OptimizationZygoteExt D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97
[30] macro expansion
@ D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:68 [inlined]
[31] macro expansion
@ D:\Julia\Julia-1.10.4\packages\packages\Optimization\fPKIF\src\utils.jl:32 [inlined]
[32] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:66
[33] solve!(cache::OptimizationCache{…})
@ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:188
[34] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
@ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:96
[35] top-level scope
@ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.
This issue seems to only occur with recurrent neural networks like LSTM, but not with regular fully connected neural networks. So I want to ask if there's a way to optimize Lux.jl's LSTMCell and other RNN models using Optimization.jl
Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
[7d9f7c33] Accessors v0.1.38
⌃ [4c88cf16] Aqua v0.8.7
[6e4b80f9] BenchmarkTools v1.5.0
⌃ [336ed68f] CSV v0.10.14
⌃ [052768ef] CUDA v5.4.3
[d360d2e6] ChainRulesCore v1.25.0
⌃ [b0b7db55] ComponentArrays v0.15.16
⌃ [a93c6f00] DataFrames v1.6.1
⌃ [82cc6244] DataInterpolations v6.1.0
⌅ [459566f4] DiffEqCallbacks v3.7.0
[ffbed154] DocStringExtensions v0.9.3
⌃ [f6369f11] ForwardDiff v0.10.36
⌃ [86223c79] Graphs v1.11.2
[cde335eb] HydroErrors v0.1.0 `D:\Julia\Julia-1.10.4\packages\dev\HydroErrors`
[de52edbc] Integrals v4.5.0
[a98d9a8b] Interpolations v0.15.1
[c8e1da08] IterTools v1.10.0
⌃ [7ed4a6bd] LinearSolve v2.34.0
⌃ [b2108857] Lux v0.5.65
⌅ [bb33d45b] LuxCore v0.1.24
⌃ [961ee093] ModelingToolkit v9.32.0
⌃ [872c559c] NNlib v0.9.22
[d9ec5142] NamedTupleTools v0.14.3
⌅ [7f7a1694] Optimization v3.27.0
⌃ [3e6eede4] OptimizationBBO v0.3.0
⌃ [42dfb2eb] OptimizationOptimisers v0.2.1
⌃ [1dea7af3] OrdinaryDiffEq v6.87.0
⌃ [d7d3b36b] ParameterSchedulers v0.4.2
⌃ [91a5bcdd] Plots v1.40.5
[92933f4c] ProgressMeter v1.10.2
⌃ [731186ca] RecursiveArrayTools v3.27.0
[189a3867] Reexport v1.2.2
[7e49a35a] RuntimeGeneratedFunctions v0.5.13
⌃ [0bca4576] SciMLBase v2.50.0
⌃ [c0aeaf25] SciMLOperators v0.3.11
⌃ [1ed8b502] SciMLSensitivity v7.64.0
[860ef19b] StableRNGs v1.0.2
⌃ [90137ffa] StaticArrays v1.9.7
⌅ [d1185830] SymbolicUtils v2.1.2
⌅ [0c5d862f] Symbolics v5.36.0
⌃ [e88e6eb3] Zygote v0.6.70
[ade2ca70] Dates
[37e2e46d] LinearAlgebra
[9a3f8284] Random
[2f01184e] SparseArrays v1.10.0
[10745b16] Statistics v1.10.0
[fa267f1f] TOML v1.0.3
- Output of
versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 24 × 12th Gen Intel(R) Core(TM) i9-12900HX
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 24 virtual cores)
Environment:
JULIA_DEPOT_PATH = D:\Julia\Julia-1.10.4\packages
JULIA_PKG_SERVER = https://mirrors.pku.edu.cn/julia/
JULIA_EDITOR = code
JULIA_NUM_THREADS =
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working