Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml
.vscode/
50 changes: 33 additions & 17 deletions src/parallel_csr_mv.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@


struct Range <: AbstractVector{Int}
start::Int
stop::Int
function Range(start::Int,stop::Int)
new(start-1,stop)
end
end

Base.length(range::Range) = (range.stop-range.start) # includes both ends
Base.size(range::Range) = (length(range),)

@inline function Base.getindex(range::Range,index::Int)
@boundscheck (0 < index ≤ length(range)) || throw(BoundsError("attempting to access the $(range.stop-range.start)-element Range at index [$(index)]"))
return index + range.start
end


mutable struct Coordinate
x::Int
Expand All @@ -14,12 +30,12 @@ end
#
# a -> row-end offsets so really pass in a[2:end]
# b -> "natural" numbers
function merge_path_search(diagonal::Int, a_len::Int, b_len::Int, a, b)
function merge_path_search(diagonal::Int, a_len::Int, b_len::Int, a::AbstractVector, b::AbstractVector)
# Diagonal search range (in x coordinate space)
x_min = max(diagonal - b_len, 0)
x_max = min(diagonal, a_len)
# 2D binary-search along diagonal search range
while (x_min < x_max)
@inbounds while (x_min < x_max)
pivot = (x_min + x_max) >> 1
if (a[pivot + 1] <= b[diagonal - pivot])
x_min = pivot + 1
Expand All @@ -29,8 +45,8 @@ function merge_path_search(diagonal::Int, a_len::Int, b_len::Int, a, b)
end

return Coordinate(
min(x_min, a_len),
diagonal - x_min
min(x_min, a_len) + 1,
diagonal - x_min + 1
)

end
Expand All @@ -48,7 +64,7 @@ function merge_csr_mv!(α::Number,A::AbstractSparseMatrixCSC, input::StridedVect
# nrows = length(cp) - 1 can give the wrong number of rows!
nrows = A.n

nz_indices = collect(1:nnz)
nz_indices = Range(1,nnz)
row_end_offsets = cp[2:end] # nzval ordering is diff for diff formats
num_merge_items = nnz + nrows # preserve the dimensions of the original matrix

Expand All @@ -69,39 +85,38 @@ function merge_csr_mv!(α::Number,A::AbstractSparseMatrixCSC, input::StridedVect

# Consume merge items, whole rows first
running_total = zero(eltype(output))
while thread_coord.x < thread_coord_end.x
@inbounds while thread_coord.y < row_end_offsets[thread_coord.x + 1] - 1
@inbounds running_total += op(nzv[thread_coord.y + 1]) * input[rv[thread_coord.y + 1]]
@inbounds while thread_coord.x < thread_coord_end.x
while thread_coord.y < row_end_offsets[thread_coord.x]
running_total += op(nzv[thread_coord.y]) * input[rv[thread_coord.y]]
thread_coord.y += 1
end

@inbounds output[thread_coord.x + 1] += α * running_total
output[thread_coord.x] += α * running_total
running_total = zero(eltype(output))
thread_coord.x += 1
end

# May have thread end up partially consuming a row.
# Save result form partial consumption and do one pass at the end to add it back to y
while thread_coord.y < thread_coord_end.y
@inbounds running_total += op(nzv[thread_coord.y + 1]) * input[rv[thread_coord.y + 1]]
@inbounds while thread_coord.y < thread_coord_end.y
running_total += op(nzv[thread_coord.y]) * input[rv[thread_coord.y]]
thread_coord.y += 1
end

# Save carry-outs
@inbounds row_carry_out[tid] = thread_coord_end.x + 1
@inbounds row_carry_out[tid] = thread_coord_end.x
@inbounds value_carry_out[tid] = running_total

end

for tid in 1:num_threads
@inbounds if row_carry_out[tid] <= nrows
@inbounds output[row_carry_out[tid]] += α * value_carry_out[tid]
@inbounds for tid in 1:num_threads
if row_carry_out[tid] <= nrows
output[row_carry_out[tid]] += α * value_carry_out[tid]
end
end

end


# C = adjoint(A)Bα + Cβ
# C = transpose(A)B + Cβ
# C = xABα + Cβ
Expand All @@ -124,4 +139,5 @@ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
# end of @eval macro
end
# end of for loop
end

end
20 changes: 18 additions & 2 deletions test/merge_csr_mv!.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
using Test
using ParallelMergeCSR: merge_csr_mv!
using ParallelMergeCSR: merge_csr_mv!, Range
using SparseArrays


@testset "Range" begin

range = Range(1,10)

@test length(range) == 10
@test size(range,1) == 10
@test size(range,2) == 1

for i in 1:10
@test range[i] == i
end

@test_throws BoundsError range[11]

end

## NOTE: Sparse matrices are converted to dense form in the @test's
## considering that our redefinition of SparseArrays.mul! seems to
## interfere
@testset "Extreme Cases" begin

@testset "Singleton (Real)" begin
A = sparse(reshape([1], 1, 1))

Expand Down