-
Notifications
You must be signed in to change notification settings - Fork 64
Open
Description
Hello,
I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate
instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.
A have carved out an MWE, which would look like this
using Zygote
x = [randn(Float32, 2) for i in 1:3]
h = randn(Float32, 2)
function f(α, h, x)
o = accumulate(x, init = h) do h, x
α * h + x
end
end
function g(α, h, x)
o = accumulate(x, init = (h, x[1])) do (h,_),x
(α * h + x, x)
end
first.(o)
end
gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
gradient(α -> sum(sum(f(α, h, x))), 1f0)[1]
While computing gradient of f
succeeds, computing gradient of g
crashes with
julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})
Closest candidates are:
construct(::Type{T}, ::T) where T<:Tuple
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:251
construct(::Type{T}, ::NamedTuple{L}) where {T, L}
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:235
Stacktrace:
[1] +(a::ChainRulesCore.Tangent{Tuple{…}, Tuple{…}}, d::ChainRulesCore.Tangent{Any, Tuple{…}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_arithmetic.jl:142
[2] (::ChainRules.var"#1699#1702")(::Tuple{…}, ::Tuple{…})
@ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:541
[3] iterate(itr::Base.Iterators.Accumulate)
@ Base.Iterators ./iterators.jl:589 [inlined]
[4] collect_to!
@ ./array.jl:892 [inlined]
[5] collect_to_with_first!
@ ./array.jl:870 [inlined]
[6] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
@ Base ./array.jl:864 [inlined]
[7] collect(itr::Base.Generator)
@ Base ./array.jl:759 [inlined]
[8] #accumulate#893
@ ./accumulate.jl:281 [inlined]
[9] accumulate
@ ./accumulate.jl:278 [inlined]
[10] (::ChainRules.var"#decumulate#1701"{…})(dy::Vector{…})
@ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:540
[11] ZBack
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
[12] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#decumulate#1701"{…}})(dy::Vector{Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:237
[13] g
@ ./REPL[43]:2 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{FillArrays.Fill{…}, 1, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[15] #53
@ ./REPL[44]:1 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
[18] gradient(f::Function, args::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
[19] top-level scope
@ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.
Julia and environment
julia> versioninfo()
Julia Version 1.10.0-rc2
Commit dbb9c46795b (2023-12-03 15:25 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (x86_64-apple-darwin22.4.0)
CPU: 8 × Intel(R) Core(TM) i5-8279U CPU @ 2.40GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 on 8 virtual cores
(tmp) pkg> st
Status `/private/tmp/Project.toml`
[082447d4] ChainRules v1.63.0
[d360d2e6] ChainRulesCore v1.21.1
[26cc04aa] FiniteDifferences v0.12.31
[587475ba] Flux v0.14.11
[3bd65402] Optimisers v0.3.2
[eeda0dda] SafeTensors v1.0.0
[2913bbd2] StatsBase v0.34.2
[e88e6eb3] Zygote v0.6.69
Thanks for help
Metadata
Metadata
Assignees
Labels
No labels