Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/model/nns/onnx/layer/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export { default as log_softmax } from './log_softmax.js'
export { default as loglog } from './loglog.js'
export { default as logsigmoid } from './logsigmoid.js'
export { default as lp_pool } from './lp_pool.js'
export { default as lrn } from './lrn.js'
export { default as matmul } from './matmul.js'
export { default as max } from './max.js'
export { default as max_pool } from './max_pool.js'
Expand Down
86 changes: 86 additions & 0 deletions lib/model/nns/onnx/layer/lrn.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { onnx } from '../onnx_exporter.js'

/**
* Handle lrn layer
*/
export default {
/**
* Export to onnx object.
* @param {onnx.ModelProto} model Model object
* @param {import("../../graph.js").LayerObject & {type: 'lrn'}} obj Node object
* @param {{[key: string]: {type: onnx.TensorProto.DataType; size: number[]}}} info Output informatino of other layers
*/
export(model, obj, info) {
const graph = model.getGraph()

const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
const size = info[input].size.concat()

const node = new onnx.NodeProto()
node.setOpType('LRN')
if (obj.channel_dim === 1) {
node.addInput(input)
node.addOutput(obj.name)
} else if (obj.channel_dim == null || obj.channel_dim === -1) {
const node_transpose1 = new onnx.NodeProto()
node_transpose1.setOpType('Transpose')
node_transpose1.addInput(input)
node_transpose1.addOutput(obj.name + '_t1')
const attrPerm1 = new onnx.AttributeProto()
attrPerm1.setName('perm')
attrPerm1.setType(onnx.AttributeProto.AttributeType.INTS)
const perm1 = Array.from(size, (_, i) => i - 1)
perm1[0] = 0
perm1[1] = size.length - 1
attrPerm1.setIntsList(perm1)
node_transpose1.addAttribute(attrPerm1)
graph.addNode(node_transpose1)

node.addInput(obj.name + '_t1')
node.addOutput(obj.name + '_gap')

const node_transpose2 = new onnx.NodeProto()
node_transpose2.setOpType('Transpose')
node_transpose2.addInput(obj.name + '_gap')
node_transpose2.addOutput(obj.name)
const attrPerm2 = new onnx.AttributeProto()
attrPerm2.setName('perm')
attrPerm2.setType(onnx.AttributeProto.AttributeType.INTS)
const perm2 = Array.from(size, (_, i) => i + 1)
perm2[0] = 0
perm2[perm2.length - 1] = 1
attrPerm2.setIntsList(perm2)
node_transpose2.addAttribute(attrPerm2)
graph.addNode(node_transpose2)
} else {
throw new Error(`Not implemented value of attribute 'channel_dim' ${obj.channel_dim}.`)
}

if (obj.n == null) {
throw new Error("Require attribute 'n'")
}
const attrSize = new onnx.AttributeProto()
attrSize.setName('size')
attrSize.setType(onnx.AttributeProto.AttributeType.INT)
attrSize.setI(obj.n)
node.addAttribute(attrSize)

const attrAlpha = new onnx.AttributeProto()
attrAlpha.setName('alpha')
attrAlpha.setType(onnx.AttributeProto.AttributeType.FLOAT)
attrAlpha.setF(obj.alpha ?? 0.0001)
node.addAttribute(attrAlpha)
const attrBeta = new onnx.AttributeProto()
attrBeta.setName('beta')
attrBeta.setType(onnx.AttributeProto.AttributeType.FLOAT)
attrBeta.setF(obj.beta ?? 0.75)
node.addAttribute(attrBeta)
const attrBias = new onnx.AttributeProto()
attrBias.setName('bias')
attrBias.setType(onnx.AttributeProto.AttributeType.FLOAT)
attrBias.setF(obj.k ?? 2)
node.addAttribute(attrBias)

graph.addNode(node)
},
}
77 changes: 77 additions & 0 deletions tests/lib/model/nns/onnx/layer/lrn.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import { jest } from '@jest/globals'
jest.retryTimes(3)

import * as ort from 'onnxruntime-web'
ort.env.wasm.numThreads = 1

import ONNXExporter from '../../../../../../lib/model/nns/onnx/onnx_exporter.js'
import lrn from '../../../../../../lib/model/nns/onnx/layer/lrn.js'
import LRNLayer from '../../../../../../lib/model/nns/layer/lrn.js'
import Tensor from '../../../../../../lib/util/tensor.js'

describe('export', () => {
test.each([
{ input: 'x', channel_dim: -1, n: 3 },
{ input: ['x'], n: 3 },
])('last channel %p', param => {
const model = ONNXExporter.createONNXModel()
lrn.export(model, { type: 'lrn', ...param }, { x: { size: [null, 10, 3] } })
const nodes = model.getGraph().getNodeList()
expect(nodes).toHaveLength(3)
expect(nodes[0].getOpType()).toBe('Transpose')
expect(nodes[1].getOpType()).toBe('Transpose')
expect(nodes[2].getOpType()).toBe('LRN')
})

test('first channel', () => {
const model = ONNXExporter.createONNXModel()
lrn.export(model, { type: 'lrn', input: 'x', channel_dim: 1, n: 3 }, { x: { size: [null, 10, 3] } })
const nodes = model.getGraph().getNodeList()
expect(nodes).toHaveLength(1)
expect(nodes[0].getOpType()).toBe('LRN')
})

test('invalid channel dim', () => {
const model = ONNXExporter.createONNXModel()
expect(() =>
lrn.export(model, { type: 'lrn', input: ['x'], channel_dim: 0, n: 3 }, { x: { size: [null, 10, 3] } })
).toThrow("Not implemented value of attribute 'channel_dim' 0")
})

test('require n', () => {
const model = ONNXExporter.createONNXModel()
expect(() =>
lrn.export(model, { type: 'lrn', input: ['x'], channel_dim: -1 }, { x: { size: [null, 10, 3] } })
).toThrow("Require attribute 'n'")
})
})

describe('runtime', () => {
let session
afterEach(async () => {
await session?.release()
session = null
})

test.each([
[{ channel_dim: 1, n: 3 }, [null, 4, 3, 3], [1, 4, 3, 3]],
[{ n: 5 }, [null, 4, 4, 10], [1, 4, 4, 10]],
[{ alpha: 0.0002, beta: 0.7, k: 2, n: 5 }, [null, 3, 3, 5], [1, 3, 3, 5]],
])('lrn %p %p %p', async (param, inSize, actualSize) => {
const buf = ONNXExporter.dump([{ type: 'input', size: inSize }, { type: 'lrn', ...param }, { type: 'output' }])
session = await ort.InferenceSession.create(buf)

const x = Tensor.randn(actualSize)
const xten = new ort.Tensor('float32', x.value, x.sizes)
const out = await session.run({ _input: xten })
const yten = out._lrn
expect(yten.dims).toEqual(actualSize)
const y = await yten.getData(true)

const t = new LRNLayer(param).calc(x)
expect(yten.dims).toEqual(t.sizes)
for (let i = 0; i < y.length; i++) {
expect(y[i]).toBeCloseTo(t.value[i])
}
})
})