Skip to content

Commit 3e6c09f

Browse files
authored
Add OAP-BPM (#736)
1 parent 9b80dc5 commit 3e6c09f

File tree

6 files changed

+242
-1
lines changed

6 files changed

+242
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ for (let i = 0; i < n; i++) {
122122
| task | model |
123123
| ---- | ----- |
124124
| clustering | (Soft / Kernel / Genetic / Weighted) k-means, k-means++, k-medois, k-medians, x-means, G-means, LBG, ISODATA, Fuzzy c-means, Possibilistic c-means, Agglomerative (complete linkage, single linkage, group average, Ward's, centroid, weighted average, median), DIANA, Monothetic, Mutual kNN, Mean shift, DBSCAN, OPTICS, HDBSCAN, DENCLUE, DBCLASD, CLUES, PAM, CLARA, CLARANS, BIRCH, CURE, ROCK, C2P, PLSA, Latent dirichlet allocation, GMM, VBGMM, Affinity propagation, Spectral clustering, Mountain, (Growing) SOM, GTM, (Growing) Neural gas, Growing cell structures, LVQ, ART, SVC, CAST, CHAMELEON, COLL, CLIQUE, PROCLUS, ORCLUS, FINDIT, NMF, Autoencoder |
125-
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, Ordered logistic, Ordered probit, PRank, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
125+
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, Ordered logistic, Ordered probit, PRank, OAP-BPM, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
126126
| semi-supervised classification | k-nearest neighbor, Radius neighbor, Label propagation, Label spreading, k-means, GMM, S3VM, Ladder network |
127127
| regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Naive Bayes, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MLP, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median |
128128
| interpolation | Nearest neighbor, IDW, (Spherical) Linear, Brahmagupta, Logarithmic, Cosine, (Inverse) Smoothstep, Cubic, (Centripetal) Catmull-Rom, Hermit, Polynomial, Lagrange, Trigonometric, Spline, RBF Network, Akima, Natural neighbor, Delaunay |

js/model_selector.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ const AIMethods = [
211211
{ value: 'ordered_logistic', title: 'Ordered logistic regression' },
212212
{ value: 'ordered_probit', title: 'Ordered probit regression' },
213213
{ value: 'prank', title: 'PRank' },
214+
{ value: 'oapbpm', title: 'OAP-BPM' },
214215
],
215216
'': [
216217
{ value: 'least_square', title: 'Least squares' },

js/view/oapbpm.js

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import OAPBPM from '../../lib/model/oapbpm.js'
2+
import Controller from '../controller.js'
3+
4+
export default function (platform) {
5+
platform.setting.ml.usage =
6+
'Click and add data point. Next, click "Initialize". Finally, click "Fit" button repeatedly.'
7+
platform.setting.ml.reference = {
8+
author: 'R. F. Harrington',
9+
title: 'Online Ranking/Collaborative filtering using the Perceptron Algorithm',
10+
year: 2003,
11+
}
12+
const controller = new Controller(platform)
13+
14+
let model = null
15+
const fitModel = () => {
16+
if (!model) {
17+
return
18+
}
19+
20+
model.fit(
21+
platform.trainInput,
22+
platform.trainOutput.map(v => v[0])
23+
)
24+
const pred = model.predict(platform.testInput(4))
25+
platform.testResult(pred)
26+
}
27+
28+
const n = controller.input.number({ label: ' N ', value: 10, min: 1, max: 100 })
29+
const tau = controller.input.number({ label: ' Tau ', value: 0.5, min: 0, max: 1, step: 0.1 })
30+
const rate = controller.input.number({ label: ' Learning rate ', value: 0.1, min: 0, max: 100, step: 0.1 })
31+
controller
32+
.stepLoopButtons()
33+
.init(() => {
34+
model = new OAPBPM(n.value, tau.value, rate.value)
35+
platform.init()
36+
})
37+
.step(fitModel)
38+
.epoch()
39+
}

lib/model/oapbpm.js

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/**
2+
* Online Aggregate Prank-Bayes Point Machine
3+
*/
4+
export default class OAPBPM {
5+
// Online Ranking/Collaborative filtering using the Perceptron Algorithm
6+
// https://cdn.aaai.org/ICML/2003/ICML03-035.pdf
7+
/**
8+
* @param {number} n Number of PRank models
9+
* @param {number} tau Probability to learn
10+
* @param {number} [rate=0.1] Learning rate
11+
*/
12+
constructor(n, tau, rate = 0.1) {
13+
this._n = n
14+
this._tau = tau
15+
this._wh = null
16+
this._w = []
17+
this._a = rate
18+
19+
this._bh = [0, Infinity]
20+
this._b = Array.from({ length: n }, () => [0, Infinity])
21+
this._min = 1
22+
}
23+
24+
_update(x, y, k) {
25+
const p = this._w[k].reduce((s, v, i) => s + v * x[i], 0)
26+
let r = 0
27+
for (; r < this._b[k].length; r++) {
28+
if (p - this._b[k][r] < 0) break
29+
}
30+
const yh = r + this._min
31+
if (y === yh) return
32+
let t = 0
33+
for (let i = 0; i < this._b[k].length - 1; i++) {
34+
const yt = y <= i + this._min ? -1 : 1
35+
if ((p - this._b[k][i]) * yt <= 0) {
36+
t += yt
37+
this._b[k][i] -= this._a * yt
38+
}
39+
}
40+
for (let m = 0; m < this._w[k].length; m++) {
41+
this._w[k][m] += this._a * t * x[m]
42+
}
43+
}
44+
45+
/**
46+
* Fit model.
47+
*
48+
* @param {Array<Array<number>>} x Training data
49+
* @param {Array<number>} y Target values
50+
*/
51+
fit(x, y) {
52+
if (!this._wh) {
53+
for (let i = 0; i < this._n; i++) {
54+
this._w[i] = Array(x[0].length).fill(0)
55+
}
56+
}
57+
58+
for (let k = 0; k < x.length; k++) {
59+
if (y[k] < this._min) {
60+
for (let j = 0; j < this._b.length; j++) {
61+
this._b[j].splice(0, 0, ...Array(this._min - y[k]).fill(this._b[j][0]))
62+
}
63+
this._min = y[k]
64+
} else if (y[k] >= this._min + this._b[0].length) {
65+
for (let j = 0; j < this._b.length; j++) {
66+
this._b[j].splice(
67+
this._b[j].length - 1,
68+
0,
69+
...Array(this._min + this._b[j].length - y[k] + 1).fill(this._b[j][this._b[j].length - 2])
70+
)
71+
}
72+
}
73+
74+
for (let j = 0; j < this._n; j++) {
75+
const p = this._w[j].reduce((s, v, i) => s + v * x[k][i], 0)
76+
let r = 0
77+
for (; r < this._b[j].length; r++) {
78+
if (p - this._b[j][r] < 0) break
79+
}
80+
const yh = r + this._min
81+
if (Math.random() < this._tau && y[k] !== yh) {
82+
let t = 0
83+
for (let i = 0; i < this._b[j].length - 1; i++) {
84+
const yt = y[k] <= i + this._min ? -1 : 1
85+
if ((p - this._b[j][i]) * yt <= 0) {
86+
t += yt
87+
this._b[j][i] -= this._a * yt
88+
}
89+
}
90+
for (let m = 0; m < this._w[j].length; m++) {
91+
this._w[j][m] += this._a * t * x[k][m]
92+
}
93+
}
94+
}
95+
}
96+
this._wh = Array(this._w[0].length).fill(0)
97+
this._bh = Array(this._b[0].length).fill(0)
98+
for (let j = 0; j < this._n; j++) {
99+
for (let m = 0; m < this._wh.length; m++) {
100+
this._wh[m] += this._w[j][m] / this._n
101+
}
102+
for (let i = 0; i < this._bh.length; i++) {
103+
this._bh[i] += this._b[j][i] / this._n
104+
}
105+
}
106+
}
107+
108+
/**
109+
* Returns predicted values.
110+
*
111+
* @param {Array<Array<number>>} x Sample data
112+
* @returns {Array<number>} Predicted values
113+
*/
114+
predict(x) {
115+
const p = []
116+
for (let k = 0; k < x.length; k++) {
117+
const v = this._wh.reduce((s, v, i) => s + v * x[k][i], 0)
118+
let r = 0
119+
for (; r < this._bh.length; r++) {
120+
if (v - this._bh[r] < 0) break
121+
}
122+
p[k] = r + this._min
123+
}
124+
return p
125+
}
126+
}

tests/gui/view/oapbpm.test.js

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { getPage } from '../helper/browser'
2+
3+
describe('classification', () => {
4+
/** @type {Awaited<ReturnType<getPage>>} */
5+
let page
6+
beforeEach(async () => {
7+
page = await getPage()
8+
})
9+
10+
afterEach(async () => {
11+
await page?.close()
12+
})
13+
14+
test('initialize', async () => {
15+
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
16+
await taskSelectBox.selectOption('CF')
17+
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
18+
await modelSelectBox.selectOption('oapbpm')
19+
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
20+
const buttons = await methodMenu.waitForSelector('.buttons')
21+
22+
const n = await buttons.waitForSelector('input:nth-of-type(1)')
23+
await expect((await n.getProperty('value')).jsonValue()).resolves.toBe('10')
24+
const tau = await buttons.waitForSelector('input:nth-of-type(2)')
25+
await expect((await tau.getProperty('value')).jsonValue()).resolves.toBe('0.5')
26+
const rate = await buttons.waitForSelector('input:nth-of-type(3)')
27+
await expect((await rate.getProperty('value')).jsonValue()).resolves.toBe('0.1')
28+
})
29+
30+
test('learn', async () => {
31+
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
32+
await taskSelectBox.selectOption('CF')
33+
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
34+
await modelSelectBox.selectOption('oapbpm')
35+
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
36+
const buttons = await methodMenu.waitForSelector('.buttons')
37+
38+
const epoch = await buttons.waitForSelector('[name=epoch]')
39+
await expect(epoch.evaluate(el => el.textContent)).resolves.toBe('0')
40+
const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' })
41+
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toBe('')
42+
43+
const initButton = await buttons.waitForSelector('input[value=Initialize]')
44+
await initButton.evaluate(el => el.click())
45+
const stepButton = await buttons.waitForSelector('input[value=Step]:enabled')
46+
await stepButton.evaluate(el => el.click())
47+
48+
await expect(epoch.evaluate(el => el.textContent)).resolves.toBe('1')
49+
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toMatch(/^Accuracy:[0-9.]+$/)
50+
})
51+
})

tests/lib/model/oapbpm.test.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import Matrix from '../../../lib/util/matrix.js'
2+
import OAPBPM from '../../../lib/model/oapbpm.js'
3+
4+
import { rmse } from '../../../lib/evaluate/regression.js'
5+
6+
describe('ordinal', () => {
7+
test('fit', () => {
8+
const model = new OAPBPM(10, 0.3)
9+
const x = Matrix.concat(
10+
Matrix.concat(Matrix.randn(50, 2, -5, 0.2), Matrix.randn(50, 2, 0, 0.2)),
11+
Matrix.concat(Matrix.randn(50, 2, 5, 0.2), Matrix.randn(50, 2, 10, 0.2))
12+
).toArray()
13+
const t = []
14+
for (let i = 0; i < x.length; i++) {
15+
t[i] = Math.floor(i / 50)
16+
}
17+
for (let i = 0; i < 100; i++) {
18+
model.fit(x, t)
19+
}
20+
const y = model.predict(x)
21+
const err = rmse(y, t)
22+
expect(err).toBeLessThan(0.5)
23+
})
24+
})

0 commit comments

Comments
 (0)