Skip to content

Commit d237846

Browse files
authored
Add diffusion model (#980)
* Add diffusion model * Fix full layer and improve tests * Initialize clusters and reset data in diffusion model tests * Add tests for Matrix class methods _to_position and _to_index * Add comprehensive tests for DiffusionModel with 2D and 3D layers, including custom configurations and validation of generated outputs.
1 parent 0177ca3 commit d237846

File tree

10 files changed

+447
-3
lines changed

10 files changed

+447
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ for (let i = 0; i < n; i++) {
132132
| feature selection | Mutual information, Ridge, Lasso, Elastic net, Decision tree, NCA |
133133
| transformation | Box-Cox, Yeo-Johnson |
134134
| density estimation | Histogram, Average shifted histogram, Polynomial histogram, Maximum likelihood, Kernel density estimation, k-nearest neighbor, Naive Bayes, GMM, HMM |
135-
| generate | MH, Slice sampling, GMM, GBRBM, HMM, VAE, GAN, NICE |
135+
| generate | MH, Slice sampling, GMM, GBRBM, HMM, VAE, GAN, NICE, Diffusion |
136136
| smoothing | (Linear weighted / Triangular / Cumulative) Moving average, Exponential average, Moving median, KZ filter, Savitzky Golay filter, Hampel filter, Kalman filter, Particle filter, Lowpass filter, Bessel filter, Butterworth filter, Chebyshev filter, Elliptic filter |
137137
| timeseries prediction | Holt winters, AR, ARMA, SDAR, VAR, Kalman filter, MLP, RNN |
138138
| change point detection | Cumulative sum, k-nearest neighbor, LOF, COF, SST, KLIEP, LSIF, uLSIF, LSDD, HMM, Markov switching |

js/model_selector.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ const AIMethods = [
484484
{ value: 'vae', title: 'VAE' },
485485
{ value: 'gan', title: 'GAN' },
486486
{ value: 'nice', title: 'NICE' },
487+
{ value: 'diffusion_model', title: 'Diffusion Model' },
487488
],
488489
},
489490
{

js/view/diffusion_model.js

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import Controller from '../controller.js'
2+
import { BaseWorker } from '../utils.js'
3+
4+
class DiffusionModelWorker extends BaseWorker {
5+
constructor() {
6+
super('js/view/worker/model_worker.js', { type: 'module' })
7+
}
8+
9+
initialize(timesteps) {
10+
return this._postMessage({ name: 'diffusion_model', method: 'constructor', arguments: [timesteps] })
11+
}
12+
13+
epoch() {
14+
return this._postMessage({ name: 'diffusion_model', method: 'epoch' }).then(r => r.data)
15+
}
16+
17+
fit(train_x, iteration, rate, batch) {
18+
return this._postMessage({
19+
name: 'diffusion_model',
20+
method: 'fit',
21+
arguments: [train_x, iteration, rate, batch],
22+
}).then(r => r.data)
23+
}
24+
25+
generate(n) {
26+
return this._postMessage({ name: 'diffusion_model', method: 'generate', arguments: [n] }).then(r => r.data)
27+
}
28+
}
29+
30+
export default function (platform) {
31+
platform.setting.ml.usage =
32+
'Click and add data point. Next, click "Initialize". Finally, click "Fit" button repeatedly.'
33+
const controller = new Controller(platform)
34+
const model = new DiffusionModelWorker()
35+
let epoch = 0
36+
37+
const fitModel = async () => {
38+
if (platform.datas.length === 0) {
39+
return
40+
}
41+
const tx = platform.trainInput
42+
const loss = await model.fit(tx, +iteration.value, rate.value, batch.value)
43+
epoch = await model.epoch()
44+
platform.plotLoss(loss[0])
45+
const gen_data = await model.generate(tx.length)
46+
platform.trainResult = gen_data
47+
}
48+
49+
const genValues = async () => {
50+
const ty = platform.trainOutput
51+
genBtn.element.disabled = true
52+
const gen_data = await model.generate(platform.trainInput.length, ty)
53+
genBtn.element.disabled = false
54+
console.log(gen_data)
55+
platform.trainResult = gen_data
56+
}
57+
58+
const slbConf = controller.stepLoopButtons().init(done => {
59+
model.initialize(100).then(done)
60+
platform.init()
61+
})
62+
const iteration = controller.select({ label: ' Iteration ', values: [1, 10, 100, 1000, 10000] })
63+
iteration.value = 10
64+
const rate = controller.input.number({ label: 'Learning rate ', min: 0, max: 100, step: 0.01, value: 0.01 })
65+
const batch = controller.input.number({ label: ' Batch size ', min: 1, max: 100, value: 10 })
66+
slbConf.step(fitModel).epoch(() => epoch)
67+
const genBtn = controller.input.button('Generate').on('click', genValues)
68+
69+
return () => {
70+
model.terminate()
71+
}
72+
}

lib/model/diffusion_model.js

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import Matrix from '../util/matrix.js'
2+
import Tensor from '../util/tensor.js'
3+
import NeuralNetwork from './neuralnetwork.js'
4+
5+
/**
6+
* Diffusion model network
7+
*/
8+
export default class DiffusionModel {
9+
// https://qiita.com/pocokhc/items/5a015ee5b527a357dd67
10+
/**
11+
* @param {number} timesteps Number of timestep
12+
* @param {LayerObject[]} [layers] Layers
13+
*/
14+
constructor(timesteps, layers) {
15+
this._timesteps = timesteps
16+
this._ulayers = layers
17+
this._peDims = 32
18+
19+
this._model = null
20+
this._epoch = 0
21+
22+
const betaStart = 0.0001
23+
const betaEnd = 0.02
24+
const betaStep = (betaEnd - betaStart) / (this._timesteps - 1)
25+
this._beta = [betaStart]
26+
for (let t = 1; t < this._timesteps - 1; t++) {
27+
this._beta[t] = betaStart + betaStep * t
28+
}
29+
this._beta.push(betaEnd)
30+
this._alpha = [1 - this._beta[0]]
31+
this._alphaCumprod = [this._alpha[0]]
32+
for (let t = 1; t < this._beta.length; t++) {
33+
this._alpha[t] = 1 - this._beta[t]
34+
this._alphaCumprod[t] = this._alphaCumprod[t - 1] * this._alpha[t]
35+
}
36+
}
37+
38+
/**
39+
* Epoch
40+
* @type {number}
41+
*/
42+
get epoch() {
43+
return this._epoch
44+
}
45+
46+
_addNoise(x, t) {
47+
const at = this._alphaCumprod[t]
48+
const sqrtat = Math.sqrt(at)
49+
const sqrt1at = Math.sqrt(1 - at)
50+
const noize = Tensor.randn(x.sizes)
51+
const xNoised = x.copy()
52+
xNoised.broadcastOperate(noize, (a, b) => sqrtat * a + sqrt1at * b)
53+
return [xNoised, noize]
54+
}
55+
56+
_build() {
57+
if (this._dataShape.length === 1) {
58+
this._layers = [
59+
{ type: 'input', name: 'x' },
60+
{ type: 'input', name: 'position_encoding' },
61+
{ type: 'full', out_size: this._peDims, l2_decay: 0.001, activation: 'gelu', name: 'pe' },
62+
{ type: 'concat', input: ['x', 'pe'], axis: 1 },
63+
]
64+
if (this._ulayers) {
65+
this._layers.push(...this._ulayers)
66+
} else {
67+
this._layers.push(
68+
{ type: 'full', out_size: 32, l2_decay: 0.001, name: 'c1', activation: 'tanh' },
69+
{ type: 'full', out_size: 16, l2_decay: 0.001, activation: 'tanh' },
70+
{ type: 'full', out_size: 32, l2_decay: 0.001, name: 'u1', activation: 'tanh' },
71+
{ type: 'concat', input: ['u1', 'c1'], axis: 1 },
72+
{ type: 'full', out_size: 32, l2_decay: 0.001, activation: 'tanh' }
73+
)
74+
}
75+
this._layers.push({ type: 'full', out_size: this._dataShape[0], l2_decay: 0.001 }, { type: 'output' })
76+
} else {
77+
const dim = this._dataShape.length
78+
this._layers = [
79+
{ type: 'input', name: 'x' },
80+
{ type: 'input', name: 'position_encoding' },
81+
{ type: 'full', out_size: this._peDims, l2_decay: 0.001, activation: 'gelu' },
82+
{ type: 'reshape', size: [...Array(dim - 1).fill(1), this._peDims] },
83+
{ type: 'up_sampling', size: this._dataShape.slice(0, dim - 1), name: 'pe' },
84+
{ type: 'concat', input: ['x', 'pe'], axis: dim },
85+
]
86+
if (this._ulayers) {
87+
this._layers.push(...this._ulayers)
88+
} else {
89+
this._layers.push(
90+
{
91+
type: 'conv',
92+
kernel: 3,
93+
channel: 16,
94+
padding: 1,
95+
l2_decay: 0.001,
96+
name: 'c1',
97+
activation: 'relu',
98+
},
99+
{ type: 'max_pool', kernel: 2 },
100+
{ type: 'conv', kernel: 3, channel: 32, padding: 1, l2_decay: 0.001, activation: 'relu' },
101+
{ type: 'up_sampling', size: 2, name: 'u1' },
102+
{ type: 'concat', input: ['u1', 'c1'], axis: dim },
103+
{ type: 'conv', kernel: 3, channel: 16, padding: 1, l2_decay: 0.001, activation: 'relu' }
104+
)
105+
}
106+
this._layers.push(
107+
{ type: 'conv', kernel: 1, channel: this._dataShape[dim - 1], l2_decay: 0.001 },
108+
{ type: 'output' }
109+
)
110+
}
111+
112+
return NeuralNetwork.fromObject(this._layers, 'mse', 'adam')
113+
}
114+
115+
_positionEncoding(t, embdims) {
116+
const rates = Array.from({ length: embdims }, (_, i) => t / 10000 ** (2 * Math.floor(i / 2)) / embdims)
117+
const pe = rates.map((v, i) => (i % 2 === 0 ? Math.sin(v) : Math.cos(v)))
118+
return new Matrix(1, embdims, pe)
119+
}
120+
121+
/**
122+
* Fit model.
123+
* @param {Array<Array<number>>} train_x Training data
124+
* @param {number} iteration Iteration count
125+
* @param {number} rate Learning rate
126+
* @param {number} batch Batch size
127+
* @returns {{labeledLoss: number, unlabeledLoss: number}} Loss value
128+
*/
129+
fit(train_x, iteration, rate, batch) {
130+
const x = Tensor.fromArray(train_x)
131+
this._dataShape = x.sizes.slice(1)
132+
if (!this._model) {
133+
this._model = this._build()
134+
}
135+
let loss = null
136+
for (let i = 0; i < iteration; i++) {
137+
const t = Math.floor(Math.random() * this._timesteps)
138+
const pe = this._positionEncoding(t, this._peDims)
139+
pe.repeat(x.sizes[0], 0)
140+
const [noised_x, noise] = this._addNoise(x, t)
141+
142+
loss = this._model.fit({ x: noised_x, position_encoding: pe }, Tensor.fromArray(noise), 1, rate, batch)
143+
}
144+
this._epoch += iteration
145+
return loss
146+
}
147+
148+
/**
149+
* Returns generated data from the model.
150+
* @param {number} n Number of generated data
151+
* @returns {Array<Array<number>>} Generated values
152+
*/
153+
generate(n) {
154+
const ds = this._dataShape.concat()
155+
const samples = Tensor.randn([n, ...ds])
156+
for (let t = this._timesteps - 1; t >= 0; t--) {
157+
const pe = this._positionEncoding(t, this._peDims)
158+
pe.repeat(n, 0)
159+
160+
const pred = this._model.calc({ x: samples, position_encoding: pe })
161+
162+
samples.broadcastOperate(
163+
pred,
164+
(a, b) =>
165+
(1 / Math.sqrt(this._alpha[t])) * (a - (b * this._beta[t]) / Math.sqrt(1 - this._alphaCumprod[t]))
166+
)
167+
if (t > 0) {
168+
const s2 = ((1 - this._alphaCumprod[t - 1]) / (1 - this._alphaCumprod[t])) * this._beta[t]
169+
const noise = Tensor.randn(samples.sizes, 0, s2)
170+
samples.broadcastOperate(noise, (a, b) => a + b)
171+
}
172+
}
173+
174+
return samples.toArray()
175+
}
176+
}

lib/model/nns/layer/full.js

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,20 @@ export default class FullyConnected extends Layer {
9393
i = this._i.copy()
9494
i.reshape(-1, this._w.rows)
9595
i = i.toMatrix()
96+
} else if (!(this._i instanceof Matrix)) {
97+
i = i.toMatrix()
9698
}
9799
let b = bo
98100
if (b.dimension !== 2) {
99101
b = bo.copy()
100102
b.reshape(-1, this._w.cols)
101103
b = b.toMatrix()
104+
} else if (!(b instanceof Matrix)) {
105+
b = b.toMatrix()
102106
}
107+
const n = this._i.sizes[0]
103108
this._dw = i.tDot(b)
104-
this._dw.div(this._i.rows)
109+
this._dw.div(n)
105110
if (this._l2_decay > 0 || this._l1_decay > 0) {
106111
for (let i = 0; i < this._dw.rows; i++) {
107112
for (let j = 0; j < this._dw.cols; j++) {
@@ -111,7 +116,7 @@ export default class FullyConnected extends Layer {
111116
}
112117
}
113118
this._db = b.sum(0)
114-
this._db.div(this._i.rows)
119+
this._db.div(n)
115120

116121
this._bi = bo.dot(this._w.t)
117122
if (this._wname || this._bname) {

lib/util/matrix.js

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,21 @@ export default class Matrix {
461461
return s + ']'
462462
}
463463

464+
_to_position(...i) {
465+
let p = 0
466+
for (let d = 0; d < this.dimension; d++) {
467+
if (i[d] < 0 || this._size[d] <= i[d]) {
468+
throw new MatrixException('Index out of bounds.')
469+
}
470+
p = p * this._size[d] + i[d]
471+
}
472+
return p
473+
}
474+
475+
_to_index(p) {
476+
return [Math.floor(p / this._size[1]), p % this._size[1]]
477+
}
478+
464479
/**
465480
* Returns a copy of this matrix.
466481
* @param {Matrix<T>} [dst] Destination matrix
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import { getPage } from '../helper/browser'
2+
3+
describe('generate', () => {
4+
/** @type {Awaited<ReturnType<getPage>>} */
5+
let page
6+
beforeEach(async () => {
7+
page = await getPage()
8+
const clusters = page.locator('#data_menu input[name=n]')
9+
await clusters.fill('1')
10+
const resetDataButton = page.locator('#data_menu input[value=Reset]')
11+
await resetDataButton.dispatchEvent('click')
12+
const taskSelectBox = page.locator('#ml_selector dl:first-child dd:nth-child(5) select')
13+
await taskSelectBox.selectOption('GR')
14+
const modelSelectBox = page.locator('#ml_selector .model_selection #mlDisp')
15+
await modelSelectBox.selectOption('diffusion_model')
16+
})
17+
18+
afterEach(async () => {
19+
await page?.close()
20+
})
21+
22+
test('initialize', async () => {
23+
const methodMenu = page.locator('#ml_selector #method_menu')
24+
const buttons = methodMenu.locator('.buttons')
25+
26+
const iteration = buttons.locator('select:nth-of-type(1)')
27+
await expect(iteration.inputValue()).resolves.toBe('10')
28+
const rate = buttons.locator('input:nth-of-type(2)')
29+
await expect(rate.inputValue()).resolves.toBe('0.01')
30+
const batch = buttons.locator('input:nth-of-type(3)')
31+
await expect(batch.inputValue()).resolves.toBe('10')
32+
})
33+
34+
test('learn', async () => {
35+
const methodMenu = page.locator('#ml_selector #method_menu')
36+
const buttons = methodMenu.locator('.buttons')
37+
38+
const epoch = buttons.locator('[name=epoch]')
39+
await expect(epoch.textContent()).resolves.toBe('0')
40+
const methodFooter = page.locator('#method_footer', { state: 'attached' })
41+
await expect(methodFooter.textContent()).resolves.toBe('')
42+
43+
const initButton = buttons.locator('input[value=Initialize]')
44+
await initButton.dispatchEvent('click')
45+
const stepButton = buttons.locator('input[value=Step]:enabled')
46+
await stepButton.dispatchEvent('click')
47+
await buttons.locator('input[value=Step]:enabled').waitFor()
48+
49+
await expect(epoch.textContent()).resolves.toBe('10')
50+
await expect(methodFooter.textContent()).resolves.toMatch(/^loss/)
51+
})
52+
})

0 commit comments

Comments
 (0)