From 94801500d352995e0b80448df6b1b25de060dd63 Mon Sep 17 00:00:00 2001 From: Martin Cornejo Date: Mon, 5 Jun 2023 14:44:14 +0200 Subject: [PATCH 1/3] Add `directsum` --- Project.toml | 3 ++- src/TensorCore.jl | 31 +++++++++++++++++++++++++++++++ test/runtests.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6f2a06e..beecbcc 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" julia = "1" [extras] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "SparseArrays"] diff --git a/src/TensorCore.jl b/src/TensorCore.jl index e5ac156..e1e6500 100644 --- a/src/TensorCore.jl +++ b/src/TensorCore.jl @@ -5,6 +5,7 @@ using LinearAlgebra export ⊙, hadamard, hadamard! export ⊗, tensor, tensor! export ⊡, boxdot, boxdot! +export ⊕, directsum """ hadamard(a, b) @@ -282,6 +283,36 @@ else end +""" + directsum(A, B) + A ⊕ B + + The direct sum of matrices `A` of size m × n and `B` of size p × q constructs a block matrix of size (m + p)×(n + q), + with `A` and `B` as diagonal elements and zero matrices for the off-diagonal blocks. + + `A ⊕ B = [A 0; 0 B]` + + # Examples + ```jldoctest; setup=:(using TensorCore) + julia> A = [1 3 2; 2 3 1]; B = [1 6; 0 1]; + + julia> A ⊕ B + 4×5 Matrix{Int64}: + 1 3 2 0 0 + 2 3 1 0 0 + 0 0 0 1 6 + 0 0 0 0 1 + ``` +""" +function directsum(A::AbstractArray, B::AbstractArray) + Z1 = zeros(Bool, size(A, 1), size(B, 2)) # upper right + Z2 = zeros(Bool, size(B, 1), size(A, 2)) # lower left + + return [A Z1; Z2 B] +end + +const ⊕ = directsum + """ TensorCore._adjoint(A) diff --git a/test/runtests.jl b/test/runtests.jl index 9279090..754d689 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using TensorCore using LinearAlgebra +using SparseArrays using Test @testset "Ambiguities" begin @@ -279,6 +280,35 @@ end @test boxdot!(similar(c,1), c', d) == [dot(c, d)] end +@testset "directsum" begin + A = rand(3, 2) + B = rand(2, 4) + b = rand(2) + + # size + @test size(A ⊕ B) == (5, 6) + @test size(A ⊕ B') == (7, 4) + @test size(A ⊕ b) == (5, 3) + @test size(A ⊕ b') == (4, 4) + + # eltype + eltypes = [(ComplexF64, Float64), (Float64, Float32), (Float32, Int), (Int, Bool)] + for (Ta, Tb) in eltypes + A = rand(Ta, 2, 2) + B = rand(Tb, 2, 2) + C = A ⊕ B + @test eltype(C) == Ta + end + + # sparse + A = sprand(4, 4, 0.5) + B = sprand(2, 2, 0.5) + B´ = Array(B) + + @test A ⊕ B isa SparseMatrixCSC + @test A ⊕ B´ isa SparseMatrixCSC +end + @testset "_adjoint" begin A = [1 2+im; 3 4im] E3 = cat(A, -A, dims=3) From d91ce1cf9ffcfaa3babaa274cead8026d8788958 Mon Sep 17 00:00:00 2001 From: Martin Cornejo Date: Mon, 5 Jun 2023 14:53:46 +0200 Subject: [PATCH 2/3] Update README --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index effda00..0d412b9 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,8 @@ [![Codecov](https://codecov.io/gh/JuliaMath/TensorCore.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaMath/TensorCore.jl) This package is intended as a lightweight foundation for tensor operations across the Julia ecosystem. -Currently it exports three operations: +Currently it exports four operations: +* `directsum` of matrices, with unicode operator `⊕`, * `hadamard` elementwise multiplication, with unicode operator `⊙`, * `tensor` product preserves all dimensions, operator `⊗`, and * `boxdot` contracts neighbouring dimensions, named after the unicode `⊡`. @@ -15,6 +16,13 @@ julia> using TensorCore julia> A = [1 2 3; 4 5 6]; B = [7 8 9; 0 10 20]; +julia> A ⊕ B # directsum(A, B) +4×6 Matrix{Int64}: + 1 2 3 0 0 0 + 4 5 6 0 0 0 + 0 0 0 7 8 9 + 0 0 0 0 10 20 + julia> A ⊙ B # hadamard(A, B) 2×3 Matrix{Int64}: 7 16 27 From c14dcec28afbb429a8ae002f1ad40a885ba3619a Mon Sep 17 00:00:00 2001 From: Martin Cornejo Date: Mon, 5 Jun 2023 15:00:01 +0200 Subject: [PATCH 3/3] Update docs --- docs/src/index.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index d6f2eb3..eb6485f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,7 +1,7 @@ # TensorCore.jl This package is intended as a lightweight foundation for tensor operations across the Julia ecosystem. -Currently it exports three operations, `hadamard`, `tensor` and `boxdot`, and corresponding unicode operators `⊙`, `⊗` and `⊡`. +Currently it exports four operations, `hadamard`, `tensor`, `boxdot` and `directsum`, and corresponding unicode operators `⊙`, `⊗`, `⊡` and `⊕`. ## API @@ -15,4 +15,5 @@ tensor tensor! boxdot boxdot! +directsum ```