Skip to content

Commit 9b80dc5

Browse files
authored
Add PRank (#734)
1 parent e3567c9 commit 9b80dc5

File tree

6 files changed

+193
-1
lines changed

6 files changed

+193
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ for (let i = 0; i < n; i++) {
122122
| task | model |
123123
| ---- | ----- |
124124
| clustering | (Soft / Kernel / Genetic / Weighted) k-means, k-means++, k-medois, k-medians, x-means, G-means, LBG, ISODATA, Fuzzy c-means, Possibilistic c-means, Agglomerative (complete linkage, single linkage, group average, Ward's, centroid, weighted average, median), DIANA, Monothetic, Mutual kNN, Mean shift, DBSCAN, OPTICS, HDBSCAN, DENCLUE, DBCLASD, CLUES, PAM, CLARA, CLARANS, BIRCH, CURE, ROCK, C2P, PLSA, Latent dirichlet allocation, GMM, VBGMM, Affinity propagation, Spectral clustering, Mountain, (Growing) SOM, GTM, (Growing) Neural gas, Growing cell structures, LVQ, ART, SVC, CAST, CHAMELEON, COLL, CLIQUE, PROCLUS, ORCLUS, FINDIT, NMF, Autoencoder |
125-
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, Ordered logistic, Ordered probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
125+
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, Ordered logistic, Ordered probit, PRank, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
126126
| semi-supervised classification | k-nearest neighbor, Radius neighbor, Label propagation, Label spreading, k-means, GMM, S3VM, Ladder network |
127127
| regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Naive Bayes, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MLP, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median |
128128
| interpolation | Nearest neighbor, IDW, (Spherical) Linear, Brahmagupta, Logarithmic, Cosine, (Inverse) Smoothstep, Cubic, (Centripetal) Catmull-Rom, Hermit, Polynomial, Lagrange, Trigonometric, Spline, RBF Network, Akima, Natural neighbor, Delaunay |

js/model_selector.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ const AIMethods = [
210210
Ranking: [
211211
{ value: 'ordered_logistic', title: 'Ordered logistic regression' },
212212
{ value: 'ordered_probit', title: 'Ordered probit regression' },
213+
{ value: 'prank', title: 'PRank' },
213214
],
214215
'': [
215216
{ value: 'least_square', title: 'Least squares' },

js/view/prank.js

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import PRank from '../../lib/model/prank.js'
2+
import Controller from '../controller.js'
3+
4+
export default function (platform) {
5+
platform.setting.ml.usage =
6+
'Click and add data point. Next, click "Initialize". Finally, click "Fit" button repeatedly.'
7+
platform.setting.ml.reference = {
8+
author: 'K. Crammer, Y. Singer',
9+
title: 'Pranking with Ranking ',
10+
year: 2001,
11+
}
12+
const controller = new Controller(platform)
13+
14+
let model = null
15+
const fitModel = () => {
16+
if (!model) {
17+
return
18+
}
19+
20+
model.fit(
21+
platform.trainInput,
22+
platform.trainOutput.map(v => v[0])
23+
)
24+
const pred = model.predict(platform.testInput(4))
25+
platform.testResult(pred)
26+
}
27+
28+
const rate = controller.input.number({ label: ' Learning rate ', value: 0.1, min: 0, max: 100, step: 0.1 })
29+
controller
30+
.stepLoopButtons()
31+
.init(() => {
32+
model = new PRank(rate.value)
33+
platform.init()
34+
})
35+
.step(fitModel)
36+
.epoch()
37+
}

lib/model/prank.js

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/**
2+
* Perceptron ranking
3+
*/
4+
export default class PRank {
5+
// Pranking with Ranking
6+
// https://proceedings.neurips.cc/paper_files/paper/2001/file/5531a5834816222280f20d1ef9e95f69-Paper.pdf
7+
/**
8+
* @param {number} [rate=0.1] Learning rate
9+
*/
10+
constructor(rate = 0.1) {
11+
this._w = null
12+
this._a = rate
13+
14+
this._b = [0, Infinity]
15+
this._min = 1
16+
}
17+
18+
/**
19+
* Fit model.
20+
*
21+
* @param {Array<Array<number>>} x Training data
22+
* @param {Array<number>} y Target values
23+
*/
24+
fit(x, y) {
25+
if (!this._w) {
26+
this._w = Array(x[0].length).fill(0)
27+
}
28+
29+
for (let k = 0; k < x.length; k++) {
30+
if (y[k] < this._min) {
31+
this._b.splice(0, 0, ...Array(this._min - y[k]).fill(this._b[0]))
32+
this._min = y[k]
33+
} else if (y[k] >= this._min + this._b.length) {
34+
this._b.splice(
35+
this._b.length - 1,
36+
0,
37+
...Array(this._min + this._b.length - y[k] + 1).fill(this._b[this._b.length - 2])
38+
)
39+
}
40+
41+
const p = this._w.reduce((s, v, i) => s + v * x[k][i], 0)
42+
let r = 0
43+
for (; r < this._b.length; r++) {
44+
if (p - this._b[r] < 0) break
45+
}
46+
const yh = r + this._min
47+
if (y[k] === yh) continue
48+
let t = 0
49+
for (let i = 0; i < this._b.length - 1; i++) {
50+
const yt = y[k] <= i + this._min ? -1 : 1
51+
if ((p - this._b[i]) * yt <= 0) {
52+
t += yt
53+
this._b[i] -= this._a * yt
54+
}
55+
}
56+
for (let m = 0; m < this._w.length; m++) {
57+
this._w[m] += this._a * t * x[k][m]
58+
}
59+
}
60+
}
61+
62+
/**
63+
* Returns predicted values.
64+
*
65+
* @param {Array<Array<number>>} x Sample data
66+
* @returns {Array<number>} Predicted values
67+
*/
68+
predict(x) {
69+
const p = []
70+
for (let k = 0; k < x.length; k++) {
71+
const v = this._w.reduce((s, v, i) => s + v * x[k][i], 0)
72+
let r = 0
73+
for (; r < this._b.length; r++) {
74+
if (v - this._b[r] < 0) break
75+
}
76+
p[k] = r + this._min
77+
}
78+
return p
79+
}
80+
}

tests/gui/view/prank.test.js

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import { getPage } from '../helper/browser'
2+
3+
describe('classification', () => {
4+
/** @type {Awaited<ReturnType<getPage>>} */
5+
let page
6+
beforeEach(async () => {
7+
page = await getPage()
8+
})
9+
10+
afterEach(async () => {
11+
await page?.close()
12+
})
13+
14+
test('initialize', async () => {
15+
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
16+
await taskSelectBox.selectOption('CF')
17+
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
18+
await modelSelectBox.selectOption('prank')
19+
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
20+
const buttons = await methodMenu.waitForSelector('.buttons')
21+
22+
const rate = await buttons.waitForSelector('input:nth-of-type(1)')
23+
await expect((await rate.getProperty('value')).jsonValue()).resolves.toBe('0.1')
24+
})
25+
26+
test('learn', async () => {
27+
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
28+
await taskSelectBox.selectOption('CF')
29+
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
30+
await modelSelectBox.selectOption('prank')
31+
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
32+
const buttons = await methodMenu.waitForSelector('.buttons')
33+
34+
const epoch = await buttons.waitForSelector('[name=epoch]')
35+
await expect(epoch.evaluate(el => el.textContent)).resolves.toBe('0')
36+
const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' })
37+
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toBe('')
38+
39+
const initButton = await buttons.waitForSelector('input[value=Initialize]')
40+
await initButton.evaluate(el => el.click())
41+
const stepButton = await buttons.waitForSelector('input[value=Step]:enabled')
42+
await stepButton.evaluate(el => el.click())
43+
44+
await expect(epoch.evaluate(el => el.textContent)).resolves.toBe('1')
45+
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toMatch(/^Accuracy:[0-9.]+$/)
46+
})
47+
})

tests/lib/model/prank.test.js

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import { jest } from '@jest/globals'
2+
jest.retryTimes(10)
3+
4+
import Matrix from '../../../lib/util/matrix.js'
5+
import PRank from '../../../lib/model/prank.js'
6+
7+
import { rmse } from '../../../lib/evaluate/regression.js'
8+
9+
describe('ordinal', () => {
10+
test('fit', () => {
11+
const model = new PRank()
12+
const x = Matrix.concat(
13+
Matrix.concat(Matrix.randn(50, 2, -5, 0.2), Matrix.randn(50, 2, 0, 0.2)),
14+
Matrix.concat(Matrix.randn(50, 2, 5, 0.2), Matrix.randn(50, 2, 10, 0.2))
15+
).toArray()
16+
const t = []
17+
for (let i = 0; i < x.length; i++) {
18+
t[i] = Math.floor(i / 50)
19+
}
20+
for (let i = 0; i < 100; i++) {
21+
model.fit(x, t)
22+
}
23+
const y = model.predict(x)
24+
const err = rmse(y, t)
25+
expect(err).toBeLessThan(0.5)
26+
})
27+
})

0 commit comments

Comments
 (0)