Skip to content

Commit ee71a05

Browse files
authored
Fix ELU gradient (#774)
1 parent 5f3b581 commit ee71a05

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

lib/model/nns/layer/elu.js

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ export default class ELULayer extends Layer {
1414
}
1515

1616
calc(x) {
17-
this._o = x.copy()
18-
this._o.map(v => (v > 0 ? v : this._a * (Math.exp(v) - 1)))
19-
return this._o
17+
this._i = x
18+
const o = x.copy()
19+
o.map(v => (v > 0 ? v : this._a * (Math.exp(v) - 1)))
20+
return o
2021
}
2122

2223
grad(bo) {
2324
const bi = bo.copy()
24-
bi.broadcastOperate(this._o, (a, b) => a * (b > 0 ? 1 : this._a * Math.exp(b)))
25+
bi.broadcastOperate(this._i, (a, b) => a * (b > 0 ? 1 : this._a * Math.exp(b)))
2526
return bi
2627
}
2728

tests/lib/model/nns/layer/elu.test.js

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,35 @@ describe('layer', () => {
4141
})
4242

4343
describe('grad', () => {
44+
test('numeric calculation', () => {
45+
const layer = new ELULayer({})
46+
const e = 1.0e-8
47+
48+
const x = Matrix.fromArray([Array.from({ length: 100 }, (_, i) => (i - 50) / 50)])
49+
const y = layer.calc(x)
50+
51+
const bo = Matrix.ones(1, 100)
52+
const bi = layer.grad(bo)
53+
54+
const d = layer.calc(Matrix.add(x, e))
55+
for (let i = 0; i < x.rows; i++) {
56+
for (let j = 0; j < x.cols; j++) {
57+
expect(bi.at(i, j)).toBeCloseTo((d.at(i, j) - y.at(i, j)) / e)
58+
}
59+
}
60+
})
61+
4462
test('matrix', () => {
4563
const layer = new ELULayer({})
4664

4765
const x = Matrix.randn(100, 10)
48-
const y = layer.calc(x)
66+
layer.calc(x)
4967

5068
const bo = Matrix.ones(100, 10)
5169
const bi = layer.grad(bo)
5270
for (let i = 0; i < x.rows; i++) {
5371
for (let j = 0; j < x.cols; j++) {
54-
expect(bi.at(i, j)).toBeCloseTo(x.at(i, j) > 0 ? 1 : Math.exp(y.at(i, j)))
72+
expect(bi.at(i, j)).toBeCloseTo(x.at(i, j) > 0 ? 1 : Math.exp(x.at(i, j)))
5573
}
5674
}
5775
})
@@ -60,14 +78,14 @@ describe('layer', () => {
6078
const layer = new ELULayer({})
6179

6280
const x = Tensor.randn([15, 10, 7])
63-
const y = layer.calc(x)
81+
layer.calc(x)
6482

6583
const bo = Tensor.ones([15, 10, 7])
6684
const bi = layer.grad(bo)
6785
for (let i = 0; i < x.sizes[0]; i++) {
6886
for (let j = 0; j < x.sizes[1]; j++) {
6987
for (let k = 0; k < x.sizes[2]; k++) {
70-
expect(bi.at(i, j, k)).toBeCloseTo(x.at(i, j, k) > 0 ? 1 : Math.exp(y.at(i, j, k)))
88+
expect(bi.at(i, j, k)).toBeCloseTo(x.at(i, j, k) > 0 ? 1 : Math.exp(x.at(i, j, k)))
7189
}
7290
}
7391
}

0 commit comments

Comments
 (0)