Skip to content

Commit 6c3ac5b

Browse files
authored
Merge pull request #506 from sicking/borrow
Make Uniform and its helper traits use arguments of type Borrow<X>
2 parents 8fdd710 + fc6c7a0 commit 6c3ac5b

File tree

2 files changed

+121
-36
lines changed

2 files changed

+121
-36
lines changed

src/distributions/uniform.rs

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,17 @@
5353
//! Those methods should include an assert to check the range is valid (i.e.
5454
//! `low < high`). The example below merely wraps another back-end.
5555
//!
56+
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
57+
//! type Borrow<X> in order to support passing in values by reference or by
58+
//! value. In the implementation of these functions, you can choose to simply
59+
//! use the reference returned by [`Borrow::borrow`], or you can choose to copy
60+
//! or clone the value, whatever is appropriate for your type.
61+
//!
5662
//! ```
5763
//! use rand::prelude::*;
5864
//! use rand::distributions::uniform::{Uniform, SampleUniform,
5965
//! UniformSampler, UniformFloat};
66+
//! use std::borrow::Borrow;
6067
//!
6168
//! struct MyF32(f32);
6269
//!
@@ -67,12 +74,18 @@
6774
//!
6875
//! impl UniformSampler for UniformMyF32 {
6976
//! type X = MyF32;
70-
//! fn new(low: Self::X, high: Self::X) -> Self {
77+
//! fn new<B1, B2>(low: B1, high: B2) -> Self
78+
//! where B1: Borrow<Self::X> + Sized,
79+
//! B2: Borrow<Self::X> + Sized
80+
//! {
7181
//! UniformMyF32 {
72-
//! inner: UniformFloat::<f32>::new(low.0, high.0),
82+
//! inner: UniformFloat::<f32>::new(low.borrow().0, high.borrow().0),
7383
//! }
7484
//! }
75-
//! fn new_inclusive(low: Self::X, high: Self::X) -> Self {
85+
//! fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
86+
//! where B1: Borrow<Self::X> + Sized,
87+
//! B2: Borrow<Self::X> + Sized
88+
//! {
7689
//! UniformSampler::new(low, high)
7790
//! }
7891
//! fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
@@ -96,9 +109,11 @@
96109
//! [`UniformInt`]: struct.UniformInt.html
97110
//! [`UniformFloat`]: struct.UniformFloat.html
98111
//! [`UniformDuration`]: struct.UniformDuration.html
112+
//! [`Borrow::borrow`]: https://doc.rust-lang.org/std/borrow/trait.Borrow.html
99113
100114
#[cfg(feature = "std")]
101115
use std::time::Duration;
116+
use core::borrow::Borrow;
102117

