diff --git a/Project.toml b/Project.toml index 241c7eb5c..f74fe823f 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" +NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" diff --git a/src/Dagger.jl b/src/Dagger.jl index c0cb23526..1f5f2afd5 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -8,7 +8,7 @@ import MemPool import MemPool: DRef, FileRef, poolget, poolset import Base: collect, reduce - +import NextLA import LinearAlgebra import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric import Random @@ -109,6 +109,7 @@ include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") include("array/lu.jl") +include("array/qr.jl") import KernelAbstractions, Adapt diff --git a/src/array/alloc.jl b/src/array/alloc.jl index a95e070ae..51215307a 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -1,6 +1,7 @@ import Base: cat import Random: MersenneTwister export partition +import LinearAlgebra: UniformScaling mutable struct AllocateArray{T,N} <: ArrayOp{T,N} eltype::Type{T} @@ -83,17 +84,15 @@ function stage(ctx, a::AllocateArray) chunks = map(CartesianIndices(a.domainchunks)) do I x = a.domainchunks[I] i = LinearIndices(a.domainchunks)[I] - args = a.want_index ? (i, size(x)) : (size(x),) - if isnothing(a.procgrid) scope = get_compute_scope() else scope = ExactScope(a.procgrid[CartesianIndex(mod1.(Tuple(I), size(a.procgrid))...)]) end if a.want_index - Dagger.@spawn compute_scope=scope allocate_array(a.f, a.eltype, i, args...) + Dagger.@spawn compute_scope=scope allocate_array(a.f, a.eltype, i, size(x)) else - Dagger.@spawn compute_scope=scope allocate_array(a.f, a.eltype, args...) + Dagger.@spawn compute_scope=scope allocate_array(a.f, a.eltype, size(x)) end end return DArray(a.eltype, a.domain, a.domainchunks, chunks, a.partitioning) @@ -159,6 +158,7 @@ Base.zeros(p::BlocksOrAuto, dims::Dims; assignment::AssignmentType = :arbitrary) Base.zeros(::AutoBlocks, eltype::Type, dims::Dims; assignment::AssignmentType = :arbitrary) = zeros(auto_blocks(dims), eltype, dims; assignment) + function Base.zero(x::DArray{T,N}) where {T,N} dims = ntuple(i->x.domain.indexes[i].stop, N) sd = first(x.subdomains) @@ -167,6 +167,39 @@ function Base.zero(x::DArray{T,N}) where {T,N} return _to_darray(a) end +function _allocate_diag(i,T, _dims, subdomain) + sA = zeros(T, _dims) + if !isempty(intersect(subdomain.indexes[1], subdomain.indexes[2])) + for j in range(1, min(_dims[1], _dims[2])) + sA[j,j] = one(T) + end + end + return sA +end + +function DMatrix(p::BlocksOrAuto, s::UniformScaling, dims::Dims, assignment::AssignmentType = :arbitrary) + d = ArrayDomain(map(x->1:x, dims)) + sd = partition(p, d) + T = eltype(s) + a = AllocateArray(T, (i, T, _dims) -> _allocate_diag(i, T, _dims, sd[i]), true, d, partition(p, d), p, assignment) + return _to_darray(a) +end +DMatrix(p::BlocksOrAuto, s::UniformScaling, dims::Integer...; assignment::AssignmentType = :arbitrary) = + DMatrix(p, s, dims; assignment) +DMatrix(::AutoBlocks, s::UniformScaling, dims::Dims; assignment::AssignmentType = :arbitrary) = + DMatrix(auto_blocks(dims), s::UniformScaling, dims; assignment) + +function DArray{T}(p::BlocksOrAuto, ::UndefInitializer, dims::Dims; assignment::AssignmentType = :arbitrary) where {T} + d = ArrayDomain(map(x->1:x, dims)) + a = AllocateArray(T, AllocateUndef{T}(), false, d, partition(p, d), p, assignment) + return _to_darray(a) +end + +DArray{T}(p::BlocksOrAuto, ::UndefInitializer, dims::Integer...; assignment::AssignmentType = :arbitrary) where {T} = + DArray{T}(p, undef, dims; assignment) +DArray{T}(p::AutoBlocks, ::UndefInitializer, dims::Dims; assignment::AssignmentType = :arbitrary) where {T} = + DArray{T}(auto_blocks(dims), undef, dims; assignment) + function Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N} d = ArrayDomain(Base.index_shape(A)) dc = partition(p, d) diff --git a/src/array/copy.jl b/src/array/copy.jl index 30bfeac4d..e3c9d4cfa 100644 --- a/src/array/copy.jl +++ b/src/array/copy.jl @@ -124,7 +124,6 @@ function darray_copyto!(B::DArray{TB,NB}, A::DArray{TA,NA}, Binds=parentindices( end end end - return B end function copyto_view!(Bpart, Brange, Apart, Arange) diff --git a/src/array/darray.jl b/src/array/darray.jl index 8dbd0e719..80fdade75 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -65,6 +65,9 @@ ArrayDomain((1:15), (1:80)) alignfirst(a::ArrayDomain) = ArrayDomain(map(r->1:length(r), indexes(a))) +alignfirst(a::CartesianIndices{N}) where N = + ArrayDomain(map(r->1:length(r), a.indices)) + function size(a::ArrayDomain, dim) idxs = indexes(a) length(idxs) < dim ? 1 : length(idxs[dim]) @@ -365,7 +368,7 @@ function group_indices(cumlength, idxs::AbstractRange) end _cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x) -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}; slice::Bool=false) where N groups = map(group_indices, subdmns.cumlength, indexes(d)) sz = map(length, groups) pieces = Array{Any}(undef, sz) @@ -379,15 +382,17 @@ function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d: out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength) return pieces, out_dmn end -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}) where {N,S} +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}; slice::Bool=false) where {N,S} if S != 1 throw(BoundsError(A, d.indexes)) end inds = CartesianIndices(A)[d.indexes...] new_d = ntuple(i->first(inds).I[i]:last(inds).I[i], N) - return lookup_parts(A, ps, subdmns, ArrayDomain(new_d)) + return lookup_parts(A, ps, subdmns, ArrayDomain(new_d); slice) end +lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::CartesianIndices; slice::Bool=false) where {N} = lookup_parts(A, ps, subdmns, ArrayDomain(d.indices); slice) + """ Base.fetch(c::DArray) diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 3d6fe56b6..74f8d5b77 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,4 +1,4 @@ -### getindex +#get index const GETINDEX_CACHE = TaskLocalValue{Dict{Tuple,Any}}(()->Dict{Tuple,Any}()) const GETINDEX_CACHE_SIZE = ScopedValue{Int}(0) @@ -36,6 +36,7 @@ with_index_caching(f, size::Integer=1) = with(f, GETINDEX_CACHE_SIZE=>size) # Return the value return part[offset_idx...] end + function partition_for(A::DArray, idx::NTuple{N,Int}) where N part_idx = zeros(Int, N) offset_idx = zeros(Int, N) diff --git a/src/array/qr.jl b/src/array/qr.jl new file mode 100644 index 000000000..6ac390963 --- /dev/null +++ b/src/array/qr.jl @@ -0,0 +1,257 @@ +export geqrf!, porgqr!, pormqr!, cageqrf! +import LinearAlgebra: QRCompactWY, AdjointQ, BlasFloat, QRCompactWYQ, AbstractQ, StridedVecOrMat, I +import Base.:* + +(*)(Q::QRCompactWYQ{T, M}, b::Number) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b +(*)(b::Number, Q::QRCompactWYQ{T, M}) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b + +(*)(Q::AdjointQ{T, QRCompactWYQ{T, M, C}}, b::Number) where {T<:Number, M<:DMatrix{T}, C<:M} = DMatrix(Q) * b +(*)(b::Number, Q::AdjointQ{T, QRCompactWYQ{T, M, C}}) where {T<:Number, M<:DMatrix{T}, C<:M} = DMatrix(Q) * b + +LinearAlgebra.lmul!(B::QRCompactWYQ{T, <:DMatrix{T}}, A::DMatrix{T}) where {T} = pormqr!('L', 'N', B.factors, B.T, A) +function LinearAlgebra.lmul!(B::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DMatrix{T}}}, A::Dagger.DMatrix{T}) where {T} + trans = T <: Complex ? 'C' : 'T' + pormqr!('L', trans, B.Q.factors, B.Q.T, A) +end + +LinearAlgebra.rmul!(A::Dagger.DMatrix{T}, B::QRCompactWYQ{T, <:Dagger.DMatrix{T}}) where {T} = pormqr!('R', 'N', B.factors, B.T, A) +function LinearAlgebra.rmul!(A::Dagger.DMatrix{T}, B::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DMatrix{T}}}) where {T} + trans = T <: Complex ? 'C' : 'T' + pormqr!('R', trans, B.Q.factors, B.Q.T, A) +end + +function Dagger.DMatrix(Q::QRCompactWYQ{T, <:Dagger.DMatrix{T}}) where {T} + DQ = DMatrix(Q.factors.partitioning, I*one(T), size(Q)) + porgqr!('N', Q.factors, Q.T, DQ) + return DQ +end + +function Dagger.DMatrix(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DMatrix{T}}}) where {T} + DQ = DMatrix(AQ.Q.factors.partitioning, I*one(T), size(AQ)) + trans = T <: Complex ? 'C' : 'T' + porgqr!(trans, AQ.Q.factors, AQ.Q.T, DQ) + return DQ +end + +Base.collect(Q::QRCompactWYQ{T, <:Dagger.DMatrix{T}}) where {T} = collect(Dagger.DMatrix(Q)) +Base.collect(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DMatrix{T}}}) where {T} = collect(Dagger.DMatrix(AQ)) + +function _repartition_pormqr(A, Tm, C, side::Char, trans::Char) + partA = A.partitioning.blocksize + partTm = Tm.partitioning.blocksize + partC = C.partitioning.blocksize + + # The pormqr! kernels assume that the number of row tiles (index k) + # matches between the reflector matrix A and the target matrix C. + # Adjust C's block size accordingly but avoid reshaping A or Tm, + # as their chunking encodes the factorisation structure. + partC_new = partC + if side == 'L' + # Q * C or Q' * C: C's row blocking must match A's row blocking. + partC_new = (partA[1], partC[2]) + else + # C * Q or C * Q': C's column blocking must match A's row blocking + # because the kernels iterate over the k index along columns. + partC_new = (partC[1], partA[1]) + end + + return Blocks(partA...), Blocks(partTm...), Blocks(partC_new...) +end + +function pormqr!(side::Char, trans::Char, A::Dagger.DMatrix{T}, Tm::Dagger.DMatrix{T}, C::Dagger.DMatrix{T}) where {T<:Number} + partA, partTm, partC = _repartition_pormqr(A, Tm, C, side, trans) + + return maybe_copy_buffered(A=>partA, Tm=>partTm, C=>partC) do A, Tm, C + return _pormqr_impl!(side, trans, A, Tm, C) + end +end + +function _pormqr_impl!(side::Char, trans::Char, A::Dagger.DMatrix{T}, Tm::Dagger.DMatrix{T}, C::Dagger.DMatrix{T}) where {T<:Number} + m, n = size(C) + Ac = A.chunks + Tc = Tm.chunks + Cc = C.chunks + + Amt, Ant = size(Ac) + Tmt, Tnt = size(Tc) + Cmt, Cnt = size(Cc) + minMT = min(Amt, Ant) + + Dagger.spawn_datadeps() do + if side == 'L' + if (trans == 'T' || trans == 'C') + for k in 1:minMT + for n in 1:Cnt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k,n])) + end + for m in k+1:Cmt, n in 1:Cnt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + if trans == 'N' + for k in minMT:-1:1 + for m in Cmt:-1:k+1, n in 1:Cnt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in 1:Cnt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k, n])) + end + end + end + else + if side == 'R' + if trans == 'T' || trans == 'C' + for k in minMT:-1:1 + for n in Cmt:-1:k+1, m in 1:Cmt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + for m in 1:Cmt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + end + end + if trans == 'N' + for k in 1:minMT + for m in 1:Cmt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + for n in k+1:Cmt, m in 1:Cmt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + end + end + end + end + end + return C +end + +function cageqrf!(A::Dagger.DMatrix{T}, Tm::Dagger.DMatrix{T}; static::Bool=true, traversal::Symbol=:inorder, p::Int64=1) where {T<: Number} + if p == 1 + return geqrf!(A, Tm; static, traversal) + end + Ac = A.chunks + mt, nt = size(Ac) + @assert mt % p == 0 "Number of tiles must be divisible by the number of domains" + mtd = Int64(mt/p) + Tc = Tm.chunks + proot = 1 + nxtmt = mtd + trans = T <: Complex ? 'C' : 'T' + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + if k > nxtmt + proot += 1 + nxtmt += mtd + end + for pt in proot:p + ibeg = 1 + (pt-1) * mtd + if pt == proot + ibeg = k + end + Dagger.@spawn NextLA.geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) + for n in k+1:nt + Dagger.@spawn NextLA.unmqr!('L', trans, Deps(Ac[ibeg, k], In(LowerTriangular)), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + end + for m in ibeg+1:(pt * mtd) + Dagger.@spawn NextLA.tsqrt!(Deps(Ac[ibeg, k], InOut(UpperTriangular)), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + for m in 1:ceil(Int64, log2(p-proot+1)) + p1 = proot + p2 = p1 + 2^(m-1) + while p2 ≤ p + i1 = 1 + (p1-1) * mtd + i2 = 1 + (p2-1) * mtd + if p1 == proot + i1 = k + end + Dagger.@spawn NextLA.ttqrt!(Deps(Ac[i1, k], InOut(UpperTriangular)), Deps(Ac[i2, k], InOut(UpperTriangular)), Out(Tc[i2, k])) + for n in k+1:nt + Dagger.@spawn NextLA.ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), Deps(Ac[i2, k], In(UpperTriangular)), In(Tc[i2, k])) + end + p1 += 2^m + p2 += 2^m + end + end + end + end +end + +function geqrf!(A::Dagger.DMatrix{T}, Tm::Dagger.DMatrix{T}; static::Bool=true, traversal::Symbol=:inorder) where {T<: Number} + Ac = A.chunks + mt, nt = size(Ac) + Tc = Tm.chunks + trans = T <: Complex ? 'C' : 'T' + + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + Dagger.@spawn NextLA.geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) + for n in k+1:nt + Dagger.@spawn NextLA.unmqr!('L', trans, Deps(Ac[k,k], In(LowerTriangular)), In(Tc[k,k]), InOut(Ac[k, n])) + end + for m in k+1:mt + Dagger.@spawn NextLA.tsqrt!(Deps(Ac[k, k], InOut(UpperTriangular)), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + end +end + +function porgqr!(trans::Char, A::Dagger.DMatrix{T}, Tm::Dagger.DMatrix{T}, Q::Dagger.DMatrix{T}; static::Bool=true, traversal::Symbol=:inorder) where {T<:Number} + Ac = A.chunks + Tc = Tm.chunks + Qc = Q.chunks + mt, nt = size(Ac) + qmt, qnt = size(Qc) + + Dagger.spawn_datadeps(;static, traversal) do + if trans == 'N' + for k in min(mt, nt):-1:1 + for m in qmt:-1:k + 1, n in k:qnt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in k:qnt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + end + else + for k in 1:min(mt, nt) + for n in 1:k + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + for m in k+1:qmt, n in 1:qnt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + end +end + +function meas_ws(A::Dagger.DMatrix{T}, ib::Int64) where {T<: Number} + mb, nb = A.partitioning.blocksize + m, n = size(A) + MT = (mod(m,nb)==0) ? floor(Int64, (m / mb)) : floor(Int64, (m / mb) + 1) + NT = (mod(n,nb)==0) ? floor(Int64,(n / nb)) : floor(Int64, (n / nb) + 1) * 2 + lm = ib * MT; + ln = nb * NT; + lm, ln +end + +function LinearAlgebra.qr!(A::Dagger.DMatrix{T}; ib::Int64=1, p::Int64=1) where {T<:Number} + lm, ln = meas_ws(A, ib) + nb = A.partitioning.blocksize[2] + Tm = DArray{T}(Blocks(ib, nb), undef, (lm, ln)) + cageqrf!(A, Tm; p=p) + return QRCompactWY(A, Tm); +end + + diff --git a/test/array/linalg/qr.jl b/test/array/linalg/qr.jl new file mode 100644 index 000000000..ad17e0823 --- /dev/null +++ b/test/array/linalg/qr.jl @@ -0,0 +1,37 @@ + @testset "Tile QR: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + ## Square matrices + A = rand(T, 128, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') + @test triu(collect(DR)) ≈ collect(DR) + ## Rectangular matrices (block and element wise) + # Tall Element and Block + A = rand(T, 128, 64) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') + @test triu(collect(DR)) ≈ collect(DR) + + # Wide Element and Block + A = rand(T, 64, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(16,16)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') + @test triu(collect(DR)) ≈ collect(DR) +end diff --git a/test/runtests.jl b/test/runtests.jl index 264eb4603..32ede0c9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,7 +30,8 @@ tests = [ ("Array - LinearAlgebra - Core", "array/linalg/core.jl"), ("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"), ("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"), - ("Array - LinearAlgebra - LU", "array/linalg/lu.jl"), + ("Array - LinearAlgebra - LU", "array/linalg/lu.jl"), + ("Array- LinearALgebra - QR", "array/linalg/qr.jl"), ("Array - Random", "array/random.jl"), ("Array - Stencils", "array/stencil.jl"), ("Array - FFT", "array/fft.jl"),