diff --git a/lib/model/zip.js b/lib/model/zip.js index afc68e9b1..ae90b5a3b 100644 --- a/lib/model/zip.js +++ b/lib/model/zip.js @@ -7,8 +7,11 @@ export default class ZeroInflatedPoisson { // https://qiita.com/nozma/items/52211b1bacaa8a898164 // http://web.uvic.ca/~dgiles/downloads/count/zip.pdf // https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/NCSS/Zero-Inflated_Poisson_Regression.pdf - constructor() { - this._method = 'ml' + /** + * @param {'moments' | 'maximum_likelihood'} [method=maximum_likelihood] Method name + */ + constructor(method = 'maximum_likelihood') { + this._method = method } /** @@ -17,7 +20,7 @@ export default class ZeroInflatedPoisson { * @param {number[]} x Training data */ fit(x) { - if (this._method === 'mo') { + if (this._method === 'moments') { this._mo(x) } else { this._ml(x) diff --git a/tests/lib/model/zip.test.js b/tests/lib/model/zip.test.js index 1de8fcedc..0397c2905 100644 --- a/tests/lib/model/zip.test.js +++ b/tests/lib/model/zip.test.js @@ -14,33 +14,31 @@ const random_poisson = l => { return k - 1 } -test('density estimation', () => { - const model = new ZeroInflatedPoisson() - const x = [] - for (let i = 0; i < 10000; i++) { - const r = Math.random() - if (r < 0.5) { - x.push(0) - } else { - x.push(random_poisson(1)) +describe('density estimation', () => { + test.each([undefined, 'moments', 'maximum_likelihood'])('%s', method => { + const model = new ZeroInflatedPoisson(method) + const x = [] + for (let i = 0; i < 10000; i++) { + const r = Math.random() + x.push(r < 0.5 ? 0 : random_poisson(1)) } - } - model.fit(x) + model.fit(x) - const y = [0, 1, 2, 3, 4, 5] - const p = Array(y.length).fill(0) - p[0] += 0.5 - for (let i = 0; i < y.length; i++) { - let f = 1 - for (let k = 2; k <= i; k++) { - f *= k + const y = [0, 1, 2, 3, 4, 5] + const p = Array(y.length).fill(0) + p[0] += 0.5 + for (let i = 0; i < y.length; i++) { + let f = 1 + for (let k = 2; k <= i; k++) { + f *= k + } + p[i] += ((1 / f) * Math.exp(-1)) / 2 } - p[i] += ((1 / f) * Math.exp(-1)) / 2 - } - const pred = model.probability(y) - for (let i = 0; i < y.length; i++) { - expect(pred[i]).toBeCloseTo(p[i]) - } + const pred = model.probability(y) + for (let i = 0; i < y.length; i++) { + expect(pred[i]).toBeCloseTo(p[i]) + } + }) })