From 497bf8a3b9884f6b1a0a4fa36569821cdf357f13 Mon Sep 17 00:00:00 2001 From: ishii-norimi Date: Thu, 23 Nov 2023 10:50:00 +0900 Subject: [PATCH] Make partial method in tensor class --- lib/model/nns/layer/gru.js | 4 +-- lib/model/nns/layer/lstm.js | 4 +-- lib/model/nns/layer/rnn.js | 4 +-- lib/util/tensor.js | 21 ++++++++++--- tests/lib/util/tensor.test.js | 55 +++++++++++++++++++++++++++++++---- 5 files changed, 73 insertions(+), 15 deletions(-) diff --git a/lib/model/nns/layer/gru.js b/lib/model/nns/layer/gru.js index 276b22f56..6ec23f4cf 100644 --- a/lib/model/nns/layer/gru.js +++ b/lib/model/nns/layer/gru.js @@ -45,7 +45,7 @@ export default class GRULayer extends Layer { x = x.transpose(1, 0, 2) this._x = [] for (let k = 0; k < x.sizes[0]; k++) { - this._x[k] = x.at(k).toMatrix() + this._x[k] = x.index(k).toMatrix() } const s = [] @@ -69,7 +69,7 @@ export default class GRULayer extends Layer { if (this._return_sequences) { bo = bo.transpose(1, 0, 2) for (let i = 0; i < s; i++) { - this._bo[i] = bo.at(i).toMatrix() + this._bo[i] = bo.index(i).toMatrix() } } else { this._bo[s - 1] = bo diff --git a/lib/model/nns/layer/lstm.js b/lib/model/nns/layer/lstm.js index 2af6eebba..4b5b2934a 100644 --- a/lib/model/nns/layer/lstm.js +++ b/lib/model/nns/layer/lstm.js @@ -74,7 +74,7 @@ export default class LSTMLayer extends Layer { x = x.transpose(1, 0, 2) this._x = [] for (let k = 0; k < x.sizes[0]; k++) { - this._x[k] = x.at(k).toMatrix() + this._x[k] = x.index(k).toMatrix() } this._y = [] @@ -98,7 +98,7 @@ export default class LSTMLayer extends Layer { if (this._return_sequences) { bo = bo.transpose(1, 0, 2) for (let i = 0; i < s; i++) { - this._bo[i] = bo.at(i).toMatrix() + this._bo[i] = bo.index(i).toMatrix() } } else { this._bo[s - 1] = bo diff --git a/lib/model/nns/layer/rnn.js b/lib/model/nns/layer/rnn.js index f7f890522..c9c83965b 100644 --- a/lib/model/nns/layer/rnn.js +++ b/lib/model/nns/layer/rnn.js @@ -66,7 +66,7 @@ export default class RNNLayer extends Layer { x = x.transpose(1, 0, 2) this._i = [] for (let k = 0; k < x.sizes[0]; k++) { - this._i[k] = x.at(k).toMatrix() + this._i[k] = x.index(k).toMatrix() } this._o = [] this._z = [] @@ -97,7 +97,7 @@ export default class RNNLayer extends Layer { if (this._return_sequences) { bo = bo.transpose(1, 0, 2) for (let i = 0; i < s; i++) { - this._bo[i] = bo.at(i).toMatrix() + this._bo[i] = bo.index(i).toMatrix() } } else { this._bo[s - 1] = bo diff --git a/lib/util/tensor.js b/lib/util/tensor.js index 53c4544a6..7eea08d51 100644 --- a/lib/util/tensor.js +++ b/lib/util/tensor.js @@ -287,17 +287,30 @@ export default class Tensor { } /** - * Returns value(s) at the index position. + * Returns value at the index position. * * @param {...number} i Index values - * @returns {number | Tensor} The value or sub tensor + * @returns {number} The value */ at(...i) { if (Array.isArray(i[0])) { i = i[0] } - if (i.length === this.dimension) { - return this._value[this._to_position(...i)] + if (i.length !== this.dimension) { + throw new MatrixException('Length is invalid.') + } + return this._value[this._to_position(...i)] + } + + /** + * Returns tensor at the index position. + * + * @param {...number} i Index values + * @returns {Tensor} Sub tensor + */ + index(...i) { + if (Array.isArray(i[0])) { + i = i[0] } let s = 0 diff --git a/tests/lib/util/tensor.test.js b/tests/lib/util/tensor.test.js index cd40ac726..c7a5ed031 100644 --- a/tests/lib/util/tensor.test.js +++ b/tests/lib/util/tensor.test.js @@ -353,6 +353,51 @@ describe('Tensor', () => { expect(() => ten.at([i, j, k])).toThrow('Index out of bounds.') }) + test.each([[0], [0, 0]])('fail[%i]', i => { + const ten = new Tensor([2, 3, 4]) + expect(() => ten.at(i)).toThrow('Length is invalid.') + expect(() => ten.at([i])).toThrow('Length is invalid.') + }) + }) + + describe('index', () => { + test('default', () => { + const data = [ + [ + [1, 2], + [3, 4], + [5, 6], + ], + [ + [7, 8], + [9, 10], + [11, 12], + ], + ] + const ten = new Tensor([2, 3, 2], data) + for (let i = 0; i < 2; i++) { + for (let j = 0; j < 3; j++) { + for (let k = 0; k < 2; k++) { + expect(ten.index(i, j, k).at()).toBe(data[i][j][k]) + expect(ten.index([i, j, k]).at()).toBe(data[i][j][k]) + } + } + } + }) + + test.each([ + [-1, 0, 0], + [2, 0, 0], + [0, -1, 0], + [0, 3, 0], + [0, 0, -1], + [0, 0, 4], + ])('fail[%i, %i, %i]', (i, j, k) => { + const ten = new Tensor([2, 3, 4]) + expect(() => ten.index(i, j, k)).toThrow('Index out of bounds.') + expect(() => ten.index([i, j, k])).toThrow('Index out of bounds.') + }) + test('multi', () => { const data = [ [ @@ -368,7 +413,7 @@ describe('Tensor', () => { ] const ten = new Tensor([2, 3, 2], data) - const at = ten.at(1, 1) + const at = ten.index(1, 1) expect(at.sizes).toEqual([2]) expect(at.at(0)).toBe(9) expect(at.at(1)).toBe(10) @@ -376,8 +421,8 @@ describe('Tensor', () => { test.each([[-1], [2]])('fail[%i]', i => { const ten = new Tensor([2, 3, 4]) - expect(() => ten.at(i)).toThrow('Index out of bounds.') - expect(() => ten.at([i])).toThrow('Index out of bounds.') + expect(() => ten.index(i)).toThrow('Index out of bounds.') + expect(() => ten.index([i])).toThrow('Index out of bounds.') }) test.each([ @@ -387,8 +432,8 @@ describe('Tensor', () => { [0, 3], ])('fail[%i, %i]', (i, j) => { const ten = new Tensor([2, 3, 4]) - expect(() => ten.at(i, j)).toThrow('Index out of bounds.') - expect(() => ten.at([i, j])).toThrow('Index out of bounds.') + expect(() => ten.index(i, j)).toThrow('Index out of bounds.') + expect(() => ten.index([i, j])).toThrow('Index out of bounds.') }) })