@@ -396,82 +396,120 @@ function lmul!(D::Diagonal, T::Tridiagonal)
396396 return T
397397end
398398
399- function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
399+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add:: MulAddMul )
400+ @inbounds for j in axes (B, 2 )
401+ @simd for i in axes (B, 1 )
402+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
403+ end
404+ end
405+ out
406+ end
407+ _maybe_unwrap_tri (out, A) = out, A
408+ _maybe_unwrap_tri (out:: UpperTriangular , A:: UpperOrUnitUpperTriangular ) = parent (out), parent (A)
409+ _maybe_unwrap_tri (out:: LowerTriangular , A:: LowerOrUnitLowerTriangular ) = parent (out), parent (A)
410+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add:: MulAddMul )
411+ isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
412+ # if both B and out have the same upper/lower triangular structure,
413+ # we may directly read and write from the parents
414+ out_maybeparent, B_maybeparent = _maybe_unwrap_tri (out, B)
415+ for j in axes (B, 2 )
416+ if isunit
417+ _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
418+ end
419+ rowrange = B isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (B,1 ))) : (j+ isunit: size (B,1 ))
420+ @inbounds @simd for i in rowrange
421+ _modify! (_add, D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
422+ end
423+ end
424+ out
425+ end
426+ function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul )
400427 require_one_based_indexing (out, B)
401428 alpha, beta = _add. alpha, _add. beta
402429 if iszero (alpha)
403430 _rmul_or_fill! (out, beta)
404431 else
405- if bis0
406- @inbounds for j in axes (B, 2 )
407- @simd for i in axes (B, 1 )
408- out[i,j] = D. diag[i] * B[i,j] * alpha
409- end
410- end
411- else
412- @inbounds for j in axes (B, 2 )
413- @simd for i in axes (B, 1 )
414- out[i,j] = D. diag[i] * B[i,j] * alpha + out[i,j] * beta
415- end
416- end
417- end
432+ __muldiag_nonzeroalpha! (out, D, B, _add)
418433 end
419434 return out
420435end
421- function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
436+
437+ @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
438+ beta = _add. beta
439+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
440+ @inbounds for j in axes (A, 2 )
441+ dja = _add (D. diag[j])
442+ @simd for i in axes (A, 1 )
443+ _modify! (_add_aisone, A[i,j] * dja, out, (i,j))
444+ end
445+ end
446+ out
447+ end
448+ @inline function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
449+ isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
450+ beta = _add. beta
451+ # since alpha is multiplied to the diagonal element of D,
452+ # we may skip alpha in the second multiplication by setting ais1 to true
453+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
454+ # if both A and out have the same upper/lower triangular structure,
455+ # we may directly read and write from the parents
456+ out_maybeparent, A_maybeparent = _maybe_unwrap_tri (out, A)
457+ @inbounds for j in axes (A, 2 )
458+ dja = _add (D. diag[j])
459+ if isunit
460+ _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
461+ end
462+ rowrange = A isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (A,1 ))) : (j+ isunit: size (A,1 ))
463+ @simd for i in rowrange
464+ _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
465+ end
466+ end
467+ out
468+ end
469+ function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul )
422470 require_one_based_indexing (out, A)
423471 alpha, beta = _add. alpha, _add. beta
424472 if iszero (alpha)
425473 _rmul_or_fill! (out, beta)
426474 else
427- if bis0
428- @inbounds for j in axes (A, 2 )
429- dja = D. diag[j] * alpha
430- @simd for i in axes (A, 1 )
431- out[i,j] = A[i,j] * dja
432- end
433- end
434- else
435- @inbounds for j in axes (A, 2 )
436- dja = D. diag[j] * alpha
437- @simd for i in axes (A, 1 )
438- out[i,j] = A[i,j] * dja + out[i,j] * beta
439- end
440- end
441- end
475+ __muldiag_nonzeroalpha! (out, A, D, _add)
442476 end
443477 return out
444478end
445- function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
479+
480+ @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
446481 d1 = D1. diag
447482 d2 = D2. diag
483+ outd = out. diag
484+ @inbounds @simd for i in eachindex (d1, d2, outd)
485+ _modify! (_add, d1[i] * d2[i], outd, i)
486+ end
487+ out
488+ end
489+ function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
448490 alpha, beta = _add. alpha, _add. beta
449491 if iszero (alpha)
450492 _rmul_or_fill! (out. diag, beta)
451493 else
452- if bis0
453- @inbounds @simd for i in eachindex (out. diag)
454- out. diag[i] = d1[i] * d2[i] * alpha
455- end
456- else
457- @inbounds @simd for i in eachindex (out. diag)
458- out. diag[i] = d1[i] * d2[i] * alpha + out. diag[i] * beta
459- end
460- end
494+ __muldiag_nonzeroalpha! (out, D1, D2, _add)
461495 end
462496 return out
463497end
464- function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
465- require_one_based_indexing (out)
466- alpha, beta = _add. alpha, _add. beta
467- mA = size (D1, 1 )
498+ @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
468499 d1 = D1. diag
469500 d2 = D2. diag
501+ @inbounds @simd for i in eachindex (d1, d2)
502+ _modify! (_add, d1[i] * d2[i], out, (i,i))
503+ end
504+ out
505+ end
506+ function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1} ) where {ais1}
507+ require_one_based_indexing (out)
508+ alpha, beta = _add. alpha, _add. beta
470509 _rmul_or_fill! (out, beta)
471510 if ! iszero (alpha)
472- @inbounds @simd for i in 1 : mA
473- out[i,i] += d1[i] * d2[i] * alpha
474- end
511+ _add_bis1 = MulAddMul {ais1,false,typeof(alpha),Bool} (alpha,true )
512+ __muldiag_nonzeroalpha! (out, D1, D2, _add_bis1)
475513 end
476514 return out
477515end
@@ -658,31 +696,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
658696 @eval $ fun (A:: $Tri , D:: Diagonal ) = $ Tri ($ fun (A. data, D))
659697 @eval $ fun (A:: $UTri , D:: Diagonal ) = $ Tri (_setdiag! ($ fun (A. data, D), $ f, D. diag))
660698 end
699+ @eval * (A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
700+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
701+ @eval * (A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
702+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
661703 for (fun, f) in zip ((:* , :lmul! , :ldiv! , :\ ), (:identity , :identity , :inv , :inv ))
662704 @eval $ fun (D:: Diagonal , A:: $Tri ) = $ Tri ($ fun (D, A. data))
663705 @eval $ fun (D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! ($ fun (D, A. data), $ f, D. diag))
664706 end
707+ @eval * (D:: Diagonal , A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
708+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
709+ @eval * (D:: Diagonal , A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
710+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
665711 # 3-arg ldiv!
666712 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $Tri ) = $ Tri (ldiv! (C. data, D, A. data))
667713 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! (ldiv! (C. data, D, A. data), inv, D. diag))
668- # 3-arg mul! is disambiguated in special.jl
669- # 5-arg mul!
670- @eval _mul! (C:: $Tri , D:: Diagonal , A:: $Tri , _add) = $ Tri (mul! (C. data, D, A. data, _add. alpha, _add. beta))
671- @eval function _mul! (C:: $Tri , D:: Diagonal , A:: $UTri , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
672- α, β = _add. alpha, _add. beta
673- iszero (α) && return _rmul_or_fill! (C, β)
674- diag′ = bis0 ? nothing : diag (C)
675- data = mul! (C. data, D, A. data, α, β)
676- $ Tri (_setdiag! (data, _add, D. diag, diag′))
677- end
678- @eval _mul! (C:: $Tri , A:: $Tri , D:: Diagonal , _add) = $ Tri (mul! (C. data, A. data, D, _add. alpha, _add. beta))
679- @eval function _mul! (C:: $Tri , A:: $UTri , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
680- α, β = _add. alpha, _add. beta
681- iszero (α) && return _rmul_or_fill! (C, β)
682- diag′ = bis0 ? nothing : diag (C)
683- data = mul! (C. data, A. data, D, α, β)
684- $ Tri (_setdiag! (data, _add, D. diag, diag′))
685- end
686714end
687715
688716@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
0 commit comments