Skip to content

Commit 8bec2de

Browse files
committed
Make Uniform and its helper traits use arguments of type Borrow<X> rather than type X.
1 parent ec3d7ef commit 8bec2de

File tree

2 files changed

+120
-36
lines changed

2 files changed

+120
-36
lines changed

src/distributions/uniform.rs

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,17 @@
5353
//! Those methods should include an assert to check the range is valid (i.e.
5454
//! `low < high`). The example below merely wraps another back-end.
5555
//!
56+
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
57+
//! type Borrow<X> in order to support passing in values by reference or by
58+
//! value. In the implementation of these functions, you can choose to simply
59+
//! use the reference returned by `Borrow::borrow`, or you can choose to copy
60+
//! or clone the value, whatever is appropriate for your type.
61+
//!
5662
//! ```
5763
//! use rand::prelude::*;
5864
//! use rand::distributions::uniform::{Uniform, SampleUniform,
5965
//! UniformSampler, UniformFloat};
66+
//! use std::borrow::Borrow;
6067
//!
6168
//! struct MyF32(f32);
6269
//!
@@ -67,12 +74,18 @@
6774
//!
6875
//! impl UniformSampler for UniformMyF32 {
6976
//! type X = MyF32;
70-
//! fn new(low: Self::X, high: Self::X) -> Self {
77+
//! fn new<B1, B2>(low: B1, high: B2) -> Self
78+
//! where B1: Borrow<Self::X> + Sized,
79+
//! B2: Borrow<Self::X> + Sized
80+
//! {
7181
//! UniformMyF32 {
72-
//! inner: UniformFloat::<f32>::new(low.0, high.0),
82+
//! inner: UniformFloat::<f32>::new(low.borrow().0, high.borrow().0),
7383
//! }
7484
//! }
75-
//! fn new_inclusive(low: Self::X, high: Self::X) -> Self {
85+
//! fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
86+
//! where B1: Borrow<Self::X> + Sized,
87+
//! B2: Borrow<Self::X> + Sized
88+
//! {
7689
//! UniformSampler::new(low, high)
7790
//! }
7891
//! fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
@@ -99,6 +112,7 @@
99112
100113
#[cfg(feature = "std")]
101114
use std::time::Duration;
115+
use core::borrow::Borrow;
102116

