@@ -27,15 +27,9 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
2727
2828# resolve ambiguities
2929Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
30- dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
31- # dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
30+ dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims, init)
3231Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
33- dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
34- # dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35- Base. mapreduce (f, op, A:: AnyGPUArray ;
36- dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
37- Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} ;
38- dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
32+ dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims, init)
3933
4034function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
4135 # figure out the destination container type by looking at the initializer element,
@@ -72,9 +66,25 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
7266 end
7367
7468 # allocate an output container
69+ block_size = 256 # Hard-code AK default to prevent mismatches
7570 sz = size (A)
7671 red = ntuple (i-> (dims== Colon () || i in dims) ? 1 : sz[i], length (sz))
77- R = similar (A, ET, red)
72+ R = if dims isa Colon
73+ num_per_block = 2 * block_size
74+ blocks = (prod (sz) + num_per_block - 1 ) ÷ num_per_block
75+ similar (A, ET, 2 * blocks)
76+ else
77+ similar (A, ET, red)
78+ end
79+
80+ # Use AcceleratedKernels if possible
81+ if dims isa Colon || dims isa Integer
82+ return AK. mapreduce (f, op, Base. materialize (A), get_backend (R);
83+ block_size, init,
84+ neutral= init,
85+ dims= dims isa Colon ? nothing : dims,
86+ temp = R)
87+ end
7888
7989 # perform the reduction
8090 if prod (sz) == 0
0 commit comments