Skip to content

julep: extended iteration API, with proof-of-principle for fixing performance of Cartesian iteration #16878

@timholy

Description

@timholy

This proposes a change in our iteration API (refs #9182, #9178, #6125) that narrows the gap on fixes #9080 (which seems to have gotten worse), following the convoluted process of discovery in #16035. Let's suppose this function:

function mysum(A)
    s = zero(eltype(A))
    for I in eachindex(A)
        @inbounds s += A[I]
    end
    s
end

got expanded to something like this (note: I'm passing in the iterator as an argument so I can compare performance of existing and new implementations):

function itergeneral(A, iter)
    # We'll sum over the elements of A that lie within bounds of iter
    s = zero(eltype(A))
    state = start(iter)
    isdone = maybe_done(iter, state)
    if isdone
        isdone, state = done(iter, state, true)  # use dummy 3rd arg to indicate 2 returns
    end
    if !isdone
        while true
            item, state = next(iter, state)
            @inbounds s += A[item]
            if maybe_done(iter, state)
                isdone, state = done(iter, state, true)
                isdone && break
            end
        end
    end
    s
end

To avoid breakage of current code, we need the following fallback definitions:

import Base: start, next, done, getindex

# fallback definitions
maybe_done(iter, state) = true
# For 3-argument code when only the 2-argument version is defined
done(iter, state, ::Bool) = done(iter, state), state
# For explicit 2-argument code when only the 3-argument version is defined
done(iter, state) = done(iter, state, true)[1]

The only downside of this I see is that if no valid definition of done exists, now you get a StackOverflow rather than MethodError. Aside from that, this appears to be non-breaking. If the compiler knows that maybe_done always returns true for iter, then note that the code above simplifies to

function itergeneral(A, iter)
    # We'll sum over the elements of A that lie within bounds of iter
    s = zero(eltype(A))
    state = start(iter)
    isdone, state = done(iter, state, true)  # use dummy 3rd arg to indicate 2 returns
    if !isdone
        while true
            item, state = next(iter, state)
            @inbounds s += A[item]
            isdone, state = done(iter, state, true)
            isdone && break
        end
    end
    s
end

An advantage of this more symmetric interface is that either next or done can increment the state; therefore, this subsumes the nextval/nextstate split first proposed in #6125. But unlike the implementation in #9182, thanks to maybe_done this also introduces a key advantage for #9080: it ensures that typically (when dimension 1 has size larger than 1) there's only one branch per iteration, using two or more branches only when the "carry" operation needs to be performed.

Proof that this helps #9080 (with the previous definitions all loaded):

using Base: tail, @propagate_inbounds

immutable CartIdx{N}
    I::NTuple{N,Int}
    CartIdx(index::NTuple{N,Integer}) = new(index)
end
CartIdx{N}(index::NTuple{N,Integer}) = CartIdx{N}(index)
getindex(ci::CartIdx, d::Integer) = ci.I[d]
@propagate_inbounds getindex(A::AbstractArray, ci::CartIdx{2}) = A[ci.I[1], ci.I[2]]

immutable CartRange{I}
    start::I
    stop::I
end

@inline start(cr::CartRange) = cr.start
@inline next(cr::CartRange, state::CartIdx) = state, inc1(state)
@inline inc1(state) = CartIdx((state[1]+1, tail(state.I)...))  # a "dumb" incrementer
@inline maybe_done(cr::CartRange, state::CartIdx) = state[1] > cr.stop[1]
@inline function done{N}(cr::CartRange, state::CartIdx{N}, ::Bool)
    # handle the carry operation (clean up after the "dumb" incrementer)
    state = CartIdx(carry((), state.I, cr.start.I, cr.stop.I))
    state[N] > cr.stop[N], state
end
@inline carry(out, ::Tuple{}, ::Tuple{}, ::Tuple{}) = out
@inline carry(out, s::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (out..., s[1])
@inline function carry(out, s, start, stop)
    if s[1] > stop[1]
        return carry((out..., start[1]), (s[2]+1, tail(tail(s))...), tail(start), tail(stop))
    end
    return s
end

with the following test script (sumcart_manual and sumcart_iter are from #9080):

A = rand(10^4, 10^4)
sumcart_manual(A)
sumcart_iter(A)
iter = CartRange(CartIdx((1,1)), CartIdx(size(A)))
R = CartesianRange(size(A))
itergeneral(A, R)
itergeneral(A, iter)
@time 1
@time sumcart_manual(A)
@time sumcart_iter(A)
@time itergeneral(A, R)
@time itergeneral(A, iter)

with results

  0.000003 seconds (149 allocations: 9.013 KB)
  0.106422 seconds (5 allocations: 176 bytes)
  0.127288 seconds (5 allocations: 176 bytes)
  0.228374 seconds (5 allocations: 176 bytes)
  0.106804 seconds (5 allocations: 176 bytes)

It's awesome that this is almost 3x faster than our current scheme. It stinks that it's still not as good as the manual version. I profiled it, and almost all the time is spend on @inbounds s += A[item]. Other than the possibility that somehow the CPU isn't as good at cache prefetch with this code as with traditional for-loop code, I'm at a loss to explain the gap.

Note that currently we have one other way of circumventing #9080: add @simd. Even when this doesn't actually turn on vectorization, the @simd macro splits out the cartesian iterator into "inner index" and "remaining index," and thus achieves parity with the manual version (i.e., is better than this julep) even without vectorization. Now, @simd is limited, so this julep is still attractive, but I would feel stronger about pushing for it if we could achieve parity.

Metadata

Metadata

Assignees

No one assigned

    Labels

    julepJulia Enhancement Proposal

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions