Skip to content

Commit d602a29

Browse files
authored
Add string expression function layer (#779)
1 parent 1a4801b commit d602a29

File tree

4 files changed

+869
-1
lines changed

4 files changed

+869
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ for (let i = 0; i < n; i++) {
217217
| reduce | sum, mean, prod, variance, std, reduce max, reduce min, argmax, softargmax |
218218
| graph | convolutional, SAGE, readout |
219219
| loss | Huber, MSE |
220-
| other | concat, split, detach, clip, dropout, One-hot, reshape, flatten, transpose, reverse, sparce, conditional |
220+
| other | concat, split, detach, clip, dropout, One-hot, reshape, flatten, transpose, reverse, sparce, conditional, function |
221221

222222
## Contact
223223

lib/model/nns/layer/function.js

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import Layer from './base.js'
2+
3+
/**
4+
* Function layer
5+
*/
6+
export default class FunctionLayer extends Layer {
7+
/**
8+
* @param {object} config config
9+
* @param {string} config.func Function
10+
*/
11+
constructor({ func, ...rest }) {
12+
super(rest)
13+
this._func = func
14+
this._exc = stringToFunction(this._func)
15+
}
16+
17+
calc(x, y) {
18+
const o = x.copy()
19+
this._unary = !y
20+
if (this._unary) {
21+
o.map(v => this._exc({ x: v }))
22+
this._g = o.copy()
23+
this._g.map(v => v.gx)
24+
} else {
25+
o.broadcastOperate(y, (a, b) => this._exc({ x: a, y: b }))
26+
this._gx = o.copy()
27+
this._gx.map(v => v.gx)
28+
this._gy = o.copy()
29+
this._gy.map(v => v.gy)
30+
}
31+
o.map(v => v.v)
32+
return o
33+
}
34+
35+
grad(bo) {
36+
if (this._unary) {
37+
const bi = bo.copy()
38+
bi.broadcastOperate(this._g, (a, b) => a * b)
39+
return bi
40+
} else {
41+
const bx = bo.copy()
42+
bx.broadcastOperate(this._gx, (a, b) => a * b)
43+
const by = bo.copy()
44+
by.broadcastOperate(this._gy, (a, b) => a * b)
45+
return [bx, by]
46+
}
47+
}
48+
49+
toObject() {
50+
return {
51+
type: 'function',
52+
func: this._func,
53+
}
54+
}
55+
}
56+
57+
FunctionLayer.registLayer()
58+
59+
class OP {
60+
constructor(name, priority, func, grad) {
61+
this.name = name
62+
this.p = priority
63+
this.f = func
64+
this.g = grad
65+
}
66+
67+
get length() {
68+
return this.f.length
69+
}
70+
}
71+
72+
const uops = {
73+
'+': new OP(
74+
'+',
75+
4,
76+
v => v,
77+
(v, g) => g
78+
),
79+
'-': new OP(
80+
'-',
81+
4,
82+
v => -v,
83+
(v, g) => -g
84+
),
85+
}
86+
87+
const bops = {
88+
'-': new OP(
89+
'-',
90+
1,
91+
(a, b) => a - b,
92+
(va, vb, ga, gb) => ga - gb
93+
),
94+
'+': new OP(
95+
'+',
96+
1,
97+
(a, b) => a + b,
98+
(va, vb, ga, gb) => ga + gb
99+
),
100+
'*': new OP(
101+
'*',
102+
2,
103+
(a, b) => a * b,
104+
(va, vb, ga, gb) => va * gb + ga * vb
105+
),
106+
'/': new OP(
107+
'/',
108+
2,
109+
(a, b) => a / b,
110+
(va, vb, ga, gb) => (ga * vb - va * gb) / vb ** 2
111+
),
112+
'**': new OP(
113+
'**',
114+
3,
115+
(a, b) => a ** b,
116+
(va, vb, ga, gb) => va ** vb * ((gb === 0 ? 0 : gb * Math.log(va)) + vb * (ga / va))
117+
),
118+
}
119+
120+
const funcs = {
121+
abs: { f: Math.abs, g: (v, g) => (v < 0 ? -g : g) },
122+
acos: { f: Math.acos, g: (v, g) => -g / (Math.sqrt(1 - v ** 2) + 1.0e-4) },
123+
acosh: { f: Math.acosh, g: (v, g) => -g / (Math.sqrt(v ** 2 - 1) + 1.0e-4) },
124+
asin: { f: Math.asin, g: (v, g) => g / (Math.sqrt(1 - v ** 2) + 1.0e-4) },
125+
asinh: { f: Math.asinh, g: (v, g) => g / Math.sqrt(1 + v ** 2) },
126+
atan: { f: Math.atan, g: (v, g) => g / (1 + v ** 2) },
127+
atanh: { f: Math.atanh, g: (v, g) => g / (1 - v ** 2) },
128+
cbrt: { f: Math.cbrt, g: (v, g) => g / (3 * Math.cbrt(v) ** 2) },
129+
cos: { f: Math.cos, g: (v, g) => -g * Math.sin(v) },
130+
cosh: { f: Math.cosh, g: (v, g) => g * Math.sinh(v) },
131+
exp: { f: Math.exp, g: (v, g) => g * Math.exp(v) },
132+
log: { f: Math.log, g: (v, g) => g / v },
133+
log10: { f: Math.log10, g: (v, g) => g / (v * Math.LN10) },
134+
log2: { f: Math.log2, g: (v, g) => g / (v * Math.LN2) },
135+
max: { f: Math.max, g: (va, vb, ga, gb) => (va >= vb ? ga : gb) },
136+
min: { f: Math.min, g: (va, vb, ga, gb) => (va <= vb ? ga : gb) },
137+
sin: { f: Math.sin, g: (v, g) => g * Math.cos(v) },
138+
sinh: { f: Math.sinh, g: (v, g) => g * Math.cosh(v) },
139+
sqrt: { f: Math.sqrt, g: (v, g) => g / (2 * Math.sqrt(v)) },
140+
tan: { f: Math.tan, g: (v, g) => g / Math.cos(v) ** 2 },
141+
tanh: { f: Math.tanh, g: (v, g) => g * (1 - Math.tanh(v) ** 2) },
142+
}
143+
144+
const consts = {
145+
e: Math.E,
146+
ln2: Math.LN2,
147+
ln10: Math.LN10,
148+
log2e: Math.LOG2E,
149+
log10e: Math.LOG10E,
150+
pi: Math.PI,
151+
sqrt1_2: Math.SQRT1_2,
152+
sqrt2: Math.SQRT2,
153+
}
154+
155+
const tokenTable = [...Object.keys(bops), ...Object.keys(uops), '(', ')', ',', '[', ']']
156+
tokenTable.sort((a, b) => b.length - a.length)
157+
158+
const tokenize = e => {
159+
let p = 0
160+
const tk = []
161+
162+
const isToken = s => {
163+
for (const op of tokenTable) {
164+
if (op === e.slice(p + s, p + s + op.length)) {
165+
return op
166+
}
167+
}
168+
return null
169+
}
170+
171+
while (p < e.length) {
172+
if (e[p] === ' ') {
173+
p++
174+
continue
175+
}
176+
const op = isToken(0)
177+
if (op) {
178+
p += op.length
179+
tk.push(op)
180+
continue
181+
}
182+
183+
let i = 1
184+
for (; i < e.length - p; i++) {
185+
if (e[p + i] === ' ' || isToken(i)) {
186+
break
187+
}
188+
}
189+
tk.push(e.slice(p, p + i))
190+
p += i
191+
}
192+
return tk
193+
}
194+
195+
const construct = e => {
196+
const tokens = tokenize(e)
197+
198+
const rpn = []
199+
const stack = []
200+
let lastExpr = false
201+
for (const token of tokens) {
202+
if (consts[token]) {
203+
rpn.push(consts[token])
204+
lastExpr = true
205+
} else if (funcs[token]) {
206+
stack.push(token)
207+
lastExpr = false
208+
} else if (uops[token] || bops[token]) {
209+
if ((lastExpr && !bops[token]) || (!lastExpr && !uops[token])) {
210+
throw new Error(`Invalid operation '${token}'.`)
211+
}
212+
const op = lastExpr ? bops[token] : uops[token]
213+
while (true) {
214+
const lt = stack[stack.length - 1]
215+
if (lt instanceof OP && lt.p >= op.p) {
216+
rpn.push(stack.pop())
217+
} else {
218+
break
219+
}
220+
}
221+
stack.push(op)
222+
lastExpr = false
223+
} else if (token === ',') {
224+
while (true) {
225+
if (stack.length === 0) {
226+
throw new Error('Invalid parenthesis')
227+
}
228+
if (stack[stack.length - 1] === '(') {
229+
break
230+
}
231+
rpn.push(stack.pop())
232+
}
233+
lastExpr = false
234+
} else if (token === '(') {
235+
stack.push(token)
236+
lastExpr = false
237+
} else if (token === ')') {
238+
while (true) {
239+
const lt = stack.pop()
240+
if (!lt) {
241+
throw new Error('Invalid parenthesis')
242+
}
243+
if (lt === '(') {
244+
if (funcs[stack[stack.length - 1]]) {
245+
rpn.push(stack.pop())
246+
}
247+
break
248+
}
249+
rpn.push(lt)
250+
}
251+
lastExpr = true
252+
} else if (Number.isFinite(+token)) {
253+
rpn.push(+token)
254+
lastExpr = true
255+
} else {
256+
rpn.push(token)
257+
lastExpr = true
258+
}
259+
}
260+
261+
while (stack.length > 0) {
262+
rpn.push(stack.pop())
263+
}
264+
return rpn
265+
}
266+
267+
const execute = (rpn, env) => {
268+
const n = rpn.length
269+
let k = n - 1
270+
271+
const calc = () => {
272+
const token = rpn[k--]
273+
if (typeof token === 'number') {
274+
return { v: token, gx: 0, gy: 0 }
275+
} else if (Object.hasOwn(env, token)) {
276+
if (token === 'x') {
277+
return { v: env[token], gx: 1, gy: 0 }
278+
} else if (token === 'y') {
279+
return { v: env[token], gx: 0, gy: 1 }
280+
}
281+
}
282+
if (token instanceof OP) {
283+
const args = []
284+
for (let i = 0; i < token.length; i++) {
285+
args.unshift(calc())
286+
}
287+
const v = token.f(...args.map(v => v.v))
288+
const gx = token.g(...args.map(v => v.v), ...args.map(v => v.gx))
289+
const gy = token.g(...args.map(v => v.v), ...args.map(v => v.gy))
290+
return { v, gx, gy }
291+
}
292+
if (funcs[token]) {
293+
const an = funcs[token].f.length
294+
const args = []
295+
for (let i = 0; i < an; i++) {
296+
args.unshift(calc())
297+
}
298+
const v = funcs[token].f(...args.map(v => v.v))
299+
const gx = funcs[token].g(...args.map(v => v.v), ...args.map(v => v.gx))
300+
const gy = funcs[token].g(...args.map(v => v.v), ...args.map(v => v.gy))
301+
return { v, gx, gy }
302+
}
303+
throw new Error(`Invalid token '${token}'.`)
304+
}
305+
const ans = calc()
306+
if (k !== -1) {
307+
throw new Error('Invalid expression.')
308+
}
309+
return ans
310+
}
311+
312+
const stringToFunction = e => {
313+
const rpn = construct(e)
314+
return env => execute(rpn, env)
315+
}

lib/model/nns/layer/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export { default as FastELULayer } from './felu.js'
2727
export { default as FlattenLayer } from './flatten.js'
2828
export { default as FlexibleReLULayer } from './frelu.js'
2929
export { default as FullyConnected } from './full.js'
30+
export { default as FunctionLayer } from './function.js'
3031
export { default as GaussianLayer } from './gaussian.js'
3132
export { default as GlobalAveragePoolLayer } from './global_averagepool.js'
3233
export { default as GlobalLpPoolLayer } from './global_lppool.js'
@@ -160,6 +161,7 @@ import Tensor from '../../../util/tensor.js'
160161
* { type: 'floor' } |
161162
* { type: 'frelu', b?: number } |
162163
* { type: 'full', out_size: number | string, w?: number[][] | Matrix | string, b?: number[][] | Matrix | string, activation?: string | object, l2_decay?: number, l1_decay?: number } |
164+
* { type: 'function', func: string } |
163165
* { type: 'gaussian' } |
164166
* { type: 'gelu' } |
165167
* { type: 'global_average_pool', channel_dim?: number } |

0 commit comments

Comments
 (0)