@@ -154,19 +154,43 @@ class MATHS_COMMON_EXPORT CNaiveBayesFeatureDensityFromPrior final
154154 TPriorPtr m_Prior;
155155};
156156
157+ // ! \brief Enables using custom feature weights in class prediction.
158+ class CNaiveBayesFeatureWeight {
159+ public:
160+ virtual ~CNaiveBayesFeatureWeight () = default ;
161+ virtual void add (std::size_t class_, double logLikelihood) = 0;
162+ virtual double calculate () const = 0;
163+ };
164+
157165// ! \brief Implements a Naive Bayes classifier.
158166class MATHS_COMMON_EXPORT CNaiveBayes {
159167public:
168+ using TDoubleDoublePr = std::pair<double , double >;
160169 using TDoubleSizePr = std::pair<double , std::size_t >;
161170 using TDoubleSizePrVec = std::vector<TDoubleSizePr>;
171+ using TDoubleSizePrVecDoublePr = std::pair<TDoubleSizePrVec, double >;
162172 using TDouble1Vec = core::CSmallVector<double , 1 >;
163173 using TDouble1VecVec = std::vector<TDouble1Vec>;
164- using TOptionalDouble = std::optional<double >;
174+ using TFeatureWeightProvider = std::function<CNaiveBayesFeatureWeight&()>;
175+
176+ private:
177+ // ! \brief All features have unit weight in class prediction.
178+ class CUnitFeatureWeight : public CNaiveBayesFeatureWeight {
179+ public:
180+ void add (std::size_t , double ) override {}
181+ double calculate () const override { return 1.0 ; }
182+ };
183+
184+ class CUnitFeatureWeightProvider {
185+ public:
186+ CUnitFeatureWeight& operator ()() const { return m_UnitWeight; }
187+
188+ private:
189+ mutable CUnitFeatureWeight m_UnitWeight;
190+ };
165191
166192public:
167- explicit CNaiveBayes (const CNaiveBayesFeatureDensity& exemplar,
168- double decayRate = 0.0 ,
169- TOptionalDouble minMaxLogLikelihoodToUseFeature = TOptionalDouble());
193+ explicit CNaiveBayes (const CNaiveBayesFeatureDensity& exemplar, double decayRate = 0.0 );
170194 CNaiveBayes (const CNaiveBayesFeatureDensity& exemplar,
171195 const SDistributionRestoreParams& params,
172196 core::CStateRestoreTraverser& traverser);
@@ -184,6 +208,9 @@ class MATHS_COMMON_EXPORT CNaiveBayes {
184208 // ! Check if any training data has been added initialized.
185209 bool initialized () const ;
186210
211+ // ! Get the number of classes.
212+ std::size_t numberClasses () const ;
213+
187214 // ! This can be used to optionally seed the class counts
188215 // ! with \p counts. These are added on to data class counts
189216 // ! to compute the class posterior probabilities.
@@ -210,27 +237,53 @@ class MATHS_COMMON_EXPORT CNaiveBayes {
210237 // !
211238 // ! \param[in] n The number of class probabilities to estimate.
212239 // ! \param[in] x The feature values.
240+ // ! \param[in] weightProvider Computes a feature weight from the class
241+ // ! conditional log-likelihood of the feature value. It should be in
242+ // ! the range [0,1]. The smaller the value the less impact the feature
243+ // ! has on class selection.
244+ // ! \return The class probabilities and the minimum feature weight.
213245 // ! \note \p x size should be equal to the number of features.
214246 // ! A feature is missing is indicated by passing an empty vector
215247 // ! for that feature.
216- TDoubleSizePrVec highestClassProbabilities (std::size_t n, const TDouble1VecVec& x) const ;
248+ TDoubleSizePrVecDoublePr highestClassProbabilities (
249+ std::size_t n,
250+ const TDouble1VecVec& x,
251+ const TFeatureWeightProvider& weightProvider = CUnitFeatureWeightProvider{}) const ;
217252
218253 // ! Get the probability of the class labeled \p label for \p x.
219254 // !
220255 // ! \param[in] label The label of the class of interest.
221256 // ! \param[in] x The feature values.
257+ // ! \param[in] weightProvider Computes a feature weight from the class
258+ // ! conditional log-likelihood of the feature value. It should be in
259+ // ! the range [0,1]. The smaller the value the less impact the feature
260+ // ! has on class selection.
261+ // ! \return The class probabilities and the minimum feature weight.
262+ // ! conditional distributions.
222263 // ! \note \p x size should be equal to the number of features.
223264 // ! A feature is missing is indicated by passing an empty vector
224265 // ! for that feature.
225- double classProbability (std::size_t label, const TDouble1VecVec& x) const ;
266+ TDoubleDoublePr classProbability (std::size_t label,
267+ const TDouble1VecVec& x,
268+ const TFeatureWeightProvider& weightProvider =
269+ CUnitFeatureWeightProvider{}) const ;
226270
227271 // ! Get the probabilities of all the classes for \p x.
228272 // !
229273 // ! \param[in] x The feature values.
274+ // ! \param[in] weightProvider Computes a feature weight from the class
275+ // ! conditional log-likelihood of the feature value. It should be in
276+ // ! the range [0,1]. The smaller the value the less impact the feature
277+ // ! has on class selection.
278+ // ! \return The class probabilities and the minimum feature weight.
279+ // ! A feature is missing is indicated by passing an empty vector
280+ // ! for that feature.
230281 // ! \note \p x size should be equal to the number of features.
231282 // ! A feature is missing is indicated by passing an empty vector
232283 // ! for that feature.
233- TDoubleSizePrVec classProbabilities (const TDouble1VecVec& x) const ;
284+ TDoubleSizePrVecDoublePr
285+ classProbabilities (const TDouble1VecVec& x,
286+ const TFeatureWeightProvider& weightProvider = CUnitFeatureWeightProvider{}) const ;
234287
235288 // ! Debug the memory used by this object.
236289 void debugMemoryUsage (const core::CMemoryUsage::TMemoryUsagePtr& mem) const ;
@@ -298,13 +351,6 @@ class MATHS_COMMON_EXPORT CNaiveBayes {
298351 bool validate (const TDouble1VecVec& x) const ;
299352
300353private:
301- // ! It is not always appropriate to use features with very low
302- // ! probability in all classes to discriminate: the class choice
303- // ! will be very sensitive to the underlying conditional density
304- // ! model. This is a cutoff (for the minimum maximum class log
305- // ! likelihood) in order to use a feature.
306- TOptionalDouble m_MinMaxLogLikelihoodToUseFeature;
307-
308354 // ! Controls the rate at which data are aged out.
309355 double m_DecayRate;
310356
0 commit comments