Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions base/random/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,45 +244,55 @@ rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractArray,<:Sampler}) =
@inbounds return sp[][rand(rng, sp.data)]


## random values from Dict, Set, BitSet
## random values from Dict

for x in (1, Inf) # eval because of ambiguity otherwise
for T in (Dict, Set, BitSet)
@eval Sampler(::AbstractRNG, t::$T, ::Val{$x}) = SamplerTrivial(t)
end
function Sampler(rng::AbstractRNG, t::Dict, ::Repetition)
isempty(t) && throw(ArgumentError("collection must be non-empty"))
# we use Val(Inf) below as rand is called repeatedly internally
# even for generating only one random value from t
SamplerSimple(t, Sampler(rng, linearindices(t.slots), Val(Inf)))
end

function rand(rng::AbstractRNG, sp::SamplerTrivial{<:Dict})
isempty(sp[]) && throw(ArgumentError("collection must be non-empty"))
rsp = Sampler(rng, 1:length(sp[].slots))
function rand(rng::AbstractRNG, sp::SamplerSimple{<:Dict,<:Sampler})
while true
i = rand(rng, rsp)
i = rand(rng, sp.data)
Base.isslotfilled(sp[], i) && @inbounds return (sp[].keys[i] => sp[].vals[i])
end
end

rand(rng::AbstractRNG, sp::SamplerTrivial{<:Set}) = rand(rng, sp[].dict).first
## random values from Set

Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n))

rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first

function rand(rng::AbstractRNG, sp::SamplerTrivial{BitSet})
isempty(sp[]) && throw(ArgumentError("collection must be non-empty"))
# sp[] can be empty while sp[].bits is not, so we cannot rely on the
# length check in Sampler below
rsp = Sampler(rng, 1:length(sp[].bits))
## random values from BitSet

function Sampler(rng::AbstractRNG, t::BitSet, n::Repetition)
isempty(t) && throw(ArgumentError("collection must be non-empty"))
SamplerSimple(t, Sampler(rng, linearindices(t.bits), Val(Inf)))
end

function rand(rng::AbstractRNG, sp::SamplerSimple{BitSet,<:Sampler})
while true
n = rand(rng, rsp)
n = rand(rng, sp.data)
@inbounds b = sp[].bits[n]
b && return n
end
end

## random values from Associative/AbstractSet

# avoid linear complexity for repeated calls
# we defer to _Sampler to avoid ambiguities with a call like Sampler(rng, Set(1), Val(1))
Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Repetition) =
_Sampler(rng, t, n)

# avoid linear complexity for repeated calls
_Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Val{Inf}) =
Sampler(rng, collect(t), n)

# when generating only one element, avoid the call to collect
Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) =
_Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) =
SamplerTrivial(t)

function nth(iter, n::Integer)::eltype(iter)
Expand All @@ -299,22 +309,22 @@ rand(rng::AbstractRNG, sp::SamplerTrivial{<:Union{Associative,AbstractSet}}) =

# we use collect(str), which is most of the time more efficient than specialized methods
# (except maybe for very small arrays)
Sampler(rng::AbstractRNG, str::AbstractString, n::Repetition) = Sampler(rng, collect(str), n)
Sampler(rng::AbstractRNG, str::AbstractString, n::Val{Inf}) = Sampler(rng, collect(str), n)

# when generating only one char from a string, the specialized method below
# is usually more efficient
Sampler(::AbstractRNG, str::AbstractString, ::Val{1}) = SamplerTrivial(str)
Sampler(rng::AbstractRNG, str::AbstractString, ::Val{1}) =
SamplerSimple(str, Sampler(rng, 1:_endof(str), Val(Inf)))

isvalid_unsafe(s::String, i) = !Base.is_valid_continuation(Base.@gc_preserve s unsafe_load(pointer(s), i))
isvalid_unsafe(s::AbstractString, i) = isvalid(s, i)
_endof(s::String) = sizeof(s)
_endof(s::AbstractString) = endof(s)

function rand(rng::AbstractRNG, sp::SamplerTrivial{<:AbstractString})::Char
function rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractString,<:Sampler})::Char
str = sp[]
sp_pos = Sampler(rng, 1:_endof(str))
while true
pos = rand(rng, sp_pos)
pos = rand(rng, sp.data)
isvalid_unsafe(str, pos) && return str[pos]
end
end
6 changes: 6 additions & 0 deletions base/random/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ end

Base.getindex(sp::SamplerSimple) = sp.self

# simple sampler carrying a (type) tag T and data
struct SamplerTag{T,S} <: Sampler
data::S
SamplerTag{T}(s::S) where {T,S} = new{T,S}(s)
end


### machinery for generation with Sampler

Expand Down