From 14dd7787c74c30c6efc182c0e4b6ff004bc3cae8 Mon Sep 17 00:00:00 2001 From: Tom de Geus Date: Fri, 19 Aug 2022 15:07:07 +0200 Subject: [PATCH 1/2] average: using `XTENSOR_ASSERT_MSG` instead of `XTENSOR_THROW` --- include/xtensor/xmath.hpp | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/include/xtensor/xmath.hpp b/include/xtensor/xmath.hpp index 836e864bd..71d64af29 100644 --- a/include/xtensor/xmath.hpp +++ b/include/xtensor/xmath.hpp @@ -1975,21 +1975,13 @@ namespace detail { auto ax = normalize_axis(e, axes); if (weights.dimension() == 1) { - if (weights.size() != e.shape()[ax[0]]) - { - XTENSOR_THROW(std::runtime_error, "Weights need to have the same shape as expression at axes."); - } - + XTENSOR_ASSERT_MSG(weights.size() == e.shape()[ax[0]], "Weights need to have the same shape as expression at axes."); std::fill(broadcast_shape.begin(), broadcast_shape.end(), std::size_t(1)); broadcast_shape[ax[0]] = weights.size(); } else { - if (!same_shape(e.shape(), weights.shape())) - { - XTENSOR_THROW(std::runtime_error, "Weights with dim > 1 need to have the same shape as expression."); - } - + XTENSOR_ASSERT_MSG(same_shape(e.shape(), weights.shape()), "Weights with dim > 1 need to have the same shape as expression."); std::copy(e.shape().begin(), e.shape().end(), broadcast_shape.begin()); } From 9cdf2c911678eb3753b33d02adcc9001914c4979 Mon Sep 17 00:00:00 2001 From: Tom de Geus Date: Fri, 19 Aug 2022 15:07:19 +0200 Subject: [PATCH 2/2] average: adding test to find bug --- test/test_xmath.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/test_xmath.cpp b/test/test_xmath.cpp index 30c189b55..9258cdae0 100644 --- a/test/test_xmath.cpp +++ b/test/test_xmath.cpp @@ -889,6 +889,27 @@ namespace xt EXPECT_EQ(xt::average(v, w, {0, 1})(), m); } + TEST(xmath, average_random) + { + xt::xtensor v = xt::random::rand({4, 5, 6, 7}); + xt::xtensor w = xt::random::rand({4, 5, 6, 7}) + 1.0; + xt::xtensor r = xt::zeros({6, 7}); + xt::xtensor n = xt::zeros({6, 7}); + + for (size_t i = 0; i < v.shape(0); ++i) { + for (size_t j = 0; j < v.shape(1); ++j) { + for (size_t k = 0; k < v.shape(2); ++k) { + for (size_t l = 0; l < v.shape(3); ++l) { + r(k, l) += v(i, j, k, l) * w(i, j, k, l); + n(k, l) += w(i, j, k, l); + } + } + } + } + + EXPECT_TRUE(xt::allclose(xt::average(v, w, {0, 1}), xt::eval(r / n))); + } + /************************ * Linear interpolation * ************************/