103118
use Rng;
104119
use distributions::Distribution;
@@ -155,13 +170,19 @@ pub struct Uniform<X: SampleUniform> {
155170
impl<X: SampleUniform> Uniform<X> {
156171
/// Create a new `Uniform` instance which samples uniformly from the half
157172
/// open range `[low, high)` (excluding `high`). Panics if `low >= high`.
158-
pub fn new(low: X, high: X) -> Uniform<X> {
173+
pub fn new<B1, B2>(low: B1, high: B2) -> Uniform<X>
174+
where B1: Borrow<X> + Sized,
175+
B2: Borrow<X> + Sized
176+
{
159177
Uniform { inner: X::Sampler::new(low, high) }
160178
}
161179

162180
/// Create a new `Uniform` instance which samples uniformly from the closed
163181
/// range `[low, high]` (inclusive). Panics if `low > high`.
164-
pub fn new_inclusive(low: X, high: X) -> Uniform<X> {
182+
pub fn new_inclusive<B1, B2>(low: B1, high: B2) -> Uniform<X>
183+
where B1: Borrow<X> + Sized,
184+
B2: Borrow<X> + Sized
185+
{
165186
Uniform { inner: X::Sampler::new_inclusive(low, high) }
166187
}
167188
}
@@ -206,14 +227,18 @@ pub trait UniformSampler: Sized {
206227
///
207228
/// Usually users should not call this directly but instead use
208229
/// `Uniform::new`, which asserts that `low < high` before calling this.
209-
fn new(low: Self::X, high: Self::X) -> Self;
230+
fn new<B1, B2>(low: B1, high: B2) -> Self
231+
where B1: Borrow<Self::X> + Sized,
232+
B2: Borrow<Self::X> + Sized;
210233

211234
/// Construct self, with inclusive bounds `[low, high]`.
212235
///
213236
/// Usually users should not call this directly but instead use
214237
/// `Uniform::new_inclusive`, which asserts that `low <= high` before
215238
/// calling this.
216-
fn new_inclusive(low: Self::X, high: Self::X) -> Self;
239+
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
240+
where B1: Borrow<Self::X> + Sized,
241+
B2: Borrow<Self::X> + Sized;
217242

218243
/// Sample a value.
219244
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X;
@@ -229,8 +254,10 @@ pub trait UniformSampler: Sized {
229254
/// sampling only a single value from the specified range. The default
230255
/// implementation simply calls `UniformSampler::new` then `sample` on the
231256
/// result.
232-
fn sample_single<R: Rng + ?Sized>(low: Self::X, high: Self::X, rng: &mut R)
257+
fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R)
233258
-> Self::X
259+
where B1: Borrow<Self::X> + Sized,
260+
B2: Borrow<Self::X> + Sized
234261
{
235262
let uniform: Self = UniformSampler::new(low, high);
236263
uniform.sample(rng)
@@ -311,14 +338,24 @@ macro_rules! uniform_int_impl {
311338

312339
#[inline] // if the range is constant, this helps LLVM to do the
313340
// calculations at compile-time.
314-
fn new(low: Self::X, high: Self::X) -> Self {
341+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
342+
where B1: Borrow<Self::X> + Sized,
343+
B2: Borrow<Self::X> + Sized
344+
{
345+
let low = *low_b.borrow();
346+
let high = *high_b.borrow();
315347
assert!(low < high, "Uniform::new called with `low >= high`");
316348
UniformSampler::new_inclusive(low, high - 1)
317349
}
318350

319351
#[inline] // if the range is constant, this helps LLVM to do the
320352
// calculations at compile-time.
321-
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
353+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
354+
where B1: Borrow<Self::X> + Sized,
355+
B2: Borrow<Self::X> + Sized
356+
{
357+
let low = *low_b.borrow();
358+
let high = *high_b.borrow();
322359
assert!(low <= high,
323360
"Uniform::new_inclusive called with `low > high`");
324361
let unsigned_max = ::core::$unsigned::MAX;
@@ -362,10 +399,13 @@ macro_rules! uniform_int_impl {
362399
}
363400
}
364401

365-
fn sample_single<R: Rng + ?Sized>(low: Self::X,
366-
high: Self::X,
367-
rng: &mut R) -> Self::X
402+
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R)
403+
-> Self::X
404+
where B1: Borrow<Self::X> + Sized,
405+
B2: Borrow<Self::X> + Sized
368406
{
407+
let low = *low_b.borrow();
408+
let high = *high_b.borrow();
369409
assert!(low < high,
370410
"Uniform::sample_single called with low >= high");
371411
let range = high.wrapping_sub(low) as $unsigned as $u_large;
@@ -532,7 +572,12 @@ macro_rules! uniform_float_impl {
532572
impl UniformSampler for UniformFloat<$ty> {
533573
type X = $ty;
534574

535-
fn new(low: Self::X, high: Self::X) -> Self {
575+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
576+
where B1: Borrow<Self::X> + Sized,
577+
B2: Borrow<Self::X> + Sized
578+
{
579+
let low = *low_b.borrow();
580+
let high = *high_b.borrow();
536581
assert!(low < high, "Uniform::new called with `low >= high`");
537582
let scale = high - low;
538583
let offset = low - scale;
@@ -542,7 +587,12 @@ macro_rules! uniform_float_impl {
542587
}
543588
}
544589

545-
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
590+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
591+
where B1: Borrow<Self::X> + Sized,
592+
B2: Borrow<Self::X> + Sized
593+
{
594+
let low = *low_b.borrow();
595+
let high = *high_b.borrow();
546596
assert!(low <= high,
547597
"Uniform::new_inclusive called with `low > high`");
548598
let scale = high - low;
@@ -565,9 +615,13 @@ macro_rules! uniform_float_impl {
565615
value1_2 * self.scale + self.offset
566616
}
567617

568-
fn sample_single<R: Rng + ?Sized>(low: Self::X,
569-
high: Self::X,
570-
rng: &mut R) -> Self::X {
618+
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R)
619+
-> Self::X
620+
where B1: Borrow<Self::X> + Sized,
621+
B2: Borrow<Self::X> + Sized
622+
{
623+
let low = *low_b.borrow();
624+
let high = *high_b.borrow();
571625
assert!(low < high,
572626
"Uniform::sample_single called with low >= high");
573627
let scale = high - low;
@@ -624,13 +678,23 @@ impl UniformSampler for UniformDuration {
624678
type X = Duration;
625679

626680
#[inline]
627-
fn new(low: Duration, high: Duration) -> UniformDuration {
681+
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
682+
where B1: Borrow<Self::X> + Sized,
683+
B2: Borrow<Self::X> + Sized
684+
{
685+
let low = *low_b.borrow();
686+
let high = *high_b.borrow();
628687
assert!(low < high, "Uniform::new called with `low >= high`");
629688
UniformDuration::new_inclusive(low, high - Duration::new(0, 1))
630689
}
631690

632691
#[inline]
633-
fn new_inclusive(low: Duration, high: Duration) -> UniformDuration {
692+
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
693+
where B1: Borrow<Self::X> + Sized,
694+
B2: Borrow<Self::X> + Sized
695+
{
696+
let low = *low_b.borrow();
697+
let high = *high_b.borrow();
634698
assert!(low <= high, "Uniform::new_inclusive called with `low > high`");
635699
let size = high - low;
636700
let nanos = size
@@ -750,6 +814,18 @@ mod tests {
750814
assert!(low <= v && v <= high);
751815
}
752816

817+
let my_uniform = Uniform::new(&low, high);
818+
for _ in 0..1000 {
819+
let v: $ty = rng.sample(my_uniform);
820+
assert!(low <= v && v < high);
821+
}
822+
823+
let my_uniform = Uniform::new_inclusive(&low, &high);
824+
for _ in 0..1000 {
825+
let v: $ty = rng.sample(my_uniform);
826+
assert!(low <= v && v <= high);
827+
}
828+
753829
for _ in 0..1000 {
754830
let v: $ty = rng.gen_range(low, high);
755831
assert!(low <= v && v < high);
@@ -809,6 +885,7 @@ mod tests {
809885

810886
#[test]
811887
fn test_custom_uniform() {
888+
use core::borrow::Borrow;
812889
#[derive(Clone, Copy, PartialEq, PartialOrd)]
813890
struct MyF32 {
814891
x: f32,
@@ -819,12 +896,18 @@ mod tests {
819896
}
820897
impl UniformSampler for UniformMyF32 {
821898
type X = MyF32;
822-
fn new(low: Self::X, high: Self::X) -> Self {
899+
fn new<B1, B2>(low: B1, high: B2) -> Self
900+
where B1: Borrow<Self::X> + Sized,
901+
B2: Borrow<Self::X> + Sized
902+
{
823903
UniformMyF32 {
824-
inner: UniformFloat::<f32>::new(low.x, high.x),
904+
inner: UniformFloat::<f32>::new(low.borrow().x, high.borrow().x),
825905
}
826906
}
827-
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
907+
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
908+
where B1: Borrow<Self::X> + Sized,
909+
B2: Borrow<Self::X> + Sized
910+
{
828911
UniformSampler::new(low, high)
829912
}
830913
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {

src/lib.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,10 @@ pub mod isaac {
300300

301301

302302
use core::{mem, slice};
303+
use core::borrow::Borrow;
303304
use distributions::{Distribution, Standard};
304305
use distributions::uniform::{SampleUniform, UniformSampler};
305306

306-
307307
/// An automatically-implemented extension trait on [`RngCore`] providing high-level
308308
/// generic methods for sampling values and other convenience methods.
309309
///
@@ -387,7 +387,9 @@ pub trait Rng: RngCore {
387387
/// ```
388388
///
389389
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
390-
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
390+
fn gen_range<T: SampleUniform, B1, B2>(&mut self, low: B1, high: B2) -> T
391+
where B1: Borrow<T> + Sized,
392+
B2: Borrow<T> + Sized {
391393
T::Sampler::sample_single(low, high, self)
392394
}
393395

@@ -935,19 +937,19 @@ mod test {
935937
fn test_gen_range() {
936938
let mut r = rng(101);
937939
for _ in 0..1000 {
938-
let a = r.gen_range(-3, 42);
939-
assert!(a >= -3 && a < 42);
940-
assert_eq!(r.gen_range(0, 1), 0);
941-
assert_eq!(r.gen_range(-12, -11), -12);
942-
}
943-
944-
for _ in 0..1000 {
945-
let a = r.gen_range(10, 42);
946-
assert!(a >= 10 && a < 42);
947-
assert_eq!(r.gen_range(0, 1), 0);
940+
let a = r.gen_range(-4711, 17);
941+
assert!(a >= -4711 && a < 17);
942+
let a = r.gen_range(-3i8, 42);
943+
assert!(a >= -3i8 && a < 42i8);
944+
let a = r.gen_range(&10u16, 99);
945+
assert!(a >= 10u16 && a < 99u16);
946+
let a = r.gen_range(-100i32, &2000);
947+
assert!(a >= -100i32 && a < 2000i32);
948+
949+
assert_eq!(r.gen_range(0u32, 1), 0u32);
950+
assert_eq!(r.gen_range(-12i64, -11), -12i64);
948951
assert_eq!(r.gen_range(3_000_000, 3_000_001), 3_000_000);
949952
}
950-
951953
}
952954

953955
#[test]

0 commit comments

Comments
 (0)