103117
use Rng;
104118
use distributions::Distribution;
@@ -155,13 +169,19 @@ pub struct Uniform<X: SampleUniform> {
155169
impl<X: SampleUniform> Uniform<X> {
156170
/// Create a new `Uniform` instance which samples uniformly from the half
157171
/// open range `[low, high)` (excluding `high`). Panics if `low >= high`.
158-
pub fn new(low: X, high: X) -> Uniform<X> {
172+
pub fn new<B1, B2>(low: B1, high: B2) -> Uniform<X>
173+
where B1: Borrow<X> + Sized,
174+
B2: Borrow<X> + Sized
175+
{
159176
Uniform { inner: X::Sampler::new(low, high) }
160177
}
161178

162179
/// Create a new `Uniform` instance which samples uniformly from the closed
163180
/// range `[low, high]` (inclusive). Panics if `low > high`.
164-
pub fn new_inclusive(low: X, high: X) -> Uniform<X> {
181+
pub fn new_inclusive<B1, B2>(low: B1, high: B2) -> Uniform<X>
182+
where B1: Borrow<X> + Sized,
183+
B2: Borrow<X> + Sized
184+
{
165185
Uniform { inner: X::Sampler::new_inclusive(low, high) }
166186
}
167187
}
@@ -206,14 +226,18 @@ pub trait UniformSampler: Sized {
206226
///
207227
/// Usually users should not call this directly but instead use
208228
/// `Uniform::new`, which asserts that `low < high` before calling this.
209-
fn new(low: Self::X, high: Self::X) -> Self;
229+
fn new<B1, B2>(low: B1, high: B2) -> Self
230+
where B1: Borrow<Self::X> + Sized,
231+
B2: Borrow<Self::X> + Sized;
210232

211233
/// Construct self, with inclusive bounds `[low, high]`.
212234
///
213235
/// Usually users should not call this directly but instead use
214236
/// `Uniform::new_inclusive`, which asserts that `low <= high` before
215237
/// calling this.
216-
fn new_inclusive(low: Self::X, high: Self::X) -> Self;
238+
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
239+
where B1: Borrow<Self::X> + Sized,
240+
B2: Borrow<Self::X> + Sized;
217241

218242
/// Sample a value.
219243
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X;
@@ -229,8 +253,10 @@ pub trait UniformSampler: Sized {
229253
/// sampling only a single value from the specified range. The default
230254
/// implementation simply calls `UniformSampler::new` then `sample` on the
231255
/// result.
232-
fn sample_single<R: Rng + ?Sized>(low: Self::X, high: Self::X, rng: &mut R)
256+
fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R)
233257
-> Self::X
258+
where B1: Borrow<Self::X> + Sized,
259+
B2: Borrow<Self::X> + Sized
234260
{
235261
let uniform: Self = UniformSampler::new(low, high);
236262
uniform.sample(rng)
@@ -311,14 +337,24 @@ macro_rules! uniform_int_impl {
311337

312338
#[inline] // if the range is constant, this helps LLVM to do the
313339
// calculations at compile-time.
314-
fn new(low: Self::X, high: Self::X) -> Self {
340+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
341+
where B1: Borrow<Self::X> + Sized,
342+
B2: Borrow<Self::X> + Sized
343+
{
344+
let low = *low_b.borrow();
345+
let high = *high_b.borrow();
315346
assert!(low < high, "Uniform::new called with `low >= high`");
316347
UniformSampler::new_inclusive(low, high - 1)
317348
}
318349

319350
#[inline] // if the range is constant, this helps LLVM to do the
320351
// calculations at compile-time.
321-
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
352+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
353+
where B1: Borrow<Self::X> + Sized,
354+
B2: Borrow<Self::X> + Sized
355+
{
356+
let low = *low_b.borrow();
357+
let high = *high_b.borrow();
322358
assert!(low <= high,
323359
"Uniform::new_inclusive called with `low > high`");
324360
let unsigned_max = ::core::$unsigned::MAX;
@@ -362,10 +398,13 @@ macro_rules! uniform_int_impl {
362398
}
363399
}
364400

365-
fn sample_single<R: Rng + ?Sized>(low: Self::X,
366-
high: Self::X,
367-
rng: &mut R) -> Self::X
401+
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R)
402+
-> Self::X
403+
where B1: Borrow<Self::X> + Sized,
404+
B2: Borrow<Self::X> + Sized
368405
{
406+
let low = *low_b.borrow();
407+
let high = *high_b.borrow();
369408
assert!(low < high,
370409
"Uniform::sample_single called with low >= high");
371410
let range = high.wrapping_sub(low) as $unsigned as $u_large;
@@ -532,7 +571,12 @@ macro_rules! uniform_float_impl {
532571
impl UniformSampler for UniformFloat<$ty> {
533572
type X = $ty;
534573

535-
fn new(low: Self::X, high: Self::X) -> Self {
574+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
575+
where B1: Borrow<Self::X> + Sized,
576+
B2: Borrow<Self::X> + Sized
577+
{
578+
let low = *low_b.borrow();
579+
let high = *high_b.borrow();
536580
assert!(low < high, "Uniform::new called with `low >= high`");
537581
let scale = high - low;
538582
let offset = low - scale;
@@ -542,7 +586,12 @@ macro_rules! uniform_float_impl {
542586
}
543587
}
544588

545-
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
589+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
590+
where B1: Borrow<Self::X> + Sized,
591+
B2: Borrow<Self::X> + Sized
592+
{
593+
let low = *low_b.borrow();
594+
let high = *high_b.borrow();
546595
assert!(low <= high,
547596
"Uniform::new_inclusive called with `low > high`");
548597
let scale = high - low;
@@ -565,9 +614,13 @@ macro_rules! uniform_float_impl {
565614
value1_2 * self.scale + self.offset
566615
}
567616

568-
fn sample_single<R: Rng + ?Sized>(low: Self::X,
569-
high: Self::X,
570-
rng: &mut R) -> Self::X {
617+
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R)
618+
-> Self::X
619+
where B1: Borrow<Self::X> + Sized,
620+
B2: Borrow<Self::X> + Sized
621+
{
622+
let low = *low_b.borrow();
623+
let high = *high_b.borrow();
571624
assert!(low < high,
572625
"Uniform::sample_single called with low >= high");
573626
let scale = high - low;
@@ -624,13 +677,23 @@ impl UniformSampler for UniformDuration {
624677
type X = Duration;
625678

626679
#[inline]
627-
fn new(low: Duration, high: Duration) -> UniformDuration {
680+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
681+
where B1: Borrow<Self::X> + Sized,
682+
B2: Borrow<Self::X> + Sized
683+
{
684+
let low = *low_b.borrow();
685+
let high = *high_b.borrow();
628686
assert!(low < high, "Uniform::new called with `low >= high`");
629687
UniformDuration::new_inclusive(low, high - Duration::new(0, 1))
630688
}
631689

632690
#[inline]
633-
fn new_inclusive(low: Duration, high: Duration) -> UniformDuration {
691+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
692+
where B1: Borrow<Self::X> + Sized,
693+
B2: Borrow<Self::X> + Sized
694+
{
695+
let low = *low_b.borrow();
696+
let high = *high_b.borrow();
634697
assert!(low <= high, "Uniform::new_inclusive called with `low > high`");
635698
let size = high - low;
636699
let nanos = size
@@ -750,6 +813,18 @@ mod tests {
750813
assert!(low <= v && v <= high);
751814
}
752815

816+
let my_uniform = Uniform::new(&low, high);
817+
for _ in 0..1000 {
818+
let v: $ty = rng.sample(my_uniform);
819+
assert!(low <= v && v < high);
820+
}
821+
822+
let my_uniform = Uniform::new_inclusive(&low, &high);
823+
for _ in 0..1000 {
824+
let v: $ty = rng.sample(my_uniform);
825+
assert!(low <= v && v <= high);
826+
}
827+
753828
for _ in 0..1000 {
754829
let v: $ty = rng.gen_range(low, high);
755830
assert!(low <= v && v < high);
@@ -809,6 +884,7 @@ mod tests {
809884

810885
#[test]
811886
fn test_custom_uniform() {
887+
use core::borrow::Borrow;
812888
#[derive(Clone, Copy, PartialEq, PartialOrd)]
813889
struct MyF32 {
814890
x: f32,
@@ -819,12 +895,18 @@ mod tests {
819895
}
820896
impl UniformSampler for UniformMyF32 {
821897
type X = MyF32;
822-
fn new(low: Self::X, high: Self::X) -> Self {
898+
fn new<B1, B2>(low: B1, high: B2) -> Self
899+
where B1: Borrow<Self::X> + Sized,
900+
B2: Borrow<Self::X> + Sized
901+
{
823902
UniformMyF32 {
824-
inner: UniformFloat::<f32>::new(low.x, high.x),
903+
inner: UniformFloat::<f32>::new(low.borrow().x, high.borrow().x),
825904
}
826905
}
827-
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
906+
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
907+
where B1: Borrow<Self::X> + Sized,
908+
B2: Borrow<Self::X> + Sized
909+
{
828910
UniformSampler::new(low, high)
829911
}
830912
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {

src/lib.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,10 @@ pub mod isaac {
300300

301301

302302
use core::{mem, slice};
303+
use core::borrow::Borrow;
303304
use distributions::{Distribution, Standard};
304305
use distributions::uniform::{SampleUniform, UniformSampler};
305306

306-
307307
/// An automatically-implemented extension trait on [`RngCore`] providing high-level
308308
/// generic methods for sampling values and other convenience methods.
309309
///
@@ -387,7 +387,9 @@ pub trait Rng: RngCore {
387387
/// ```
388388
///
389389
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
390-
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
390+
fn gen_range<T: SampleUniform, B1, B2>(&mut self, low: B1, high: B2) -> T
391+
where B1: Borrow<T> + Sized,
392+
B2: Borrow<T> + Sized {
391393
T::Sampler::sample_single(low, high, self)
392394
}
393395

@@ -935,19 +937,19 @@ mod test {
935937
fn test_gen_range() {
936938
let mut r = rng(101);
937939
for _ in 0..1000 {
938-
let a = r.gen_range(-3, 42);
939-
assert!(a >= -3 && a < 42);
940-
assert_eq!(r.gen_range(0, 1), 0);
941-
assert_eq!(r.gen_range(-12, -11), -12);
942-
}
943-
944-
for _ in 0..1000 {
945-
let a = r.gen_range(10, 42);
946-
assert!(a >= 10 && a < 42);
947-
assert_eq!(r.gen_range(0, 1), 0);
940+
let a = r.gen_range(-4711, 17);
941+
assert!(a >= -4711 && a < 17);
942+
let a = r.gen_range(-3i8, 42);
943+
assert!(a >= -3i8 && a < 42i8);
944+
let a = r.gen_range(&10u16, 99);
945+
assert!(a >= 10u16 && a < 99u16);
946+
let a = r.gen_range(-100i32, &2000);
947+
assert!(a >= -100i32 && a < 2000i32);
948+
949+
assert_eq!(r.gen_range(0u32, 1), 0u32);
950+
assert_eq!(r.gen_range(-12i64, -11), -12i64);
948951
assert_eq!(r.gen_range(3_000_000, 3_000_001), 3_000_000);
949952
}
950-
951953
}
952954

953955
#[test]

0 commit comments

Comments
 (0)