diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b200fd3b6..63a6162718 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to `Decimal`/`Decimal256`. - cosmwasm-std: Implement `pow`/`saturating_pow` for `Decimal`/`Decimal256`. - cosmwasm-std: Implement `ceil`/`floor` for `Decimal`/`Decimal256`. +- cosmwasm-std: Implement `PartialEq` for reference on one side and owned value + on the other for all `Uint` and `Decimal` types - cosmwasm-std: Implement `saturating_add`/`sub`/`mul` for `Decimal`/`Decimal256`. - cosmwasm-std: Implement `MIN` const value for all `Uint` and `Decimal` types diff --git a/packages/std/src/math/decimal.rs b/packages/std/src/math/decimal.rs index 1bb985d626..5c808104a8 100644 --- a/packages/std/src/math/decimal.rs +++ b/packages/std/src/math/decimal.rs @@ -207,7 +207,7 @@ impl Decimal { /// Rounds value up after decimal places. Returns OverflowError on overflow. pub fn checked_ceil(&self) -> Result { let floor = self.floor(); - if &floor == self { + if floor == self { Ok(floor) } else { floor @@ -647,6 +647,18 @@ impl<'de> de::Visitor<'de> for DecimalVisitor { } } +impl PartialEq<&Decimal> for Decimal { + fn eq(&self, rhs: &&Decimal) -> bool { + self == *rhs + } +} + +impl PartialEq for &Decimal { + fn eq(&self, rhs: &Decimal) -> bool { + *self == rhs + } +} + #[cfg(test)] mod tests { use super::*; @@ -1750,7 +1762,7 @@ mod tests { ); let empty: Vec = vec![]; - assert_eq!(Decimal::zero(), empty.iter().sum()); + assert_eq!(Decimal::zero(), empty.iter().sum::()); } #[test] @@ -1983,4 +1995,24 @@ mod tests { Err(RoundUpOverflowError { .. }) )); } + + #[test] + fn decimal_partial_eq() { + let test_cases = [ + ("1", "1", true), + ("0.5", "0.5", true), + ("0.5", "0.51", false), + ("0", "0.00000", true), + ] + .into_iter() + .map(|(lhs, rhs, expected)| (dec(lhs), dec(rhs), expected)); + + #[allow(clippy::op_ref)] + for (lhs, rhs, expected) in test_cases { + assert_eq!(lhs == rhs, expected); + assert_eq!(&lhs == rhs, expected); + assert_eq!(lhs == &rhs, expected); + assert_eq!(&lhs == &rhs, expected); + } + } } diff --git a/packages/std/src/math/decimal256.rs b/packages/std/src/math/decimal256.rs index eed988c18d..f52be3264c 100644 --- a/packages/std/src/math/decimal256.rs +++ b/packages/std/src/math/decimal256.rs @@ -220,7 +220,7 @@ impl Decimal256 { /// Rounds value up after decimal places. Returns OverflowError on overflow. pub fn checked_ceil(&self) -> Result { let floor = self.floor(); - if &floor == self { + if floor == self { Ok(floor) } else { floor @@ -672,6 +672,18 @@ impl<'de> de::Visitor<'de> for Decimal256Visitor { } } +impl PartialEq<&Decimal256> for Decimal256 { + fn eq(&self, rhs: &&Decimal256) -> bool { + self == *rhs + } +} + +impl PartialEq for &Decimal256 { + fn eq(&self, rhs: &Decimal256) -> bool { + *self == rhs + } +} + #[cfg(test)] mod tests { use super::*; @@ -1885,7 +1897,7 @@ mod tests { ); let empty: Vec = vec![]; - assert_eq!(Decimal256::zero(), empty.iter().sum()); + assert_eq!(Decimal256::zero(), empty.iter().sum::()); } #[test] @@ -2130,4 +2142,24 @@ mod tests { ); assert_eq!(Decimal256::MAX.checked_ceil(), Err(RoundUpOverflowError)); } + + #[test] + fn decimal256_partial_eq() { + let test_cases = [ + ("1", "1", true), + ("0.5", "0.5", true), + ("0.5", "0.51", false), + ("0", "0.00000", true), + ] + .into_iter() + .map(|(lhs, rhs, expected)| (dec(lhs), dec(rhs), expected)); + + #[allow(clippy::op_ref)] + for (lhs, rhs, expected) in test_cases { + assert_eq!(lhs == rhs, expected); + assert_eq!(&lhs == rhs, expected); + assert_eq!(lhs == &rhs, expected); + assert_eq!(&lhs == &rhs, expected); + } + } } diff --git a/packages/std/src/math/uint128.rs b/packages/std/src/math/uint128.rs index 150166810a..b04f88ed71 100644 --- a/packages/std/src/math/uint128.rs +++ b/packages/std/src/math/uint128.rs @@ -513,6 +513,18 @@ where } } +impl PartialEq<&Uint128> for Uint128 { + fn eq(&self, rhs: &&Uint128) -> bool { + self == *rhs + } +} + +impl PartialEq for &Uint128 { + fn eq(&self, rhs: &Uint128) -> bool { + *self == rhs + } +} + #[cfg(test)] mod tests { use super::*; @@ -854,10 +866,10 @@ mod tests { let nums = vec![Uint128(17), Uint128(123), Uint128(540), Uint128(82)]; let expected = Uint128(762); - let sum_as_ref = nums.iter().sum(); + let sum_as_ref: Uint128 = nums.iter().sum(); assert_eq!(expected, sum_as_ref); - let sum_as_owned = nums.into_iter().sum(); + let sum_as_owned: Uint128 = nums.into_iter().sum(); assert_eq!(expected, sum_as_owned); } @@ -984,4 +996,19 @@ mod tests { assert_eq!(a.abs_diff(b), expected); assert_eq!(b.abs_diff(a), expected); } + + #[test] + fn uint128_partial_eq() { + let test_cases = [(1, 1, true), (42, 42, true), (42, 24, false), (0, 0, true)] + .into_iter() + .map(|(lhs, rhs, expected)| (Uint128::new(lhs), Uint128::new(rhs), expected)); + + #[allow(clippy::op_ref)] + for (lhs, rhs, expected) in test_cases { + assert_eq!(lhs == rhs, expected); + assert_eq!(&lhs == rhs, expected); + assert_eq!(lhs == &rhs, expected); + assert_eq!(&lhs == &rhs, expected); + } + } } diff --git a/packages/std/src/math/uint256.rs b/packages/std/src/math/uint256.rs index 3d7883b608..b7677c0118 100644 --- a/packages/std/src/math/uint256.rs +++ b/packages/std/src/math/uint256.rs @@ -621,6 +621,18 @@ where } } +impl PartialEq<&Uint256> for Uint256 { + fn eq(&self, rhs: &&Uint256) -> bool { + self == *rhs + } +} + +impl PartialEq for &Uint256 { + fn eq(&self, rhs: &Uint256) -> bool { + *self == rhs + } +} + #[cfg(test)] mod tests { use super::*; @@ -1410,10 +1422,10 @@ mod tests { ]; let expected = Uint256::from(762u32); - let sum_as_ref = nums.iter().sum(); + let sum_as_ref: Uint256 = nums.iter().sum(); assert_eq!(expected, sum_as_ref); - let sum_as_owned = nums.into_iter().sum(); + let sum_as_owned: Uint256 = nums.into_iter().sum(); assert_eq!(expected, sum_as_owned); } @@ -1563,4 +1575,21 @@ mod tests { assert_eq!(a.abs_diff(b), expected); assert_eq!(b.abs_diff(a), expected); } + + #[test] + fn uint256_partial_eq() { + let test_cases = [(1, 1, true), (42, 42, true), (42, 24, false), (0, 0, true)] + .into_iter() + .map(|(lhs, rhs, expected): (u64, u64, bool)| { + (Uint256::from(lhs), Uint256::from(rhs), expected) + }); + + #[allow(clippy::op_ref)] + for (lhs, rhs, expected) in test_cases { + assert_eq!(lhs == rhs, expected); + assert_eq!(&lhs == rhs, expected); + assert_eq!(lhs == &rhs, expected); + assert_eq!(&lhs == &rhs, expected); + } + } } diff --git a/packages/std/src/math/uint512.rs b/packages/std/src/math/uint512.rs index e53bad95d3..75556c1d82 100644 --- a/packages/std/src/math/uint512.rs +++ b/packages/std/src/math/uint512.rs @@ -582,6 +582,18 @@ where } } +impl PartialEq<&Uint512> for Uint512 { + fn eq(&self, rhs: &&Uint512) -> bool { + self == *rhs + } +} + +impl PartialEq for &Uint512 { + fn eq(&self, rhs: &Uint512) -> bool { + *self == rhs + } +} + #[cfg(test)] mod tests { use super::*; @@ -1045,10 +1057,10 @@ mod tests { ]; let expected = Uint512::from(762u32); - let sum_as_ref = nums.iter().sum(); + let sum_as_ref: Uint512 = nums.iter().sum(); assert_eq!(expected, sum_as_ref); - let sum_as_owned = nums.into_iter().sum(); + let sum_as_owned: Uint512 = nums.into_iter().sum(); assert_eq!(expected, sum_as_owned); } @@ -1198,4 +1210,21 @@ mod tests { assert_eq!(a.abs_diff(b), expected); assert_eq!(b.abs_diff(a), expected); } + + #[test] + fn uint512_partial_eq() { + let test_cases = [(1, 1, true), (42, 42, true), (42, 24, false), (0, 0, true)] + .into_iter() + .map(|(lhs, rhs, expected): (u64, u64, bool)| { + (Uint512::from(lhs), Uint512::from(rhs), expected) + }); + + #[allow(clippy::op_ref)] + for (lhs, rhs, expected) in test_cases { + assert_eq!(lhs == rhs, expected); + assert_eq!(&lhs == rhs, expected); + assert_eq!(lhs == &rhs, expected); + assert_eq!(&lhs == &rhs, expected); + } + } } diff --git a/packages/std/src/math/uint64.rs b/packages/std/src/math/uint64.rs index 3b1df29d17..e7f65db0f2 100644 --- a/packages/std/src/math/uint64.rs +++ b/packages/std/src/math/uint64.rs @@ -467,6 +467,18 @@ where } } +impl PartialEq<&Uint64> for Uint64 { + fn eq(&self, rhs: &&Uint64) -> bool { + self == *rhs + } +} + +impl PartialEq for &Uint64 { + fn eq(&self, rhs: &Uint64) -> bool { + *self == rhs + } +} + #[cfg(test)] mod tests { use super::*; @@ -769,10 +781,10 @@ mod tests { let nums = vec![Uint64(17), Uint64(123), Uint64(540), Uint64(82)]; let expected = Uint64(762); - let sum_as_ref = nums.iter().sum(); + let sum_as_ref: Uint64 = nums.iter().sum(); assert_eq!(expected, sum_as_ref); - let sum_as_owned = nums.into_iter().sum(); + let sum_as_owned: Uint64 = nums.into_iter().sum(); assert_eq!(expected, sum_as_owned); } @@ -897,4 +909,19 @@ mod tests { assert_eq!(a.abs_diff(b), expected); assert_eq!(b.abs_diff(a), expected); } + + #[test] + fn uint64_partial_eq() { + let test_cases = [(1, 1, true), (42, 42, true), (42, 24, false), (0, 0, true)] + .into_iter() + .map(|(lhs, rhs, expected)| (Uint64::new(lhs), Uint64::new(rhs), expected)); + + #[allow(clippy::op_ref)] + for (lhs, rhs, expected) in test_cases { + assert_eq!(lhs == rhs, expected); + assert_eq!(&lhs == rhs, expected); + assert_eq!(lhs == &rhs, expected); + assert_eq!(&lhs == &rhs, expected); + } + } }