33
44#include " ggml_extend.hpp"
55#include " gits_noise.inl"
6- #include < boost/math/distributions/beta.hpp>
76
87/* ================================================= CompVisDenoiser ==================================================*/
98
@@ -252,6 +251,124 @@ struct KarrasSchedule : SigmaSchedule {
252251 }
253252};
254253
254+ #if 0 // boost version
255+
256+ #include <boost/math/distributions/beta.hpp>
257+
258+ struct BetaDist {
259+ boost::math::beta_distribution<double> dist;
260+
261+ BetaDist(double a = 0.6f, double b = 0.6f)
262+ : dist(a, b) {}
263+
264+ double quantile(double u) {
265+ return boost::math::quantile(dist, u);
266+ }
267+ };
268+
269+ #else // local version
270+
271+ struct BetaDist {
272+
273+ BetaDist (double a=0 .6f , double b=0 .6f )
274+ : alpha(a), beta(b) {}
275+
276+ // Beta quantile function using Newton-Raphson method
277+ double quantile (double u) {
278+ if (u <= 0.0 ) return 0.0 ;
279+ if (u >= 1.0 ) return 1.0 ;
280+
281+ double x = u < 0.5 ? u * u : 1.0 - (1.0 - u) * (1.0 - u);
282+
283+ const int max_iterations = 50 ;
284+ const double tolerance = 1e-12 ;
285+
286+ for (int i = 0 ; i < max_iterations; ++i) {
287+ double err = beta_cdf (x) - u;
288+ if (std::abs (err) < tolerance) {
289+ break ;
290+ }
291+
292+ double derivative = beta_pdf (x);
293+ if (std::abs (derivative) < 1e-30 ) {
294+ break ;
295+ }
296+
297+ double new_x = x - err / derivative;
298+ x = std::max (0.0 , std::min (1.0 , new_x));
299+ }
300+
301+ return x;
302+ }
303+
304+ private:
305+ double alpha;
306+ double beta;
307+
308+ double lbeta_ab () {
309+ return std::lgamma (alpha) + std::lgamma (beta) - std::lgamma (alpha + beta);
310+ }
311+
312+ // Beta probability density function
313+ double beta_pdf (double x) {
314+ if (x <= 0.0 || x >= 1.0 )
315+ return 0.0 ;
316+ return std::exp (-lbeta_ab ()) * std::pow (x, alpha - 1.0 ) * std::pow (1.0 - x, beta - 1.0 );
317+ }
318+
319+ // Beta cumulative distribution function
320+ double beta_cdf (double x) {
321+ if (x <= 0.0 ) return 0.0 ;
322+ if (x >= 1.0 ) return 1.0 ;
323+ if (x > (alpha + 1.0 ) / (alpha + beta + 2.0 )) {
324+ // use symmetry relation
325+ return 1.0 - incomplete_beta (beta, alpha, 1.0 - x);
326+ }
327+ else {
328+ return incomplete_beta (alpha, beta, x);
329+ }
330+ }
331+
332+ // Incomplete beta function using continued fraction representation
333+ double incomplete_beta (double a, double b, double x) {
334+
335+ double f = 1.0 , c = 1.0 , d = 0.0 ;
336+ const int max_iterations = 200 ;
337+ const double tolerance = 1e-15 ;
338+
339+ for (int i = 0 ; i <= max_iterations; ++i) {
340+ int m = i / 2 ;
341+ double numerator;
342+
343+ if (i == 0 ) {
344+ numerator = 1.0 ;
345+ } else if (i % 2 == 0 ) {
346+ numerator = (m * (b - m) * x) / ((a + 2.0 * m - 1.0 ) * (a + 2.0 * m));
347+ } else {
348+ numerator = -((a + m) * (a + b + m) * x) / ((a + 2.0 * m) * (a + 2.0 * m + 1.0 ));
349+ }
350+
351+ d = 1.0 + numerator * d;
352+ if (std::abs (d) < 1e-30 ) d = 1e-30 ;
353+ d = 1.0 / d;
354+
355+ c = 1.0 + numerator / c;
356+ if (std::abs (c) < 1e-30 ) c = 1e-30 ;
357+
358+ double cd = c * d;
359+ f *= cd;
360+
361+ if (std::abs (cd - 1.0 ) < tolerance) {
362+ break ;
363+ }
364+ }
365+
366+ return (std::exp (a * std::log (x) + b * std::log (1.0 - x) - lbeta_ab ()) / a) * (f - 1.0 );
367+ }
368+ };
369+
370+ #endif
371+
255372struct BetaSchedule : SigmaSchedule {
256373 float alpha = 0 .6f ;
257374 float beta = 0 .6f ;
@@ -272,15 +389,15 @@ struct BetaSchedule : SigmaSchedule {
272389 }
273390
274391 // Beta-Verteilung (wie scipy.stats.beta.ppf)
275- boost::math::beta_distribution< double > dist (alpha, beta) ;
392+ BetaDist beta_distribution ;
276393
277394 int last_t = -1 ;
278395 for (uint32_t i = 0 ; i < n; i++) {
279396 // entspricht ts = 1 - linspace(0,1,n,endpoint=False)
280397 double u = 1.0 - static_cast <double >(i) / static_cast <double >(n);
281398
282399 // ppf(ts) * total_timesteps
283- double t_cont = quantile (dist, u) * t_max;
400+ double t_cont = beta_distribution. quantile (u) * t_max;
284401 int t = (int )std::lround (t_cont);
285402
286403 if (t != last_t ) {
0 commit comments