Skip to content

Commit d5fac14

Browse files
authored
Add Multiclass Ridge classifier (#697)
* Add Multiclass Ridge classifier * Fix test * Fix least square
1 parent d8d002b commit d5fac14

File tree

7 files changed

+181
-106
lines changed

7 files changed

+181
-106
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ for (let i = 0; i < n; i++) {
122122
| task | model |
123123
| ---- | ----- |
124124
| clustering | (Soft / Kernel / Genetic / Weighted) k-means, k-means++, k-medois, k-medians, x-means, G-means, LBG, ISODATA, Fuzzy c-means, Possibilistic c-means, Agglomerative (complete linkage, single linkage, group average, Ward's, centroid, weighted average, median), DIANA, Monothetic, Mutual kNN, Mean shift, DBSCAN, OPTICS, HDBSCAN, DENCLUE, DBCLASD, CLUES, PAM, CLARA, CLARANS, BIRCH, CURE, ROCK, C2P, PLSA, Latent dirichlet allocation, GMM, VBGMM, Affinity propagation, Spectral clustering, Mountain, (Growing) SOM, GTM, (Growing) Neural gas, Growing cell structures, LVQ, ART, SVC, CAST, CHAMELEON, COLL, CLIQUE, PROCLUS, ORCLUS, FINDIT, NMF, Autoencoder |
125-
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
125+
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
126126
| semi-supervised classification | k-nearest neighbor, Radius neighbor, Label propagation, Label spreading, k-means, GMM, S3VM, Ladder network |
127127
| regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MLP, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median |
128128
| interpolation | Nearest neighbor, IDW, (Spherical) Linear, Brahmagupta, Logarithmic, Cosine, (Inverse) Smoothstep, Cubic, (Centripetal) Catmull-Rom, Hermit, Polynomial, Lagrange, Trigonometric, Spline, RBF Network, Akima, Natural neighbor, Delaunay |

js/view/least_square.js

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import Matrix from '../../lib/util/matrix.js'
33
import LeastSquares from '../../lib/model/least_square.js'
44
import stringToFunction from '../expression.js'
55
import EnsembleBinaryModel from '../../lib/model/ensemble_binary.js'
6+
import Controller from '../controller.js'
67

78
const combination_repetition = (n, k) => {
89
const c = []
@@ -56,6 +57,7 @@ export class BasisFunctions {
5657
}
5758

5859
makeHtml(r) {
60+
r = d3.select(r)
5961
if (!this._e) {
6062
this._e = r.append('div').attr('id', `ls_model_${this._name}`)
6163
} else {
@@ -160,47 +162,12 @@ export class BasisFunctions {
160162
}
161163
}
162164

163-
var dispLeastSquares = function (elm, platform) {
165+
export default function (platform) {
166+
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
164167
platform.setting.ml.reference = {
165168
title: 'Least squares (Wikipedia)',
166169
url: 'https://en.wikipedia.org/wiki/Least_squares',
167170
}
168-
const fitModel = () => {
169-
let model
170-
if (platform.task === 'CF') {
171-
const method = elm.select('[name=method]').property('value')
172-
model = new EnsembleBinaryModel(LeastSquares, method)
173-
} else {
174-
model = new LeastSquares()
175-
}
176-
model.fit(basisFunctions.apply(platform.trainInput).toArray(), platform.trainOutput)
177-
178-
let pred = model.predict(basisFunctions.apply(platform.testInput(2)).toArray())
179-
platform.testResult(pred)
180-
}
181-
182-
if (platform.task === 'CF') {
183-
elm.append('select')
184-
.attr('name', 'method')
185-
.selectAll('option')
186-
.data(['oneone', 'onerest'])
187-
.enter()
188-
.append('option')
189-
.property('value', d => d)
190-
.text(d => d)
191-
}
192-
const basisFunctions = new BasisFunctions(platform)
193-
basisFunctions.makeHtml(elm)
194-
195-
elm.append('input')
196-
.attr('type', 'button')
197-
.attr('value', 'Fit')
198-
.on('click', () => fitModel())
199-
}
200-
201-
export default function (platform) {
202-
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
203-
dispLeastSquares(platform.setting.ml.configElement, platform)
204171
platform.setting.ml.detail = `
205172
The model form is
206173
$$
@@ -218,4 +185,26 @@ $$
218185
$$
219186
where $ G_{ij} = g_i(x_j) $.
220187
`
188+
const controller = new Controller(platform)
189+
const fitModel = () => {
190+
let model
191+
if (platform.task === 'CF') {
192+
model = new EnsembleBinaryModel(LeastSquares, method.value)
193+
} else {
194+
model = new LeastSquares()
195+
}
196+
model.fit(basisFunctions.apply(platform.trainInput).toArray(), platform.trainOutput)
197+
198+
let pred = model.predict(basisFunctions.apply(platform.testInput(2)).toArray())
199+
platform.testResult(pred)
200+
}
201+
202+
let method = null
203+
if (platform.task === 'CF') {
204+
method = controller.select(['oneone', 'onerest'])
205+
}
206+
const basisFunctions = new BasisFunctions(platform)
207+
basisFunctions.makeHtml(controller.element)
208+
209+
controller.input.button('Fit').on('click', () => fitModel())
221210
}

js/view/ridge.js

Lines changed: 46 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,52 @@
11
import Matrix from '../../lib/util/matrix.js'
22

33
import { BasisFunctions } from './least_square.js'
4-
import { Ridge, KernelRidge } from '../../lib/model/ridge.js'
4+
import { Ridge, MulticlassRidge, KernelRidge } from '../../lib/model/ridge.js'
55
import EnsembleBinaryModel from '../../lib/model/ensemble_binary.js'
6+
import Controller from '../controller.js'
67

7-
var dispRidge = function (elm, platform) {
8+
export default function (platform) {
9+
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
810
platform.setting.ml.reference = {
911
title: 'Ridge regression (Wikipedia)',
1012
url: 'https://en.wikipedia.org/wiki/Ridge_regression',
1113
}
14+
platform.setting.ml.detail = `
15+
The model form is
16+
$$
17+
f(X) = X W + \\epsilon
18+
$$
19+
20+
The loss function can be written as
21+
$$
22+
L(W) = \\| X W - y \\|^2 + \\lambda \\| W \\|^2
23+
$$
24+
where $ y $ is the observed value corresponding to $ X $.
25+
Therefore, the optimum parameter $ \\hat{W} $ is estimated as
26+
$$
27+
\\hat{W} = \\left( X^T X + \\lambda I \\right)^{-1} X^T y
28+
$$
29+
`
30+
const controller = new Controller(platform)
1231
const task = platform.task
13-
const fitModel = cb => {
32+
const fitModel = () => {
1433
const dim = platform.datas.dimension
15-
const kernel = elm.select('[name=kernel]').property('value')
16-
const kernelName = kernel === 'no kernel' ? null : kernel
34+
const kernelName = kernel.value === 'no kernel' ? null : kernel.value
1735
let model
18-
const l = +elm.select('[name=lambda]').property('value')
36+
const l = +lambda.value
1937
if (task === 'CF') {
20-
const method = elm.select('[name=method]').property('value')
2138
if (kernelName) {
2239
model = new EnsembleBinaryModel(function () {
2340
return new KernelRidge(l, kernelName)
24-
}, method)
41+
}, method.value)
2542
} else {
26-
model = new EnsembleBinaryModel(function () {
27-
return new Ridge(l)
28-
}, method)
43+
if (method.value === 'multiclass') {
44+
model = new MulticlassRidge(l)
45+
} else {
46+
model = new EnsembleBinaryModel(function () {
47+
return new Ridge(l)
48+
}, method.value)
49+
}
2950
}
3051
} else {
3152
if (kernelName) {
@@ -53,61 +74,23 @@ var dispRidge = function (elm, platform) {
5374
}
5475

5576
const basisFunction = new BasisFunctions(platform)
77+
let method = null
5678
if (task === 'CF') {
57-
elm.append('select')
58-
.attr('name', 'method')
59-
.selectAll('option')
60-
.data(['oneone', 'onerest'])
61-
.enter()
62-
.append('option')
63-
.property('value', d => d)
64-
.text(d => d)
79+
method = controller.select(['oneone', 'onerest', 'multiclass']).on('change', () => {
80+
if (method.value === 'multiclass') {
81+
kernel.element.style.display = 'none'
82+
} else {
83+
kernel.element.style.display = null
84+
}
85+
})
6586
}
87+
let kernel = null
6688
if (task !== 'FS') {
67-
basisFunction.makeHtml(elm)
68-
elm.append('select')
69-
.attr('name', 'kernel')
70-
.selectAll('option')
71-
.data(['no kernel', 'gaussian'])
72-
.enter()
73-
.append('option')
74-
.property('value', d => d)
75-
.text(d => d)
89+
basisFunction.makeHtml(controller.element)
90+
kernel = controller.select(['no kernel', 'gaussian'])
7691
} else {
77-
elm.append('input').attr('type', 'hidden').attr('name', 'kernel').property('value', '')
92+
kernel = controller.input({ type: 'hidden', value: '' })
7893
}
79-
elm.append('span').text('lambda = ')
80-
elm.append('select')
81-
.attr('name', 'lambda')
82-
.selectAll('option')
83-
.data([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100])
84-
.enter()
85-
.append('option')
86-
.property('value', d => d)
87-
.text(d => d)
88-
elm.append('input')
89-
.attr('type', 'button')
90-
.attr('value', 'Fit')
91-
.on('click', () => fitModel())
92-
}
93-
94-
export default function (platform) {
95-
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
96-
dispRidge(platform.setting.ml.configElement, platform)
97-
platform.setting.ml.detail = `
98-
The model form is
99-
$$
100-
f(X) = X W + \\epsilon
101-
$$
102-
103-
The loss function can be written as
104-
$$
105-
L(W) = \\| X W - y \\|^2 + \\lambda \\| W \\|^2
106-
$$
107-
where $ y $ is the observed value corresponding to $ X $.
108-
Therefore, the optimum parameter $ \\hat{W} $ is estimated as
109-
$$
110-
\\hat{W} = \\left( X^T X + \\lambda I \\right)^{-1} X^T y
111-
$$
112-
`
94+
const lambda = controller.select({ label: 'lambda = ', values: [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100] })
95+
controller.input.button('Fit').on('click', () => fitModel())
11396
}

lib/model/ridge.js

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,73 @@ export class Ridge {
5050
}
5151
}
5252

53+
/**
54+
* Multiclass ridge regressioin
55+
*/
56+
export class MulticlassRidge {
57+
/**
58+
* @param {number} [lambda=0.1] Regularization strength
59+
*/
60+
constructor(lambda = 0.1) {
61+
this._w = null
62+
this._lambda = lambda
63+
this._classes = []
64+
}
65+
66+
/**
67+
* Category list
68+
*
69+
* @type {*[]}
70+
*/
71+
get categories() {
72+
return this._classes
73+
}
74+
75+
/**
76+
* Fit model.
77+
*
78+
* @param {Array<Array<number>>} x Training data
79+
* @param {*[]} y Target values
80+
*/
81+
fit(x, y) {
82+
x = Matrix.fromArray(x)
83+
this._classes = [...new Set(y)]
84+
const p = new Matrix(y.length, this._classes.length, -1)
85+
for (let i = 0; i < y.length; i++) {
86+
p.set(i, this._classes.indexOf(y[i]), 1)
87+
}
88+
const xtx = x.tDot(x)
89+
for (let i = 0; i < xtx.rows; i++) {
90+
xtx.addAt(i, i, this._lambda)
91+
}
92+
93+
this._w = xtx.solve(x.t).dot(p)
94+
}
95+
96+
/**
97+
* Returns predicted values.
98+
*
99+
* @param {Array<Array<number>>} x Sample data
100+
* @returns {*[]} Predicted values
101+
*/
102+
predict(x) {
103+
x = Matrix.fromArray(x)
104+
return x
105+
.dot(this._w)
106+
.argmax(1)
107+
.value.map(i => this._classes[i])
108+
}
109+
110+
/**
111+
* Returns importances of the features.
112+
*
113+
* @returns {number[]} Importances
114+
*/
115+
importance() {
116+
return this._w.toArray()
117+
}
118+
}
119+
53120
/**
54121
* Kernel ridge regression
55122
*/

tests/gui/view/least_square.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ describe('classification', () => {
1919
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
2020
const buttons = await methodMenu.waitForSelector('.buttons')
2121

22-
const methods = await buttons.waitForSelector('[name=method]')
22+
const methods = await buttons.waitForSelector('select:nth-of-type(1)')
2323
await expect((await methods.getProperty('value')).jsonValue()).resolves.toBe('oneone')
2424
const preset = await buttons.waitForSelector('[name=preset]')
2525
await expect((await preset.getProperty('value')).jsonValue()).resolves.toBe('linear')

tests/gui/view/ridge.test.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ describe('classification', () => {
1919
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
2020
const buttons = await methodMenu.waitForSelector('.buttons')
2121

22-
const methods = await buttons.waitForSelector('[name=method]')
22+
const methods = await buttons.waitForSelector('select:nth-of-type(1)')
2323
await expect((await methods.getProperty('value')).jsonValue()).resolves.toBe('oneone')
2424
const preset = await buttons.waitForSelector('[name=preset]')
2525
await expect((await preset.getProperty('value')).jsonValue()).resolves.toBe('linear')
26-
const kernel = await buttons.waitForSelector('[name=kernel]')
26+
const kernel = await buttons.waitForSelector('select:nth-of-type(2)')
2727
await expect((await kernel.getProperty('value')).jsonValue()).resolves.toBe('no kernel')
28-
const lambda = await buttons.waitForSelector('[name=lambda]')
28+
const lambda = await buttons.waitForSelector('select:nth-of-type(3)')
2929
await expect((await lambda.getProperty('value')).jsonValue()).resolves.toBe('0')
3030
}, 10000)
3131

tests/lib/model/ridge.test.js

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ import { jest } from '@jest/globals'
22
jest.retryTimes(3)
33

44
import Matrix from '../../../lib/util/matrix.js'
5-
import { Ridge, KernelRidge } from '../../../lib/model/ridge.js'
5+
import { Ridge, MulticlassRidge, KernelRidge } from '../../../lib/model/ridge.js'
66

77
import { rmse } from '../../../lib/evaluate/regression.js'
8+
import { accuracy } from '../../../lib/evaluate/classification.js'
89

910
describe('ridge', () => {
1011
test('default', () => {
@@ -40,6 +41,41 @@ describe('ridge', () => {
4041
})
4142
})
4243

44+
describe('multiclass ridge', () => {
45+
test('default', () => {
46+
const model = new MulticlassRidge()
47+
expect(model._lambda).toBe(0.1)
48+
})
49+
50+
test('fit', () => {
51+
const model = new MulticlassRidge(0.001)
52+
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, [0, 5], 0.2)).toArray()
53+
const t = []
54+
for (let i = 0; i < x.length; i++) {
55+
t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
56+
}
57+
model.fit(x, t)
58+
const y = model.predict(x)
59+
const acc = accuracy(y, t)
60+
expect(acc).toBeGreaterThan(0.75)
61+
})
62+
63+
test('importance', () => {
64+
const model = new MulticlassRidge(0.01)
65+
const x = Matrix.concat(Matrix.randn(50, 3, 0, 0.2), Matrix.randn(50, 3, 5, 0.2)).toArray()
66+
const t = []
67+
for (let i = 0; i < x.length; i++) {
68+
t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
69+
}
70+
model.fit(x, t)
71+
const importance = model.importance()
72+
expect(importance).toHaveLength(3)
73+
expect(importance[0]).toHaveLength(2)
74+
expect(importance[1]).toHaveLength(2)
75+
expect(importance[2]).toHaveLength(2)
76+
})
77+
})
78+
4379
describe('kernel ridge', () => {
4480
test('default', () => {
4581
const model = new KernelRidge()

0 commit comments

Comments
 (0)