Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions test/merge_csr_mv!.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,35 @@ using SparseArrays: sprand, sparse
end

@testset "Extreme Cases" begin

@testset "Singleton (Real)" begin

A = sparse(reshape([1], 1, 1))

x = rand(1)

y = zeros(size(A, 1))
y = rand(size(A, 1))
y_original = deepcopy(y)

merge_csr_mv!(0.3, A, x, y, transpose)

@test A * x * 0.3 == y
@test (A * x * 0.3) + y_original ≈ y

end

@testset "Singleton (Complex)" begin

A = sparse(reshape([1.0+2.5im], 1, 1))

x = rand(1)

y = zeros(eltype(A), size(A, 1))
y = rand(eltype(A), size(A, 1))
y_original = deepcopy(y)

merge_csr_mv!(10.1, A, x, y, adjoint)

@test adjoint(Matrix(A)) * x * 10.1 == y
@test (adjoint(Matrix(A)) * x * 10.1) + y_original ≈ y

end

@testset "Single row (Real)" begin
Expand All @@ -52,11 +59,13 @@ end
# x needs to be 10 x 1
x = rand(size(A, 1))

y = zeros(eltype(A), 1)
y = rand(eltype(A), 1)
y_original = deepcopy(y)

merge_csr_mv!(1.1, A, x, y, transpose)

@test (transpose(A) * x) * 1.1 ≈ y
@test ((transpose(A) * x) * 1.1) + y_original ≈ y

end

@testset "Single row (Complex)" begin
Expand All @@ -67,10 +76,12 @@ end
x = rand(eltype(A), size(A, 1))

y = zeros(eltype(A), 1)
y_original = deepcopy(y)

merge_csr_mv!(1.1, A, x, y, transpose)

@test (transpose(A) * x) * 1.1 ≈ y
@test ((transpose(A) * x) * 1.1) + y_original ≈ y

end


Expand All @@ -88,30 +99,35 @@ end
end

@testset "Square (Real)" begin

A = sprand(10,10,0.3)

# 10 x 1
x = rand(10)

# 10 x 1
y = zeros(size(x))
y = rand(size(x)...)
y_original = deepcopy(y)

merge_csr_mv!(1.1, A, x, y, adjoint)

@test (adjoint(A) * x) * 1.1 ≈ y
@test ((adjoint(A) * x) * 1.1) + y_original ≈ y

end

@testset "Square (Complex)" begin

A = sprand(Complex{Float64}, 10, 10, 0.3)

x = 10 * rand(Complex{Float64}, 10)

y = zeros(eltype(A), size(x))
y = rand(eltype(A), size(x)...)
y_original = deepcopy(y)

merge_csr_mv!(1.1, A, x, y, adjoint)

@test (adjoint(A) * x) * 1.1 ≈ y
@test ((adjoint(A) * x) * 1.1) + y_original ≈ y

end

@testset "4x6 (Real)" begin
Expand All @@ -128,13 +144,14 @@ end
x = [5,2,3,1]

# create empty solution
y = zeros(Int64, size(A, 2))
y = rand(1:20, size(A, 2))
y_original = deepcopy(y)

# multiply
merge_csr_mv!(2.0, A, x, y, adjoint)

@test (adjoint(m) * x * 2.0) + y_original ≈ y

@test adjoint(m) * x * 2.0 == y
end

@testset "4 x 6 (Complex)" begin
Expand All @@ -151,12 +168,13 @@ end
x = 22.1 * rand(Complex{Float64}, 4)

# create empty solution
y = zeros(eltype(x), size(A, 2))
y = rand(eltype(x), size(A, 2))
y_original = deepcopy(y)

# multiply
merge_csr_mv!(2.0, A, x, y, adjoint)

@test adjoint(m) * x * 2.0 == y
@test (adjoint(m) * x * 2.0) + y_original ≈ y

end

Expand All @@ -168,11 +186,13 @@ end
x = rand(100)

# create empty solution
y = zeros(size(A, 1))
y = rand(size(A, 1))
y_original = deepcopy(y)

merge_csr_mv!(3.0, A, x, y, transpose)

@test transpose(A) * x * 3 ≈ y
@test (transpose(A) * x * 3) + y_original ≈ y

end

@testset "100x100 (Complex)" begin
Expand All @@ -183,11 +203,13 @@ end
x = rand(Complex{Float64}, 100)

# create empty solution
y = zeros(eltype(x), size(A, 1))
y = rand(eltype(x), size(A, 1))
y_original = deepcopy(y)

merge_csr_mv!(3.0, A, x, y, transpose)

@test transpose(A) * x * 3 ≈ y
@test (transpose(A) * x * 3) + y_original ≈ y

end

#=
Expand All @@ -204,14 +226,15 @@ end

α = 9.2

Y = zeros(2, 4)
Y = rand(2, 4)
Y_original = deepcopy(Y)

for (idx, col) in enumerate(eachcol(X))
Y_view = @view Y[:, idx]
merge_csr_mv!(α, A, col, Y_view, transpose)
end

@test transpose(A) * X * 9.2 ≈ Y
@test (transpose(A) * X * 9.2) + Y_original ≈ Y

end

Expand All @@ -226,12 +249,13 @@ end
α = 9.2

Y = zeros(eltype(A), 2, 4)
Y_original = deepcopy(Y)

for (idx, col) in enumerate(eachcol(X))
Y_view = @view Y[:, idx]
merge_csr_mv!(α, A, col, Y_view, adjoint)
end

@test adjoint(A) * X * 9.2 ≈ Y
@test (adjoint(A) * X * 9.2) + Y_original ≈ Y

end
7 changes: 6 additions & 1 deletion test/mul!.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using SparseArrays
ParallelMergeCSR.mul!(C, A, B, α, β)
SparseArrays.mul!(C_copy, A, B, α, β)
@test C ≈ C_copy

end

@testset "Adjoint Complex Matrix" begin
Expand Down Expand Up @@ -51,6 +52,7 @@ using SparseArrays
ParallelMergeCSR.mul!(C, A, B, α, β)
SparseArrays.mul!(C_copy, A, B, α, β)
@test C ≈ C_copy

end

@testset "Transpose Complex Matrix" begin
Expand All @@ -66,7 +68,9 @@ using SparseArrays
ParallelMergeCSR.mul!(C, A, B, α, β)
SparseArrays.mul!(C_copy, A, B, α, β)
@test C ≈ C_copy

end

end

# trigger merge_csr_mv! in this repo, does not default to mul! somewhere else
Expand Down Expand Up @@ -121,7 +125,6 @@ end
SparseArrays.mul!(C_copy, A, B, α, β)
@test C ≈ C_copy


end

@testset "Transpose Square Complex" begin
Expand All @@ -137,6 +140,7 @@ end
ParallelMergeCSR.mul!(C, A, B, α, β)
SparseArrays.mul!(C_copy, A, B, α, β)
@test C ≈ C_copy

end

@testset "Transpose Rectangular Real" begin
Expand Down Expand Up @@ -172,4 +176,5 @@ end
@test C ≈ C_copy

end

end