diff --git a/stdlib/Random/docs/src/index.md b/stdlib/Random/docs/src/index.md index 6f23ea2686645..65b64cbbcee2d 100644 --- a/stdlib/Random/docs/src/index.md +++ b/stdlib/Random/docs/src/index.md @@ -99,6 +99,14 @@ Samplers can be arbitrary values that implement `rand(rng, sampler)`, but for mo 3. `SamplerSimple(self, data)` also contains the additional `data` field, which can be used to store arbitrary pre-computed values, which should be computed in a *custom method* of `Sampler`. +In general, all samplers should support the method `eltype`, which is used for determining the element type of pre-allocated containers, eg sampling an array of values. + +For `SamplerType`, it is provided automatically, but for `SamplerTrivial` and `SamplerSimple`, +```julia +eltype(::Type{T}) +``` +should be defined to determine the returned type for custom random distributions of type `T`. + We provide examples for each of these. We assume here that the choice of algorithm is independent of the RNG, so we use `AbstractRNG` in our signatures. ```@docs @@ -169,17 +177,19 @@ In order to define random generation out of objects of type `S`, the following m ```jldoctest Die; setup = :(Random.seed!(1)) julia> Random.rand(rng::AbstractRNG, d::Random.SamplerTrivial{Die}) = rand(rng, 1:d[].nsides); +julia> Base.eltype(::Type{Die}) = Int + julia> rand(Die(4)) 3 julia> rand(Die(4), 3) -3-element Array{Any,1}: +3-element Array{Int,1}: 3 4 2 ``` -Given a collection type `S`, it's currently assumed that if `rand(::S)` is defined, an object of type `eltype(S)` will be produced. In the last example, a `Vector{Any}` is produced; the reason is that `eltype(Die) == Any`. The remedy is to define `Base.eltype(::Type{Die}) = Int`. +Given a collection type `S`, if `rand(::S)` is defined, an object of type `eltype(S)` will be produced. In this example, if we did not define a method for `eltype`, a `Vector{Any}` would have been produced. #### Generating values for an `AbstractFloat` type diff --git a/stdlib/Random/src/Random.jl b/stdlib/Random/src/Random.jl index baa82df0bf114..68319111c58d9 100644 --- a/stdlib/Random/src/Random.jl +++ b/stdlib/Random/src/Random.jl @@ -16,7 +16,7 @@ using Base.GMP: Limb using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing -import Base: copymutable, copy, copy!, ==, hash, convert +import Base: copymutable, copy, copy!, ==, hash, convert, eltype using Serialization import Serialization: serialize, deserialize import Base: rand, randn @@ -40,10 +40,6 @@ Supertype for random number generators such as [`MersenneTwister`](@ref) and [`R """ abstract type AbstractRNG end -gentype(::Type{X}) where {X} = eltype(X) -gentype(x) = gentype(typeof(x)) - - ### integers # we define types which encode the generation of a specific number of bits @@ -81,7 +77,7 @@ for UI = (:UInt10, :UInt10Raw, :UInt23, :UInt23Raw, :UInt52, :UInt52Raw, end end -gentype(::Type{<:UniformBits{T}}) where {T} = T +eltype(::Type{<:UniformBits{T}}) where {T} = T ### floats @@ -97,7 +93,7 @@ const CloseOpen12_64 = CloseOpen12{Float64} CloseOpen01(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen01{T}() CloseOpen12(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen12{T}() -gentype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T +eltype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}} @@ -105,7 +101,7 @@ const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}} abstract type Sampler{E} end -gentype(::Type{<:Sampler{E}}) where {E} = E +eltype(::Type{<:Sampler{E}}) where {E} = E # temporarily for BaseBenchmarks RangeGenerator(x) = Sampler(GLOBAL_RNG, x) @@ -135,6 +131,10 @@ the amount of precomputation, if applicable. [`Random.SamplerType`](@ref) and [`Random.SamplerTrivial`](@ref) are default fallbacks for *types* and *values*, respectively. [`Random.SamplerSimple`](@ref) can be used to store pre-computed values without defining extra types for only this purpose. + +Generally, for most custom types that yield random values, defining a new method for +`Sampler` is *not* required, as the above solutions should be sufficient. See the manual for +details and examples. """ Sampler(rng::AbstractRNG, x, r::Repetition=Val(Inf)) = Sampler(typeof(rng), x, r) Sampler(rng::AbstractRNG, ::Type{X}, r::Repetition=Val(Inf)) where {X} = Sampler(typeof(rng), X, r) @@ -170,9 +170,11 @@ end Create a sampler that just wraps the given value `x`. This is the default fall-back for values. +`eltype(x)` is used to determine the types returned by this sampler, and should be defined. + The recommended use case is sampling from values without precomputed data. """ -SamplerTrivial(x::T) where {T} = SamplerTrivial{T,gentype(T)}(x) +SamplerTrivial(x::T) where {T} = SamplerTrivial{T,eltype(T)}(x) Sampler(::Type{<:AbstractRNG}, x, ::Repetition) = SamplerTrivial(x) @@ -189,16 +191,18 @@ end Create a sampler that wraps the given value `x` and the `data`. +`eltype(x)` is used to determine the types returned by this sampler, and should be defined. + The recommended use case is sampling from values with precomputed data. """ -SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,gentype(T)}(x, data) +SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,eltype(T)}(x, data) Base.getindex(sp::SamplerSimple) = sp.self # simple sampler carrying a (type) tag T and data struct SamplerTag{T,S,E} <: Sampler{E} data::S - SamplerTag{T}(s::S) where {T,S} = new{T,S,gentype(T)}(s) + SamplerTag{T}(s::S) where {T,S} = new{T,S,eltype(T)}(s) end @@ -271,7 +275,7 @@ end rand(r::AbstractRNG, dims::Integer...) = rand(r, Float64, Dims(dims)) rand( dims::Integer...) = rand(Float64, Dims(dims)) -rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{gentype(X)}(undef, dims), X) +rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{eltype(X)}(undef, dims), X) rand( X, dims::Dims) = rand(GLOBAL_RNG, X, dims) rand(r::AbstractRNG, X, d::Integer, dims::Integer...) = rand(r, X, Dims((d, dims...))) diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index c1d68910289f2..30b54c6c00b89 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -689,10 +689,10 @@ end end end -@testset "gentype for UniformBits" begin - @test Random.gentype(Random.UInt52()) == UInt64 - @test Random.gentype(Random.UInt52(UInt128)) == UInt128 - @test Random.gentype(Random.UInt104()) == UInt128 +@testset "eltype for UniformBits" begin + @test eltype(Random.UInt52()) == UInt64 + @test eltype(Random.UInt52(UInt128)) == UInt128 + @test eltype(Random.UInt104()) == UInt128 end @testset "shuffle[!]" begin