@@ -26,6 +26,7 @@ export srand,
2626
2727abstract type AbstractRNG end
2828
29+
2930# ## integers
3031
3132# we define types which encode the generation of a specific number of bits
@@ -83,7 +84,9 @@ const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}}
8384
8485# ## Sampler
8586
86- abstract type Sampler end
87+ abstract type Sampler{E} end
88+
89+ Base. eltype (:: Sampler{E} ) where {E} = E
8790
8891# temporarily for BaseBenchmarks
8992RangeGenerator (x) = Sampler (GLOBAL_RNG, x)
@@ -109,41 +112,48 @@ Sampler(rng::AbstractRNG, ::Type{X}) where {X} = Sampler(rng, X, Val(Inf))
109112# ### pre-defined useful Sampler types
110113
111114# default fall-back for types
112- struct SamplerType{T} <: Sampler end
115+ struct SamplerType{T} <: Sampler{T} end
113116
114117Sampler (:: AbstractRNG , :: Type{T} , :: Repetition ) where {T} = SamplerType {T} ()
115118
116- Base. getindex (sp :: SamplerType{T} ) where {T} = T
119+ Base. getindex (:: SamplerType{T} ) where {T} = T
117120
118121# default fall-back for values
119- struct SamplerTrivial{T} <: Sampler
122+ struct SamplerTrivial{T,E } <: Sampler{E}
120123 self:: T
121124end
122125
123- Sampler (:: AbstractRNG , X, :: Repetition ) = SamplerTrivial (X)
126+ SamplerTrivial (x:: T ) where {T} = SamplerTrivial {T,eltype(T)} (x)
127+
128+ Sampler (:: AbstractRNG , x, :: Repetition ) = SamplerTrivial (x)
124129
125130Base. getindex (sp:: SamplerTrivial ) = sp. self
126131
127132# simple sampler carrying data (which can be anything)
128- struct SamplerSimple{T,S} <: Sampler
133+ struct SamplerSimple{T,S,E } <: Sampler{E}
129134 self:: T
130135 data:: S
131136end
132137
138+ SamplerSimple (x:: T , data:: S ) where {T,S} = SamplerSimple {T,S,eltype(T)} (x, data)
139+
133140Base. getindex (sp:: SamplerSimple ) = sp. self
134141
135142# simple sampler carrying a (type) tag T and data
136- struct SamplerTag{T,S} <: Sampler
143+ struct SamplerTag{T,S,E } <: Sampler{E}
137144 data:: S
138- SamplerTag {T} (s:: S ) where {T,S} = new {T,S} (s)
145+ SamplerTag {T} (s:: S ) where {T,S} = new {T,S,eltype(T) } (s)
139146end
140147
141148
142149# ### helper samplers
143150
151+ # TODO : make constraining constructors to enforce that those
152+ # types are <: Sampler{T}
153+
144154# #### Adapter to generate a randome value in [0, n]
145155
146- struct LessThan{T<: Integer ,S} <: Sampler
156+ struct LessThan{T<: Integer ,S} <: Sampler{T}
147157 sup:: T
148158 s:: S # the scalar specification/sampler to feed to rand
149159end
@@ -155,7 +165,7 @@ function rand(rng::AbstractRNG, sp::LessThan)
155165 end
156166end
157167
158- struct Masked{T<: Integer ,S} <: Sampler
168+ struct Masked{T<: Integer ,S} <: Sampler{T}
159169 mask:: T
160170 s:: S
161171end
@@ -164,7 +174,7 @@ rand(rng::AbstractRNG, sp::Masked) = rand(rng, sp.s) & sp.mask
164174
165175# #### Uniform
166176
167- struct UniformT{T} <: Sampler end
177+ struct UniformT{T} <: Sampler{T} end
168178
169179uniform (:: Type{T} ) where {T} = UniformT {T} ()
170180
0 commit comments