diff --git a/.gitignore b/.gitignore index 29126e4..4b2fb27 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ docs/site/ # committed for packages, but should be committed for applications that require a static # environment. Manifest.toml +.vscode/ \ No newline at end of file diff --git a/src/parallel_csr_mv.jl b/src/parallel_csr_mv.jl index 301d53f..e7122c4 100644 --- a/src/parallel_csr_mv.jl +++ b/src/parallel_csr_mv.jl @@ -1,5 +1,21 @@ +struct Range <: AbstractVector{Int} + start::Int + stop::Int + function Range(start::Int,stop::Int) + new(start-1,stop) + end +end + +Base.length(range::Range) = (range.stop-range.start) # includes both ends +Base.size(range::Range) = (length(range),) + +@inline function Base.getindex(range::Range,index::Int) + @boundscheck (0 < index ≤ length(range)) || throw(BoundsError("attempting to access the $(range.stop-range.start)-element Range at index [$(index)]")) + return index + range.start +end + mutable struct Coordinate x::Int @@ -14,12 +30,12 @@ end # # a -> row-end offsets so really pass in a[2:end] # b -> "natural" numbers -function merge_path_search(diagonal::Int, a_len::Int, b_len::Int, a, b) +function merge_path_search(diagonal::Int, a_len::Int, b_len::Int, a::AbstractVector, b::AbstractVector) # Diagonal search range (in x coordinate space) x_min = max(diagonal - b_len, 0) x_max = min(diagonal, a_len) # 2D binary-search along diagonal search range - while (x_min < x_max) + @inbounds while (x_min < x_max) pivot = (x_min + x_max) >> 1 if (a[pivot + 1] <= b[diagonal - pivot]) x_min = pivot + 1 @@ -29,8 +45,8 @@ function merge_path_search(diagonal::Int, a_len::Int, b_len::Int, a, b) end return Coordinate( - min(x_min, a_len), - diagonal - x_min + min(x_min, a_len) + 1, + diagonal - x_min + 1 ) end @@ -48,7 +64,7 @@ function merge_csr_mv!(α::Number,A::AbstractSparseMatrixCSC, input::StridedVect # nrows = length(cp) - 1 can give the wrong number of rows! nrows = A.n - nz_indices = collect(1:nnz) + nz_indices = Range(1,nnz) row_end_offsets = cp[2:end] # nzval ordering is diff for diff formats num_merge_items = nnz + nrows # preserve the dimensions of the original matrix @@ -69,39 +85,38 @@ function merge_csr_mv!(α::Number,A::AbstractSparseMatrixCSC, input::StridedVect # Consume merge items, whole rows first running_total = zero(eltype(output)) - while thread_coord.x < thread_coord_end.x - @inbounds while thread_coord.y < row_end_offsets[thread_coord.x + 1] - 1 - @inbounds running_total += op(nzv[thread_coord.y + 1]) * input[rv[thread_coord.y + 1]] + @inbounds while thread_coord.x < thread_coord_end.x + while thread_coord.y < row_end_offsets[thread_coord.x] + running_total += op(nzv[thread_coord.y]) * input[rv[thread_coord.y]] thread_coord.y += 1 end - @inbounds output[thread_coord.x + 1] += α * running_total + output[thread_coord.x] += α * running_total running_total = zero(eltype(output)) thread_coord.x += 1 end # May have thread end up partially consuming a row. # Save result form partial consumption and do one pass at the end to add it back to y - while thread_coord.y < thread_coord_end.y - @inbounds running_total += op(nzv[thread_coord.y + 1]) * input[rv[thread_coord.y + 1]] + @inbounds while thread_coord.y < thread_coord_end.y + running_total += op(nzv[thread_coord.y]) * input[rv[thread_coord.y]] thread_coord.y += 1 end # Save carry-outs - @inbounds row_carry_out[tid] = thread_coord_end.x + 1 + @inbounds row_carry_out[tid] = thread_coord_end.x @inbounds value_carry_out[tid] = running_total end - for tid in 1:num_threads - @inbounds if row_carry_out[tid] <= nrows - @inbounds output[row_carry_out[tid]] += α * value_carry_out[tid] + @inbounds for tid in 1:num_threads + if row_carry_out[tid] <= nrows + output[row_carry_out[tid]] += α * value_carry_out[tid] end end end - # C = adjoint(A)Bα + Cβ # C = transpose(A)B + Cβ # C = xABα + Cβ @@ -124,4 +139,5 @@ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose)) # end of @eval macro end # end of for loop -end + +end \ No newline at end of file diff --git a/test/merge_csr_mv!.jl b/test/merge_csr_mv!.jl index 0eda52c..5375e85 100644 --- a/test/merge_csr_mv!.jl +++ b/test/merge_csr_mv!.jl @@ -1,12 +1,28 @@ using Test -using ParallelMergeCSR: merge_csr_mv! +using ParallelMergeCSR: merge_csr_mv!, Range using SparseArrays + +@testset "Range" begin + + range = Range(1,10) + + @test length(range) == 10 + @test size(range,1) == 10 + @test size(range,2) == 1 + + for i in 1:10 + @test range[i] == i + end + + @test_throws BoundsError range[11] + +end + ## NOTE: Sparse matrices are converted to dense form in the @test's ## considering that our redefinition of SparseArrays.mul! seems to ## interfere @testset "Extreme Cases" begin - @testset "Singleton (Real)" begin A = sparse(reshape([1], 1, 1))