diff --git a/base/random/generation.jl b/base/random/generation.jl index 17a8f02a2017b..29b0b4891ce93 100644 --- a/base/random/generation.jl +++ b/base/random/generation.jl @@ -244,32 +244,38 @@ rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractArray,<:Sampler}) = @inbounds return sp[][rand(rng, sp.data)] -## random values from Dict, Set, BitSet +## random values from Dict -for x in (1, Inf) # eval because of ambiguity otherwise - for T in (Dict, Set, BitSet) - @eval Sampler(::AbstractRNG, t::$T, ::Val{$x}) = SamplerTrivial(t) - end +function Sampler(rng::AbstractRNG, t::Dict, ::Repetition) + isempty(t) && throw(ArgumentError("collection must be non-empty")) + # we use Val(Inf) below as rand is called repeatedly internally + # even for generating only one random value from t + SamplerSimple(t, Sampler(rng, linearindices(t.slots), Val(Inf))) end -function rand(rng::AbstractRNG, sp::SamplerTrivial{<:Dict}) - isempty(sp[]) && throw(ArgumentError("collection must be non-empty")) - rsp = Sampler(rng, 1:length(sp[].slots)) +function rand(rng::AbstractRNG, sp::SamplerSimple{<:Dict,<:Sampler}) while true - i = rand(rng, rsp) + i = rand(rng, sp.data) Base.isslotfilled(sp[], i) && @inbounds return (sp[].keys[i] => sp[].vals[i]) end end -rand(rng::AbstractRNG, sp::SamplerTrivial{<:Set}) = rand(rng, sp[].dict).first +## random values from Set + +Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n)) + +rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first -function rand(rng::AbstractRNG, sp::SamplerTrivial{BitSet}) - isempty(sp[]) && throw(ArgumentError("collection must be non-empty")) - # sp[] can be empty while sp[].bits is not, so we cannot rely on the - # length check in Sampler below - rsp = Sampler(rng, 1:length(sp[].bits)) +## random values from BitSet + +function Sampler(rng::AbstractRNG, t::BitSet, n::Repetition) + isempty(t) && throw(ArgumentError("collection must be non-empty")) + SamplerSimple(t, Sampler(rng, linearindices(t.bits), Val(Inf))) +end + +function rand(rng::AbstractRNG, sp::SamplerSimple{BitSet,<:Sampler}) while true - n = rand(rng, rsp) + n = rand(rng, sp.data) @inbounds b = sp[].bits[n] b && return n end @@ -277,12 +283,16 @@ end ## random values from Associative/AbstractSet -# avoid linear complexity for repeated calls +# we defer to _Sampler to avoid ambiguities with a call like Sampler(rng, Set(1), Val(1)) Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Repetition) = + _Sampler(rng, t, n) + +# avoid linear complexity for repeated calls +_Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Val{Inf}) = Sampler(rng, collect(t), n) # when generating only one element, avoid the call to collect -Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) = +_Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) = SamplerTrivial(t) function nth(iter, n::Integer)::eltype(iter) @@ -299,22 +309,22 @@ rand(rng::AbstractRNG, sp::SamplerTrivial{<:Union{Associative,AbstractSet}}) = # we use collect(str), which is most of the time more efficient than specialized methods # (except maybe for very small arrays) -Sampler(rng::AbstractRNG, str::AbstractString, n::Repetition) = Sampler(rng, collect(str), n) +Sampler(rng::AbstractRNG, str::AbstractString, n::Val{Inf}) = Sampler(rng, collect(str), n) # when generating only one char from a string, the specialized method below # is usually more efficient -Sampler(::AbstractRNG, str::AbstractString, ::Val{1}) = SamplerTrivial(str) +Sampler(rng::AbstractRNG, str::AbstractString, ::Val{1}) = + SamplerSimple(str, Sampler(rng, 1:_endof(str), Val(Inf))) isvalid_unsafe(s::String, i) = !Base.is_valid_continuation(Base.@gc_preserve s unsafe_load(pointer(s), i)) isvalid_unsafe(s::AbstractString, i) = isvalid(s, i) _endof(s::String) = sizeof(s) _endof(s::AbstractString) = endof(s) -function rand(rng::AbstractRNG, sp::SamplerTrivial{<:AbstractString})::Char +function rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractString,<:Sampler})::Char str = sp[] - sp_pos = Sampler(rng, 1:_endof(str)) while true - pos = rand(rng, sp_pos) + pos = rand(rng, sp.data) isvalid_unsafe(str, pos) && return str[pos] end end diff --git a/base/random/random.jl b/base/random/random.jl index 47cc0b89c7dd4..fe082fdd495d7 100644 --- a/base/random/random.jl +++ b/base/random/random.jl @@ -93,6 +93,12 @@ end Base.getindex(sp::SamplerSimple) = sp.self +# simple sampler carrying a (type) tag T and data +struct SamplerTag{T,S} <: Sampler + data::S + SamplerTag{T}(s::S) where {T,S} = new{T,S}(s) +end + ### machinery for generation with Sampler