Skip to content

Commit 2050ffe

Browse files
committed
test alternative Beta implementation
1 parent f382a48 commit 2050ffe

File tree

1 file changed

+120
-3
lines changed

1 file changed

+120
-3
lines changed

denoiser.hpp

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include "ggml_extend.hpp"
55
#include "gits_noise.inl"
6-
#include <boost/math/distributions/beta.hpp>
76

87
/*================================================= CompVisDenoiser ==================================================*/
98

@@ -252,6 +251,124 @@ struct KarrasSchedule : SigmaSchedule {
252251
}
253252
};
254253

254+
#if 0 // boost version
255+
256+
#include <boost/math/distributions/beta.hpp>
257+
258+
struct BetaDist {
259+
boost::math::beta_distribution<double> dist;
260+
261+
BetaDist(double a = 0.6f, double b = 0.6f)
262+
: dist(a, b) {}
263+
264+
double quantile(double u) {
265+
return boost::math::quantile(dist, u);
266+
}
267+
};
268+
269+
#else // local version
270+
271+
struct BetaDist {
272+
273+
BetaDist(double a=0.6f, double b=0.6f)
274+
: alpha(a), beta(b) {}
275+
276+
// Beta quantile function using Newton-Raphson method
277+
double quantile(double u) {
278+
if (u <= 0.0) return 0.0;
279+
if (u >= 1.0) return 1.0;
280+
281+
double x = u < 0.5 ? u * u : 1.0 - (1.0 - u) * (1.0 - u);
282+
283+
const int max_iterations = 50;
284+
const double tolerance = 1e-12;
285+
286+
for (int i = 0; i < max_iterations; ++i) {
287+
double err = beta_cdf(x) - u;
288+
if (std::abs(err) < tolerance) {
289+
break;
290+
}
291+
292+
double derivative = beta_pdf(x);
293+
if (std::abs(derivative) < 1e-30) {
294+
break;
295+
}
296+
297+
double new_x = x - err / derivative;
298+
x = std::max(0.0, std::min(1.0, new_x));
299+
}
300+
301+
return x;
302+
}
303+
304+
private:
305+
double alpha;
306+
double beta;
307+
308+
double lbeta_ab() {
309+
return std::lgamma(alpha) + std::lgamma(beta) - std::lgamma(alpha + beta);
310+
}
311+
312+
// Beta probability density function
313+
double beta_pdf(double x) {
314+
if (x <= 0.0 || x >= 1.0)
315+
return 0.0;
316+
return std::exp(-lbeta_ab()) * std::pow(x, alpha - 1.0) * std::pow(1.0 - x, beta - 1.0);
317+
}
318+
319+
// Beta cumulative distribution function
320+
double beta_cdf(double x) {
321+
if (x <= 0.0) return 0.0;
322+
if (x >= 1.0) return 1.0;
323+
if (x > (alpha + 1.0) / (alpha + beta + 2.0)) {
324+
// use symmetry relation
325+
return 1.0 - incomplete_beta(beta, alpha, 1.0 - x);
326+
}
327+
else {
328+
return incomplete_beta(alpha, beta, x);
329+
}
330+
}
331+
332+
// Incomplete beta function using continued fraction representation
333+
double incomplete_beta(double a, double b, double x) {
334+
335+
double f = 1.0, c = 1.0, d = 0.0;
336+
const int max_iterations = 200;
337+
const double tolerance = 1e-15;
338+
339+
for (int i = 0; i <= max_iterations; ++i) {
340+
int m = i / 2;
341+
double numerator;
342+
343+
if (i == 0) {
344+
numerator = 1.0;
345+
} else if (i % 2 == 0) {
346+
numerator = (m * (b - m) * x) / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
347+
} else {
348+
numerator = -((a + m) * (a + b + m) * x) / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
349+
}
350+
351+
d = 1.0 + numerator * d;
352+
if (std::abs(d) < 1e-30) d = 1e-30;
353+
d = 1.0 / d;
354+
355+
c = 1.0 + numerator / c;
356+
if (std::abs(c) < 1e-30) c = 1e-30;
357+
358+
double cd = c * d;
359+
f *= cd;
360+
361+
if (std::abs(cd - 1.0) < tolerance) {
362+
break;
363+
}
364+
}
365+
366+
return (std::exp(a * std::log(x) + b * std::log(1.0 - x) - lbeta_ab()) / a) * (f - 1.0);
367+
}
368+
};
369+
370+
#endif
371+
255372
struct BetaSchedule : SigmaSchedule {
256373
float alpha = 0.6f;
257374
float beta = 0.6f;
@@ -272,15 +389,15 @@ struct BetaSchedule : SigmaSchedule {
272389
}
273390

274391
// Beta-Verteilung (wie scipy.stats.beta.ppf)
275-
boost::math::beta_distribution<double> dist(alpha, beta);
392+
BetaDist beta_distribution;
276393

277394
int last_t = -1;
278395
for (uint32_t i = 0; i < n; i++) {
279396
// entspricht ts = 1 - linspace(0,1,n,endpoint=False)
280397
double u = 1.0 - static_cast<double>(i) / static_cast<double>(n);
281398

282399
// ppf(ts) * total_timesteps
283-
double t_cont = quantile(dist, u) * t_max;
400+
double t_cont = beta_distribution.quantile(u) * t_max;
284401
int t = (int)std::lround(t_cont);
285402

286403
if (t != last_t) {

0 commit comments

Comments
 (0)