Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Build system changes

New library functions
---------------------
* `findextrema(f, itr; [dims])` which computes `findmin(f, itr; [dims]), findmax(f, itr; [dims])` in a single pass. ([#45783])


New library features
Expand Down
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ export
findmin,
findmin!,
findmax!,
findextrema,
findextrema!,
findnext,
findprev,
match,
Expand Down
50 changes: 50 additions & 0 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,56 @@ julia> findmin([1, 7, 7, NaN])
findmin(itr) = _findmin(itr, :)
_findmin(a, ::Colon) = findmin(identity, a)

"""
findextrema(f, domain) -> ((f(x), index_mn), (f(x), index_mx))

Return the pair of pairs which would be returned by `(findmin(f, domain), findmax(f, domain))`,
but computed in a single pass.

# Examples

```jldoctest
julia> findextrema(identity, 5:9)
((5, 1), (9, 5))

julia> findextrema(-, 1:10)
((-10, 10), (-1, 1))

julia> findextrema(first, [(2, :a), (2, :b), (3, :c)])
((2, 1), (3, 3))

julia> findextrema(cos, 0:π/2:2π)
((-1.0, 3), (1.0, 1))
```
"""
findextrema(f, domain) = _findextrema(f, domain, :)
_findextrema(f, domain, ::Colon) = mapfoldl(((k, v),) -> (fv = f(v); ((fv, k), (fv, k))), _rf_findextrema, pairs(domain))
_rf_findextrema((a, b), (c, d)) = _rf_findmin(a, c), _rf_findmax(b, d)

"""
findextrema(itr) -> ((mn, index_mn), (mx, index_mx))

Return the pair of pairs which would be returned by `(findmin(itr), findmax(itr))`,
but computed in a single pass.

See also: [`findmin`](@ref), [`findmax`](@ref)

# Examples

```jldoctest
julia> findextrema([8, 0.1, -9, pi])
((-9.0, 3), (8.0, 1))

julia> findextrema([1, 7, 7, 6])
((1, 1), (7, 2))

julia> findextrema([1, 7, 7, NaN])
((NaN, 4), (NaN, 4))
```
"""
findextrema(itr) = _findextrema(itr, :)
_findextrema(a, ::Colon) = findextrema(identity, a)

"""
argmax(f, domain)

Expand Down
143 changes: 143 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,149 @@ end

reducedim1(R, A) = length(axes1(R)) == 1

function _findextrema!(f, op_mn, op_mx, Rval_mn, Rind_mn, Rval_mx, Rind_mx, A::AbstractArray{T,N}) where {T,N}
(isempty(Rval_mn) || isempty(Rval_mx) || isempty(A)) && return ((Rval_mn, Rind_mn), (Rval_mx, Rind_mx))
lsiz_mn = check_reducedims(Rval_mn, A)
for i = 1:N
axes(Rval_mn, i) == axes(Rind_mn, i) == axes(Rval_mx, i) == axes(Rind_mx, i) || throw(DimensionMismatch("Find-reduction: outputs must have the same indices"))
end
# Same as findminmax! implementation
indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(Rval_mn))
keep, Idefault = Broadcast.shapeindexer(indsRt)
ks = keys(A)
y = iterate(ks)
zi = zero(eltype(ks))
if reducedim1(Rval_mn, A)
i1 = first(axes1(Rval_mn))
@inbounds for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
tmpRv_mn = Rval_mn[i1,IR]
tmpRi_mn = Rind_mn[i1,IR]
tmpRv_mx = Rval_mx[i1,IR]
tmpRi_mx = Rind_mx[i1,IR]
for i in axes(A,1)
k, kss = y::Tuple
tmpAv = f(A[i,IA])
if tmpRi_mn == zi || op_mn(tmpRv_mn, tmpAv)
tmpRv_mn = tmpAv
tmpRi_mn = k
end
if tmpRi_mx == zi || op_mx(tmpRv_mx, tmpAv)
tmpRv_mx = tmpAv
tmpRi_mx = k
end
y = iterate(ks, kss)
end
Rval_mn[i1,IR] = tmpRv_mn
Rind_mn[i1,IR] = tmpRi_mn
Rval_mx[i1,IR] = tmpRv_mx
Rind_mx[i1,IR] = tmpRi_mx
end
else
@inbounds for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
for i in axes(A, 1)
k, kss = y::Tuple
tmpAv = f(A[i,IA])
tmpRv_mn = Rval_mn[i,IR]
tmpRi_mn = Rind_mn[i,IR]
tmpRv_mx = Rval_mx[i,IR]
tmpRi_mx = Rind_mx[i,IR]
if tmpRi_mn == zi || op_mn(tmpRv_mn, tmpAv)
Rval_mn[i,IR] = tmpAv
Rind_mn[i,IR] = k
end
if tmpRi_mx == zi || op_mx(tmpRv_mx, tmpAv)
Rval_mx[i,IR] = tmpAv
Rind_mx[i,IR] = k
end
y = iterate(ks, kss)
end
end
end
((Rval_mn, Rind_mn), (Rval_mx, Rind_mx))
end

"""
findextrema!(rval_mn, rind_mn, rval_mx, rind_mx, A) -> ((minval, index), (maxval, index))

Find the minimum and maximum of `A` and the respective linear index along singleton
dimensions, storing the results in `((rval_mn , rind_mn), (rval_mn , rind_mn))`,
equivalent to `(findmin!(rval_mn, rind_mn, A), findmax!(rval_mx, rind_mx, A))`
but computed in a single pass.
"""
function findextrema!(rval_mn::AbstractArray, rind_mn::AbstractArray, rval_mx::AbstractArray, rind_mx::AbstractArray, A::AbstractArray; init::Bool=true)
init && !isempty(A) && (fill!(rval_mn, first(A)); fill!(rval_mx, first(A)))
Ti = eltype(keys(A))
_findextrema!(identity, isgreater, isless, rval_mn, fill!(rind_mn,zero(Ti)), rval_mx, fill!(rind_mx,zero(Ti)), A)
end

"""
findextrema(A; dims) -> ((minval, index), (maxval, index))

For an array input, returns the value and index of the minimum and maximum over the
given dimensions. Equivalent to `(findmin(A; dims), findmax(A; dims))`, but computed
in a single pass.
`NaN` is treated as greater than all other values except `missing`.

# Examples
```jldoctest
julia> A = [1.0 2; 3 4]
2×2 Matrix{Float64}:
1.0 2.0
3.0 4.0

julia> findextrema(A, dims=1) == (findmin(A, dims=1), findmax(A, dims=1))
true

julia> findextrema(A, dims=2) == (findmin(A, dims=2), findmax(A, dims=2))
true
```
"""
findextrema(A::AbstractArray; dims=:) = _findextrema(A, dims)
_findextrema(A, dims) = _findextrema(identity, A, dims)

"""
findextrema(f, A; dims) -> ((f(x), index_mn), (f(x), index_mx))

For an array input, returns the value in the codomain and index of the corresponding value
which minimize and maximize `f` over the given dimensions. Equivalent to
`(findmin(f, A; dims), findmax(f, A; dims))`, but computed in a single pass.

# Examples
```jldoctest
julia> A = [-1.0 1; -0.5 2]
2×2 Matrix{Float64}:
-1.0 1.0
-0.5 2.0

julia> findextrema(abs2, A, dims=1) == (findmin(abs2, A, dims=1), findmax(abs2, A, dims=1))
true

julia> findextrema(abs2, A, dims=2) == (findmin(abs2, A, dims=2), findmax(abs2, A, dims=2))
true
```
"""
findextrema(f, A::AbstractArray; dims=:) = _findextrema(f, A, dims)

function _findextrema(f, A, region)
ri = reduced_indices0(A, region)
if isempty(A)
if prod(map(length, reduced_indices(A, region))) != 0
throw(ArgumentError("collection slices must be non-empty"))
end
Tr = promote_op(f, eltype(A))
Ti = eltype(keys(A))
(similar(A, Tr, ri), zeros(Ti, ri)), (similar(A, Tr, ri), zeros(Ti, ri))
else
fA = f(first(A))
Tr = _findminmax_inittype(f, A)
Ti = eltype(keys(A))
_findextrema!(f, isgreater, isless, fill!(similar(A, Tr, ri), fA), zeros(Ti, ri),
fill!(similar(A, Tr, ri), fA), zeros(Ti, ri), A)
end
end

"""
argmin(A; dims) -> indices

Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/collections.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ Base.argmax
Base.argmin
Base.findmax
Base.findmin
Base.findextrema
Base.findmax!
Base.findmin!
Base.findextrema!
Base.sum
Base.sum!
Base.prod
Expand Down
11 changes: 11 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ end
@test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5)
end

# findextrema
@testset "findextrema(f, domain)" begin
@test findextrema(-, 1:10) == ((-10, 10), (-1, 1))
@test findextrema(identity, [1, 2, 3, missing]) === ((missing, 4), (missing, 4))
@test findextrema(identity, [1, NaN, 3, missing]) === ((missing, 4), (missing, 4))
@test findextrema(identity, [1, missing, NaN, 3]) === ((missing, 2), (missing, 2))
@test findextrema(identity, [1, NaN, 3]) === ((NaN, 2), (NaN, 2))
@test findextrema(identity, [1, 3, NaN]) === ((NaN, 3), (NaN, 3))
@test findextrema(cos, 0:π/2:2π) == ((-1.0, 3), (1.0, 1))
end

# any & all

@test @inferred any([]) == false
Expand Down
Loading