-
Couldn't load subscription status.
- Fork 4
Remove DenseArray restriction and add support for FillArrays
#6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This addresses #2. In particular, while this allows `StridedView`s to be used for non-`DenseArray` types, it does not define a constructor for them. This reflects the fact that in general, `StridedView`s are only well-defined for `DenseArray`s, but allows users to manually tap into the machinery of `Strided`, in which case it is up to them to ensure correctness.
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #6 +/- ##
==========================================
- Coverage 93.67% 91.53% -2.15%
==========================================
Files 4 5 +1
Lines 174 189 +15
==========================================
+ Hits 163 173 +10
- Misses 11 16 +5 ☔ View full report in Codecov by Sentry. |
|
Hi @lkdvos, thank you very much for this PR, like you mentioned I came across the using TensorOperations, StridedViews, Zygote
A = rand(4,3,2)
x = rand(2)
function f(x)
TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
sum(B)
end
function g(x)
TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
sum(B,dims=2)
end
Zygote.gradient(f,x) # doesn't work
Zygote.jacobian(g,x) # works and matches with ForwardDiff
# Following is the error stack trace
julia> Zygote.jacobian(f,x)
ERROR: conversion to pointer not defined for FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Base ./pointer.jl:67
[3] pointer
@ ./abstractarray.jl:1245 [inlined]
[4] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)})
@ StridedViews ~/.julia/packages/StridedViews/cN0vi/src/stridedview.jl:191
[5] gemm!(transA::Char, transB::Char, alpha::Float64, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, beta::Float64, C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)})
@ LinearAlgebra.BLAS ~/julia-src/julia-1.9.3/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:1524
[6] _threaded_blas_mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero, nthreads::Int64)
@ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:105
[7] _mul!
@ ~/.julia/packages/Strided/l1vm3/src/linalg.jl:91 [inlined]
[8] mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero)
@ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:60
[9] _unsafe_blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, ipC::Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:163
[10] blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:137
[11] tensorcontract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.Backend{:StridedBLAS})
@ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:65
[12] tensorcontract!
@ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:63 [inlined]
[13] tensorcontract!
@ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:35 [inlined]
[14] #62
@ ~/.julia/packages/TensorOperations/LAzcX/ext/TensorOperationsChainRulesCoreExt.jl:99 [inlined]
[15] unthunk
@ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
[16] wrap_chainrules_output
@ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:110 [inlined]
[17] map (repeats 4 times)
@ ./tuple.jl:276 [inlined]
[18] wrap_chainrules_output
@ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:111 [inlined]
[19] ZBack
@ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:211 [inlined]
[20] Pullback
@ ./REPL[54]:2 [inlined]
[21] (::Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[22] #291
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
[23] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[24] Pullback
@ ./operators.jl:1035 [inlined]
[25] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[26] Pullback
@ ./operators.jl:1034 [inlined]
[27] Pullback
@ ./operators.jl:1031 [inlined]
[28] (::Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[29] #291
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
[30] #2173#back
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[31] Pullback
@ ./operators.jl:1031 [inlined]
[32] (::Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[33] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
[34] withjacobian(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:150
[35] jacobian(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:128
[36] top-level scope
@ REPL[63]:1
I thought it would be relevant to this PR so posting here, but if requires a new issue, let me know, thank you. |
|
Thanks for reporting this! |
|
Great, thank you very much, that would be very helpful. |
This removes the hard restriction on the parent arrays subtyping
DenseArray, but only defines a default constructor forStridedView(::DenseArray).(note that the current implementation does however have a constructor
StridedView(::AbstractArray, size, strides, offset, op), which is useful in order to not have to redefine methods likeconjetc for non-DenseArrayStridedViews.)Additionally, it defines a package extension (which I currently have implemented solely as an extension, and not through requires for julia < v1.9) to be able to handle FillArrays, which shows up sometimes in for example Zygotes automatic differentiation rules.
This fixes #2