Skip to content

Errror in accumulate when I have one argument as a tuple #664

@pevnak

Description

@pevnak

Hello,

I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.
A have carved out an MWE, which would look like this

using Zygote

x = [randn(Float32, 2) for i in 1:3]
h = randn(Float32, 2)


function f(α, h, x)
	o = accumulate(x, init = h) do h, x
		α * h + x
	end
end

function g(α, h, x)
	o = accumulate(x, init = (h, x[1])) do (h,_),x
		(α * h + x, x)
	end
	first.(o)
end

gradient-> sum(sum(g(α, h, x))), 1f0)[1]
gradient-> sum(sum(f(α, h, x))), 1f0)[1]

While computing gradient of f succeeds, computing gradient of g crashes with

julia> gradient-> sum(sum(g(α, h, x))), 1f0)[1]
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})

Closest candidates are:
  construct(::Type{T}, ::T) where T<:Tuple
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:251
  construct(::Type{T}, ::NamedTuple{L}) where {T, L}
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:235

Stacktrace:
  [1] +(a::ChainRulesCore.Tangent{Tuple{…}, Tuple{…}}, d::ChainRulesCore.Tangent{Any, Tuple{…}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_arithmetic.jl:142
  [2] (::ChainRules.var"#1699#1702")(::Tuple{…}, ::Tuple{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:541
  [3] iterate(itr::Base.Iterators.Accumulate)
    @ Base.Iterators ./iterators.jl:589 [inlined]
  [4] collect_to!
    @ ./array.jl:892 [inlined]
  [5] collect_to_with_first!
    @ ./array.jl:870 [inlined]
  [6] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
    @ Base ./array.jl:864 [inlined]
  [7] collect(itr::Base.Generator)
    @ Base ./array.jl:759 [inlined]
  [8] #accumulate#893
    @ ./accumulate.jl:281 [inlined]
  [9] accumulate
    @ ./accumulate.jl:278 [inlined]
 [10] (::ChainRules.var"#decumulate#1701"{})(dy::Vector{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:540
 [11] ZBack
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [12] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#decumulate#1701"{}})(dy::Vector{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:237
 [13] g
    @ ./REPL[43]:2 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{FillArrays.Fill{…}, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] #53
    @ ./REPL[44]:1 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [18] gradient(f::Function, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [19] top-level scope
    @ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.

Julia and environment

julia> versioninfo()
Julia Version 1.10.0-rc2
Commit dbb9c46795b (2023-12-03 15:25 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 8 × Intel(R) Core(TM) i5-8279U CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

(tmp) pkg> st
Status `/private/tmp/Project.toml`
  [082447d4] ChainRules v1.63.0
  [d360d2e6] ChainRulesCore v1.21.1
  [26cc04aa] FiniteDifferences v0.12.31
  [587475ba] Flux v0.14.11
  [3bd65402] Optimisers v0.3.2
  [eeda0dda] SafeTensors v1.0.0
  [2913bbd2] StatsBase v0.34.2
  [e88e6eb3] Zygote v0.6.69

Thanks for help

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions