Skip to content

Conversation

@lkdvos
Copy link
Collaborator

@lkdvos lkdvos commented Jan 3, 2024

This removes the hard restriction on the parent arrays subtyping DenseArray, but only defines a default constructor for StridedView(::DenseArray).
(note that the current implementation does however have a constructor StridedView(::AbstractArray, size, strides, offset, op), which is useful in order to not have to redefine methods like conj etc for non-DenseArray StridedViews.)

Additionally, it defines a package extension (which I currently have implemented solely as an extension, and not through requires for julia < v1.9) to be able to handle FillArrays, which shows up sometimes in for example Zygotes automatic differentiation rules.

This fixes #2

This addresses #2.

In particular, while this allows `StridedView`s to be used for non-`DenseArray` types, it does not define a constructor for them.
This reflects the fact that in general, `StridedView`s are only well-defined for `DenseArray`s, but allows users to manually tap into the machinery of `Strided`, in which case it is up to them to ensure correctness.
@codecov
Copy link

codecov bot commented Jan 3, 2024

Codecov Report

Attention: 5 lines in your changes are missing coverage. Please review.

Comparison is base (517e676) 93.67% compared to head (5ba0421) 91.53%.

Files Patch % Lines
ext/StridedViewsFillArraysExt.jl 64.28% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main       #6      +/-   ##
==========================================
- Coverage   93.67%   91.53%   -2.15%     
==========================================
  Files           4        5       +1     
  Lines         174      189      +15     
==========================================
+ Hits          163      173      +10     
- Misses         11       16       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kishore-nori
Copy link

Hi @lkdvos, thank you very much for this PR, like you mentioned I came across the StridedView constructor not being available for FillArray in an AD situation (see MWE below), I tried using this branch to check if it resolves the below example, but deep down it throws another error regarding the lack of pointer function for FillArray (but only in one case below), called within unsafe_convert on a StridedView. Below is the example and the stack trace when using this branch:

using TensorOperations, StridedViews, Zygote

A = rand(4,3,2)
x = rand(2)

function f(x)
 TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
 sum(B)
end

function g(x)
 TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
 sum(B,dims=2)
end

Zygote.gradient(f,x) # doesn't work

Zygote.jacobian(g,x) # works and matches with ForwardDiff

# Following is the error stack trace 
julia> Zygote.jacobian(f,x)
ERROR: conversion to pointer not defined for FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Base ./pointer.jl:67
  [3] pointer
    @ ./abstractarray.jl:1245 [inlined]
  [4] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)})
    @ StridedViews ~/.julia/packages/StridedViews/cN0vi/src/stridedview.jl:191
  [5] gemm!(transA::Char, transB::Char, alpha::Float64, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, beta::Float64, C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)})
    @ LinearAlgebra.BLAS ~/julia-src/julia-1.9.3/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:1524
  [6] _threaded_blas_mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero, nthreads::Int64)
    @ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:105
  [7] _mul!
    @ ~/.julia/packages/Strided/l1vm3/src/linalg.jl:91 [inlined]
  [8] mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero)
    @ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:60
  [9] _unsafe_blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, ipC::Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
    @ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:163
 [10] blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
    @ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:137
 [11] tensorcontract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.Backend{:StridedBLAS})
    @ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:65
 [12] tensorcontract!
    @ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:63 [inlined]
 [13] tensorcontract!
    @ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:35 [inlined]
 [14] #62
    @ ~/.julia/packages/TensorOperations/LAzcX/ext/TensorOperationsChainRulesCoreExt.jl:99 [inlined]
 [15] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
 [16] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:110 [inlined]
 [17] map (repeats 4 times)
    @ ./tuple.jl:276 [inlined]
 [18] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:111 [inlined]
 [19] ZBack
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:211 [inlined]
 [20] Pullback
    @ ./REPL[54]:2 [inlined]
 [21] (::Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [22] #291
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
 [23] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
 [24] Pullback
    @ ./operators.jl:1035 [inlined]
 [25] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./operators.jl:1034 [inlined]
 [27] Pullback
    @ ./operators.jl:1031 [inlined]
 [28] (::Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [29] #291
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
 [30] #2173#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [31] Pullback
    @ ./operators.jl:1031 [inlined]
 [32] (::Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [33] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
 [34] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:150
 [35] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:128
 [36] top-level scope
    @ REPL[63]:1

I thought it would be relevant to this PR so posting here, but if requires a new issue, let me know, thank you.

@lkdvos
Copy link
Collaborator Author

lkdvos commented Jan 25, 2024

Thanks for reporting this!
I think my PR indeed requires a bit more work, as multiplication should not be dispatched through to BLAS like that, and your example is the exact reason I started looking into this.
I'll try and make some time to look further into it next week, and I hope to add this as a test case to the TensorOperations suite.

@kishore-nori
Copy link

Great, thank you very much, that would be very helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Do not require DenseArray

3 participants