|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License; |
| 4 | + * you may not use this file except in compliance with the Elastic License. |
| 5 | + */ |
| 6 | + |
| 7 | +#ifndef INCLUDED_ml_core_CImmutableRadixSet_h |
| 8 | +#define INCLUDED_ml_core_CImmutableRadixSet_h |
| 9 | + |
| 10 | +#include <core/CContainerPrinter.h> |
| 11 | + |
| 12 | +#include <algorithm> |
| 13 | +#include <limits> |
| 14 | +#include <numeric> |
| 15 | +#include <vector> |
| 16 | + |
| 17 | +namespace ml { |
| 18 | +namespace core { |
| 19 | + |
| 20 | +//! \brief An immutable sorted set which provides very fast lookup. |
| 21 | +//! |
| 22 | +//! DESCRIPTION:\n |
| 23 | +//! This supports lower bound and look up by index as well as a subset of the non |
| 24 | +//! modifying interface of std::set. Its main purpose is to provide much faster |
| 25 | +//! lookup. To this end it subdivides the range of sorted values into buckets. |
| 26 | +//! In the case that the values are uniformly distributed lowerBound will be O(1) |
| 27 | +//! with low constant. Otherwise, it is worst case O(log(n)). |
| 28 | +template<typename T> |
| 29 | +class CImmutableRadixSet { |
| 30 | +public: |
| 31 | + using TVec = std::vector<T>; |
| 32 | + using TCItr = typename std::vector<T>::const_iterator; |
| 33 | + |
| 34 | +public: |
| 35 | + // We only need to support floating point types at present (although it |
| 36 | + // could easily extended to support any numeric type). |
| 37 | + static_assert(std::is_floating_point<T>::value, "Only supports floating point types"); |
| 38 | + |
| 39 | +public: |
| 40 | + CImmutableRadixSet() = default; |
| 41 | + explicit CImmutableRadixSet(std::initializer_list<T> values) |
| 42 | + : m_Values{std::move(values)} { |
| 43 | + this->initialize(); |
| 44 | + } |
| 45 | + explicit CImmutableRadixSet(TVec values) : m_Values{std::move(values)} { |
| 46 | + this->initialize(); |
| 47 | + } |
| 48 | + |
| 49 | + // This is movable only because we hold iterators to the underlying container. |
| 50 | + CImmutableRadixSet(const CImmutableRadixSet&) = delete; |
| 51 | + CImmutableRadixSet& operator=(const CImmutableRadixSet&) = delete; |
| 52 | + CImmutableRadixSet(CImmutableRadixSet&&) = default; |
| 53 | + CImmutableRadixSet& operator=(CImmutableRadixSet&&) = default; |
| 54 | + |
| 55 | + //! \name Capacity |
| 56 | + //@{ |
| 57 | + bool empty() const { return m_Values.size(); } |
| 58 | + std::size_t size() const { return m_Values.size(); } |
| 59 | + //@} |
| 60 | + |
| 61 | + //! \name Iterators |
| 62 | + //@{ |
| 63 | + TCItr begin() const { m_Values.begin(); } |
| 64 | + TCItr end() const { m_Values.end(); } |
| 65 | + //@} |
| 66 | + |
| 67 | + //! \name Lookup |
| 68 | + //@{ |
| 69 | + const T& operator[](std::size_t i) const { return m_Values[i]; } |
| 70 | + std::ptrdiff_t upperBound(const T& value) const { |
| 71 | + // This branch is predictable so essentially free. |
| 72 | + if (m_Values.size() < 2) { |
| 73 | + return std::distance(m_Values.begin(), |
| 74 | + std::upper_bound(m_Values.begin(), m_Values.end(), value)); |
| 75 | + } |
| 76 | + |
| 77 | + std::ptrdiff_t bucket{static_cast<std::ptrdiff_t>(m_Scale * (value - m_Min))}; |
| 78 | + if (bucket < 0) { |
| 79 | + return 0; |
| 80 | + } |
| 81 | + if (bucket >= static_cast<std::ptrdiff_t>(m_Buckets.size())) { |
| 82 | + return static_cast<std::ptrdiff_t>(m_Values.size()); |
| 83 | + } |
| 84 | + TCItr beginBucket; |
| 85 | + TCItr endBucket; |
| 86 | + std::tie(beginBucket, endBucket) = m_Buckets[bucket]; |
| 87 | + return std::distance(m_Values.begin(), |
| 88 | + std::upper_bound(beginBucket, endBucket, value)); |
| 89 | + } |
| 90 | + //@} |
| 91 | + |
| 92 | + std::string print() const { |
| 93 | + return core::CContainerPrinter::print(m_Values); |
| 94 | + } |
| 95 | + |
| 96 | +private: |
| 97 | + using TCItrCItrPr = std::pair<TCItr, TCItr>; |
| 98 | + using TCItrCItrPrVec = std::vector<TCItrCItrPr>; |
| 99 | + using TPtrdiffVec = std::vector<std::ptrdiff_t>; |
| 100 | + |
| 101 | +private: |
| 102 | + void initialize() { |
| 103 | + std::sort(m_Values.begin(), m_Values.end()); |
| 104 | + m_Values.erase(std::unique(m_Values.begin(), m_Values.end()), m_Values.end()); |
| 105 | + if (m_Values.size() > 1) { |
| 106 | + std::size_t numberBuckets{m_Values.size()}; |
| 107 | + m_Min = m_Values[0]; |
| 108 | + m_Scale = static_cast<T>(numberBuckets) / (m_Values.back() - m_Min); |
| 109 | + m_Buckets.reserve(numberBuckets); |
| 110 | + T bucket{1}; |
| 111 | + T bucketClose{m_Min + bucket / m_Scale}; |
| 112 | + auto start = m_Values.begin(); |
| 113 | + for (auto i = m_Values.begin(); i != m_Values.end(); ++i) { |
| 114 | + if (*i > bucketClose) { |
| 115 | + m_Buckets.emplace_back(start, i); |
| 116 | + bucket += T{1}; |
| 117 | + bucketClose = m_Min + bucket / m_Scale; |
| 118 | + start = i; |
| 119 | + while (*i > bucketClose) { |
| 120 | + m_Buckets.emplace_back(start, i + 1); |
| 121 | + bucket += T{1}; |
| 122 | + bucketClose = m_Min + bucket / m_Scale; |
| 123 | + } |
| 124 | + } |
| 125 | + } |
| 126 | + if (m_Buckets.size() < numberBuckets) { |
| 127 | + m_Buckets.emplace_back(start, m_Values.end()); |
| 128 | + } |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | +private: |
| 133 | + T m_Min = T{0}; |
| 134 | + T m_Scale = T{0}; |
| 135 | + TCItrCItrPrVec m_Buckets; |
| 136 | + TVec m_Values; |
| 137 | +}; |
| 138 | +} |
| 139 | +} |
| 140 | + |
| 141 | +#endif // INCLUDED_ml_core_CImmutableRadixSet_h |
0 commit comments