Skip to content

Commit c979996

Browse files
authored
Merge pull request #24866 from JuliaLang/rf/rand/faster-dict-set
faster rand! for Dict, Set, BitSet
2 parents 9b1a56e + 199073a commit c979996

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

base/random/generation.jl

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -244,45 +244,55 @@ rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractArray,<:Sampler}) =
244244
@inbounds return sp[][rand(rng, sp.data)]
245245

246246

247-
## random values from Dict, Set, BitSet
247+
## random values from Dict
248248

249-
for x in (1, Inf) # eval because of ambiguity otherwise
250-
for T in (Dict, Set, BitSet)
251-
@eval Sampler(::AbstractRNG, t::$T, ::Val{$x}) = SamplerTrivial(t)
252-
end
249+
function Sampler(rng::AbstractRNG, t::Dict, ::Repetition)
250+
isempty(t) && throw(ArgumentError("collection must be non-empty"))
251+
# we use Val(Inf) below as rand is called repeatedly internally
252+
# even for generating only one random value from t
253+
SamplerSimple(t, Sampler(rng, linearindices(t.slots), Val(Inf)))
253254
end
254255

255-
function rand(rng::AbstractRNG, sp::SamplerTrivial{<:Dict})
256-
isempty(sp[]) && throw(ArgumentError("collection must be non-empty"))
257-
rsp = Sampler(rng, 1:length(sp[].slots))
256+
function rand(rng::AbstractRNG, sp::SamplerSimple{<:Dict,<:Sampler})
258257
while true
259-
i = rand(rng, rsp)
258+
i = rand(rng, sp.data)
260259
Base.isslotfilled(sp[], i) && @inbounds return (sp[].keys[i] => sp[].vals[i])
261260
end
262261
end
263262

264-
rand(rng::AbstractRNG, sp::SamplerTrivial{<:Set}) = rand(rng, sp[].dict).first
263+
## random values from Set
264+
265+
Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n))
266+
267+
rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first
265268

266-
function rand(rng::AbstractRNG, sp::SamplerTrivial{BitSet})
267-
isempty(sp[]) && throw(ArgumentError("collection must be non-empty"))
268-
# sp[] can be empty while sp[].bits is not, so we cannot rely on the
269-
# length check in Sampler below
270-
rsp = Sampler(rng, 1:length(sp[].bits))
269+
## random values from BitSet
270+
271+
function Sampler(rng::AbstractRNG, t::BitSet, n::Repetition)
272+
isempty(t) && throw(ArgumentError("collection must be non-empty"))
273+
SamplerSimple(t, Sampler(rng, linearindices(t.bits), Val(Inf)))
274+
end
275+
276+
function rand(rng::AbstractRNG, sp::SamplerSimple{BitSet,<:Sampler})
271277
while true
272-
n = rand(rng, rsp)
278+
n = rand(rng, sp.data)
273279
@inbounds b = sp[].bits[n]
274280
b && return n
275281
end
276282
end
277283

278284
## random values from Associative/AbstractSet
279285

280-
# avoid linear complexity for repeated calls
286+
# we defer to _Sampler to avoid ambiguities with a call like Sampler(rng, Set(1), Val(1))
281287
Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Repetition) =
288+
_Sampler(rng, t, n)
289+
290+
# avoid linear complexity for repeated calls
291+
_Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Val{Inf}) =
282292
Sampler(rng, collect(t), n)
283293

284294
# when generating only one element, avoid the call to collect
285-
Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) =
295+
_Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) =
286296
SamplerTrivial(t)
287297

288298
function nth(iter, n::Integer)::eltype(iter)
@@ -299,22 +309,22 @@ rand(rng::AbstractRNG, sp::SamplerTrivial{<:Union{Associative,AbstractSet}}) =
299309

300310
# we use collect(str), which is most of the time more efficient than specialized methods
301311
# (except maybe for very small arrays)
302-
Sampler(rng::AbstractRNG, str::AbstractString, n::Repetition) = Sampler(rng, collect(str), n)
312+
Sampler(rng::AbstractRNG, str::AbstractString, n::Val{Inf}) = Sampler(rng, collect(str), n)
303313

304314
# when generating only one char from a string, the specialized method below
305315
# is usually more efficient
306-
Sampler(::AbstractRNG, str::AbstractString, ::Val{1}) = SamplerTrivial(str)
316+
Sampler(rng::AbstractRNG, str::AbstractString, ::Val{1}) =
317+
SamplerSimple(str, Sampler(rng, 1:_endof(str), Val(Inf)))
307318

308319
isvalid_unsafe(s::String, i) = !Base.is_valid_continuation(Base.@gc_preserve s unsafe_load(pointer(s), i))
309320
isvalid_unsafe(s::AbstractString, i) = isvalid(s, i)
310321
_endof(s::String) = sizeof(s)
311322
_endof(s::AbstractString) = endof(s)
312323

313-
function rand(rng::AbstractRNG, sp::SamplerTrivial{<:AbstractString})::Char
324+
function rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractString,<:Sampler})::Char
314325
str = sp[]
315-
sp_pos = Sampler(rng, 1:_endof(str))
316326
while true
317-
pos = rand(rng, sp_pos)
327+
pos = rand(rng, sp.data)
318328
isvalid_unsafe(str, pos) && return str[pos]
319329
end
320330
end

base/random/random.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ end
9393

9494
Base.getindex(sp::SamplerSimple) = sp.self
9595

96+
# simple sampler carrying a (type) tag T and data
97+
struct SamplerTag{T,S} <: Sampler
98+
data::S
99+
SamplerTag{T}(s::S) where {T,S} = new{T,S}(s)
100+
end
101+
96102

97103
### machinery for generation with Sampler
98104

0 commit comments

Comments
 (0)