Skip to content

Commit e9fe330

Browse files
authored
Add ZINB (#757)
1 parent 91430f7 commit e9fe330

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ for (let i = 0; i < n; i++) {
153153
| regression | Weighted least squares |
154154
| interpolation | Cubic convolution, Sinc, Lanczos, Bilinear, n-linear, n-cubic |
155155
| scaling | Max absolute scaler, Minmax normalization, Robust scaler, Standardization |
156-
| density estimation | ZIP, ZTP |
156+
| density estimation | ZINB, ZIP, ZTP |
157157
| density ratio estimation | RuLSIF |
158158

159159
## Models (meta)

lib/model/zinb.js

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
const logGamma = z => {
2+
// https://en.wikipedia.org/wiki/Lanczos_approximation
3+
// https://slpr.sakura.ne.jp/qp/gamma-function/
4+
let x = 0
5+
if (Number.isInteger(z)) {
6+
for (let i = 2; i < z; i++) {
7+
x += Math.log(i)
8+
}
9+
} else if (Number.isInteger(z - 0.5)) {
10+
const n = z - 0.5
11+
x = Math.log(Math.sqrt(Math.PI)) - Math.log(2) * n
12+
for (let i = 2 * n - 1; i > 0; i -= 2) {
13+
x += Math.log(i)
14+
}
15+
} else if (z < 0.5) {
16+
x = Math.log(Math.PI) - Math.log(Math.sin(Math.PI * z)) - logGamma(1 - z)
17+
} else {
18+
const p = [
19+
676.5203681218851, -1259.1392167224028, 771.32342877765313, -176.61502916214059, 12.507343278686905,
20+
-0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7,
21+
]
22+
z -= 1
23+
x = 0.99999999999980993
24+
for (let i = 0; i < p.length; i++) {
25+
x += p[i] / (z + i + 1)
26+
}
27+
const t = z + p.length - 0.5
28+
x = Math.log(Math.sqrt(2 * Math.PI)) + Math.log(t) * (z + 0.5) - t + Math.log(x)
29+
}
30+
return x
31+
}
32+
33+
/**
34+
* Zero-inflated negative binomial
35+
*/
36+
export default class ZeroInflatedNegativeBinomial {
37+
// https://stats.oarc.ucla.edu/r/dae/zinb/
38+
// https://juniperpublishers.com/bboaj/pdf/BBOAJ.MS.ID.555566.pdf
39+
// https://trialsjournal.biomedcentral.com/articles/10.1186/s13063-023-07648-8
40+
/**
41+
* Fit model.
42+
*
43+
* @param {number[]} x Training data
44+
*/
45+
fit(x) {
46+
const m = x.reduce((s, v) => s + v, 0) / x.length
47+
const s2 = x.reduce((s, v) => s + (v - m) ** 2, 0) / x.length
48+
49+
const counts = []
50+
for (let i = 0; i < x.length; i++) {
51+
counts[x[i]] = (counts[x[i]] ?? 0) + 1
52+
}
53+
54+
const calc_llh = pi => {
55+
const l = m / (1 - pi)
56+
const k = (s2 / m - 1) / l - pi
57+
if (k <= 1.0e-5) {
58+
return -Infinity
59+
}
60+
let llh = 0
61+
for (let i = 0; i < counts.length; i++) {
62+
if (counts[i]) {
63+
llh += counts[i] * Math.log(this._probability(i, pi, k, l))
64+
}
65+
}
66+
return llh
67+
}
68+
69+
const maxpi = (s2 / m - 1) / (m + s2 / m - 1)
70+
const values = [0, maxpi / 2, maxpi].map(p => [p, calc_llh(p)])
71+
while (values[2][0] - values[0][0] > 1.0e-8) {
72+
const llh_l = calc_llh((values[0][0] + values[1][0]) / 2)
73+
const llh_h = calc_llh((values[1][0] + values[2][0]) / 2)
74+
if (values[1][1] < llh_l) {
75+
values[2] = values[1]
76+
values[1] = [(values[0][0] + values[1][0]) / 2, llh_l]
77+
} else if (values[1][1] < llh_h) {
78+
values[0] = values[1]
79+
values[1] = [(values[1][0] + values[2][0]) / 2, llh_h]
80+
} else {
81+
values[0] = [(values[0][0] + values[1][0]) / 2, llh_l]
82+
values[2] = [(values[1][0] + values[2][0]) / 2, llh_h]
83+
}
84+
}
85+
this._pi = values[1][0]
86+
this._l = m / (1 - this._pi)
87+
this._k = (s2 / m - 1) / this._l - this._pi
88+
}
89+
90+
_probability(z, p, k, l) {
91+
if (z === 0) {
92+
return p + (1 - p) / (1 + k * l) ** (1 / k)
93+
}
94+
return (
95+
(1 - p) *
96+
Math.exp(
97+
logGamma(z + 1 / k) +
98+
z * Math.log(k * l) -
99+
logGamma(z + 1) -
100+
logGamma(1 / k) -
101+
(z + 1 / k) * Math.log(1 + k * l)
102+
)
103+
)
104+
}
105+
106+
/**
107+
* Returns predicted probabilities.
108+
*
109+
* @param {number[]} x Sample data
110+
* @returns {number[]} Predicted values
111+
*/
112+
probability(x) {
113+
return x.map(v => {
114+
return this._probability(v, this._pi, this._k, this._l)
115+
})
116+
}
117+
}

tests/lib/model/zinb.test.js

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import { jest } from '@jest/globals'
2+
jest.retryTimes(3)
3+
4+
import ZeroInflatedNegativeBinomial from '../../../lib/model/zinb.js'
5+
6+
const random_negative_binomial = (r, p) => {
7+
let rand = Math.random()
8+
let k = 0
9+
while (true) {
10+
let c = (1 - p) ** r * p ** k
11+
for (let i = 0; i < k; i++) {
12+
c *= k + r - 1 - i
13+
c /= i + 1
14+
}
15+
rand -= c
16+
if (rand < 0) {
17+
return k
18+
}
19+
k++
20+
}
21+
}
22+
23+
test('density estimation', () => {
24+
const model = new ZeroInflatedNegativeBinomial()
25+
const r = 10
26+
const prob = 0.5
27+
const zero_prob = 0.4
28+
const x = []
29+
for (let i = 0; i < 10000; i++) {
30+
const zr = Math.random()
31+
if (zr < zero_prob) {
32+
x.push(0)
33+
} else {
34+
x.push(random_negative_binomial(r, prob))
35+
}
36+
}
37+
38+
model.fit(x)
39+
40+
const y = []
41+
const p = []
42+
p[0] = zero_prob
43+
for (let t = 0; t < 20; t++) {
44+
let c = (1 - prob) ** r * prob ** t
45+
for (let i = 0; i < t; i++) {
46+
c *= t + r - 1 - i
47+
c /= i + 1
48+
}
49+
y[t] = t
50+
if (t === 0) {
51+
p[0] = zero_prob + c * (1 - zero_prob)
52+
} else {
53+
p[t] = c * (1 - zero_prob)
54+
}
55+
}
56+
57+
const pred = model.probability(y)
58+
for (let i = 0; i < y.length; i++) {
59+
expect(pred[i]).toBeCloseTo(p[i])
60+
}
61+
})

0 commit comments

Comments
 (0)