Skip to content

Stack is not differentiable #1423

@bicycle1885

Description

@bicycle1885

I want to use the stack function introduced in Julia 1.9 in my model but Flux.jl (or its backend) cannot auto-differentiate it.

using Flux

nn = Dense(3 => 2)
x = randn(Float32, 3, 5)

slicestack(x) = stack((x for x in eachslice(x, dims = 1)), dims = 1)
slicecat(x) = reduce(vcat, (x' for x in eachslice(x, dims = 1)))
@assert slicestack(nn(x)) == slicecat(nn(x))
Flux.withgradient(nn -> sum(slicecat(nn(x))), nn)  # this works
Flux.withgradient(nn -> sum(slicestack(nn(x))), nn)  # but this doesn't

error (truncated):

kenta@KS-MBP ~/tmp> julia stack.jl
ERROR: LoadError: Mutating arrays is not supported -- called copyto!(SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:88
  [3] (::Zygote.var"#555#556"{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:103
  [4] (::Zygote.var"#2653#back#557"{Zygote.var"#555#556"{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
...

Environemnt:

julia> versioninfo()
Julia Version 1.9.0
Commit 8e630552924 (2023-05-07 11:25 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, apple-m1)
  Threads: 1 on 6 virtual cores
Environment:
  JULIA_PROJECT = @.

(tmp) pkg> status
Status `~/tmp/Project.toml`
  [587475ba] Flux v0.13.16

Metadata

Metadata

Assignees

No one assigned

    Labels

    ChainRulesadjoint -> rrule, and further integration

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions