@@ -34,9 +34,8 @@ const DualBLinearProblem = LinearProblem{
3434const DualAbstractLinearProblem = Union{
3535 DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3636
37- LinearSolve. @concrete mutable struct DualLinearCache
37+ LinearSolve. @concrete mutable struct DualLinearCache{DT <: Dual }
3838 linear_cache
39- dual_type
4039
4140 partials_A
4241 partials_b
@@ -54,7 +53,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5453 primal_b = copy (cache. linear_cache. b)
5554 uu = sol. u
5655
57- primal_sol = deepcopy (sol)
56+ primal_sol = (;
57+ u = recursivecopy (sol. u),
58+ resid = recursivecopy (sol. resid),
59+ retcode = recursivecopy (sol. retcode),
60+ iters = recursivecopy (sol. iters),
61+ stats = recursivecopy (sol. stats)
62+ )
5863
5964 # Solves Dual partials separately
6065 ∂_A = cache. partials_A
@@ -103,21 +108,15 @@ function xp_linsolve_rhs(
103108end
104109
105110function linearsolve_dual_solution (
106- u:: Number , partials, dual_type)
107- return dual_type (u, partials)
108- end
109-
110- function linearsolve_dual_solution (u:: Number , partials,
111- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
112- # Handle single-level duals
113- return dual_type (u, partials)
111+ u:: Number , partials, cache:: DualLinearCache{DT} ) where {DT}
112+ return DT (u, partials)
114113end
115114
116115function linearsolve_dual_solution (u:: AbstractArray , partials,
117- dual_type :: Type{<:Dual{T, V, P}} ) where {T, V, P }
116+ cache :: DualLinearCache{DT} ) where {DT }
118117 # Handle single-level duals for arrays
119118 partials_list = RecursiveArrayTools. VectorOfArray (partials)
120- return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
119+ return map (((uᵢ, pᵢ),) -> DT (uᵢ, Partials (Tuple (pᵢ))),
121120 zip (u, partials_list[i, :] for i in 1 : length (partials_list. u[1 ])))
122121end
123122
@@ -167,7 +166,7 @@ function __dual_init(
167166 alias = alias, abstol = abstol, reltol = reltol,
168167 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
169168 sensealg = sensealg, u0 = new_u0, kwargs... )
170- return DualLinearCache (non_partial_cache, dual_type , ∂_A, ∂_b,
169+ return DualLinearCache {dual_type} (non_partial_cache, ∂_A, ∂_b,
171170 ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b)))
172171end
173172
@@ -176,11 +175,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
176175end
177176
178177function SciMLBase. solve! (
179- cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
178+ cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
180179 sol,
181180 partials = linearsolve_forwarddiff_solve (
182181 cache:: DualLinearCache , cache. alg, args... ; kwargs... )
183- dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type )
182+ dual_sol = linearsolve_dual_solution (sol. u, partials, cache)
184183
185184 if cache. dual_u isa AbstractArray
186185 cache. dual_u[:] = dual_sol
0 commit comments