-
-
Notifications
You must be signed in to change notification settings - Fork 217
Open
Labels
ChainRulesadjoint -> rrule, and further integrationadjoint -> rrule, and further integration
Description
I want to use the stack
function introduced in Julia 1.9 in my model but Flux.jl (or its backend) cannot auto-differentiate it.
using Flux
nn = Dense(3 => 2)
x = randn(Float32, 3, 5)
slicestack(x) = stack((x for x in eachslice(x, dims = 1)), dims = 1)
slicecat(x) = reduce(vcat, (x' for x in eachslice(x, dims = 1)))
@assert slicestack(nn(x)) == slicecat(nn(x))
Flux.withgradient(nn -> sum(slicecat(nn(x))), nn) # this works
Flux.withgradient(nn -> sum(slicestack(nn(x))), nn) # but this doesn't
error (truncated):
kenta@KS-MBP ~/tmp> julia stack.jl
ERROR: LoadError: Mutating arrays is not supported -- called copyto!(SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true})
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:88
[3] (::Zygote.var"#555#556"{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:103
[4] (::Zygote.var"#2653#back#557"{Zygote.var"#555#556"{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
...
Environemnt:
julia> versioninfo()
Julia Version 1.9.0
Commit 8e630552924 (2023-05-07 11:25 UTC)
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 8 × Apple M1 Pro
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, apple-m1)
Threads: 1 on 6 virtual cores
Environment:
JULIA_PROJECT = @.
(tmp) pkg> status
Status `~/tmp/Project.toml`
[587475ba] Flux v0.13.16
Metadata
Metadata
Assignees
Labels
ChainRulesadjoint -> rrule, and further integrationadjoint -> rrule, and further integration