|
| 1 | +import Layer from './base.js' |
| 2 | + |
| 3 | +/** |
| 4 | + * Function layer |
| 5 | + */ |
| 6 | +export default class FunctionLayer extends Layer { |
| 7 | + /** |
| 8 | + * @param {object} config config |
| 9 | + * @param {string} config.func Function |
| 10 | + */ |
| 11 | + constructor({ func, ...rest }) { |
| 12 | + super(rest) |
| 13 | + this._func = func |
| 14 | + this._exc = stringToFunction(this._func) |
| 15 | + } |
| 16 | + |
| 17 | + calc(x, y) { |
| 18 | + const o = x.copy() |
| 19 | + this._unary = !y |
| 20 | + if (this._unary) { |
| 21 | + o.map(v => this._exc({ x: v })) |
| 22 | + this._g = o.copy() |
| 23 | + this._g.map(v => v.gx) |
| 24 | + } else { |
| 25 | + o.broadcastOperate(y, (a, b) => this._exc({ x: a, y: b })) |
| 26 | + this._gx = o.copy() |
| 27 | + this._gx.map(v => v.gx) |
| 28 | + this._gy = o.copy() |
| 29 | + this._gy.map(v => v.gy) |
| 30 | + } |
| 31 | + o.map(v => v.v) |
| 32 | + return o |
| 33 | + } |
| 34 | + |
| 35 | + grad(bo) { |
| 36 | + if (this._unary) { |
| 37 | + const bi = bo.copy() |
| 38 | + bi.broadcastOperate(this._g, (a, b) => a * b) |
| 39 | + return bi |
| 40 | + } else { |
| 41 | + const bx = bo.copy() |
| 42 | + bx.broadcastOperate(this._gx, (a, b) => a * b) |
| 43 | + const by = bo.copy() |
| 44 | + by.broadcastOperate(this._gy, (a, b) => a * b) |
| 45 | + return [bx, by] |
| 46 | + } |
| 47 | + } |
| 48 | + |
| 49 | + toObject() { |
| 50 | + return { |
| 51 | + type: 'function', |
| 52 | + func: this._func, |
| 53 | + } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +FunctionLayer.registLayer() |
| 58 | + |
| 59 | +class OP { |
| 60 | + constructor(name, priority, func, grad) { |
| 61 | + this.name = name |
| 62 | + this.p = priority |
| 63 | + this.f = func |
| 64 | + this.g = grad |
| 65 | + } |
| 66 | + |
| 67 | + get length() { |
| 68 | + return this.f.length |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +const uops = { |
| 73 | + '+': new OP( |
| 74 | + '+', |
| 75 | + 4, |
| 76 | + v => v, |
| 77 | + (v, g) => g |
| 78 | + ), |
| 79 | + '-': new OP( |
| 80 | + '-', |
| 81 | + 4, |
| 82 | + v => -v, |
| 83 | + (v, g) => -g |
| 84 | + ), |
| 85 | +} |
| 86 | + |
| 87 | +const bops = { |
| 88 | + '-': new OP( |
| 89 | + '-', |
| 90 | + 1, |
| 91 | + (a, b) => a - b, |
| 92 | + (va, vb, ga, gb) => ga - gb |
| 93 | + ), |
| 94 | + '+': new OP( |
| 95 | + '+', |
| 96 | + 1, |
| 97 | + (a, b) => a + b, |
| 98 | + (va, vb, ga, gb) => ga + gb |
| 99 | + ), |
| 100 | + '*': new OP( |
| 101 | + '*', |
| 102 | + 2, |
| 103 | + (a, b) => a * b, |
| 104 | + (va, vb, ga, gb) => va * gb + ga * vb |
| 105 | + ), |
| 106 | + '/': new OP( |
| 107 | + '/', |
| 108 | + 2, |
| 109 | + (a, b) => a / b, |
| 110 | + (va, vb, ga, gb) => (ga * vb - va * gb) / vb ** 2 |
| 111 | + ), |
| 112 | + '**': new OP( |
| 113 | + '**', |
| 114 | + 3, |
| 115 | + (a, b) => a ** b, |
| 116 | + (va, vb, ga, gb) => va ** vb * ((gb === 0 ? 0 : gb * Math.log(va)) + vb * (ga / va)) |
| 117 | + ), |
| 118 | +} |
| 119 | + |
| 120 | +const funcs = { |
| 121 | + abs: { f: Math.abs, g: (v, g) => (v < 0 ? -g : g) }, |
| 122 | + acos: { f: Math.acos, g: (v, g) => -g / (Math.sqrt(1 - v ** 2) + 1.0e-4) }, |
| 123 | + acosh: { f: Math.acosh, g: (v, g) => -g / (Math.sqrt(v ** 2 - 1) + 1.0e-4) }, |
| 124 | + asin: { f: Math.asin, g: (v, g) => g / (Math.sqrt(1 - v ** 2) + 1.0e-4) }, |
| 125 | + asinh: { f: Math.asinh, g: (v, g) => g / Math.sqrt(1 + v ** 2) }, |
| 126 | + atan: { f: Math.atan, g: (v, g) => g / (1 + v ** 2) }, |
| 127 | + atanh: { f: Math.atanh, g: (v, g) => g / (1 - v ** 2) }, |
| 128 | + cbrt: { f: Math.cbrt, g: (v, g) => g / (3 * Math.cbrt(v) ** 2) }, |
| 129 | + cos: { f: Math.cos, g: (v, g) => -g * Math.sin(v) }, |
| 130 | + cosh: { f: Math.cosh, g: (v, g) => g * Math.sinh(v) }, |
| 131 | + exp: { f: Math.exp, g: (v, g) => g * Math.exp(v) }, |
| 132 | + log: { f: Math.log, g: (v, g) => g / v }, |
| 133 | + log10: { f: Math.log10, g: (v, g) => g / (v * Math.LN10) }, |
| 134 | + log2: { f: Math.log2, g: (v, g) => g / (v * Math.LN2) }, |
| 135 | + max: { f: Math.max, g: (va, vb, ga, gb) => (va >= vb ? ga : gb) }, |
| 136 | + min: { f: Math.min, g: (va, vb, ga, gb) => (va <= vb ? ga : gb) }, |
| 137 | + sin: { f: Math.sin, g: (v, g) => g * Math.cos(v) }, |
| 138 | + sinh: { f: Math.sinh, g: (v, g) => g * Math.cosh(v) }, |
| 139 | + sqrt: { f: Math.sqrt, g: (v, g) => g / (2 * Math.sqrt(v)) }, |
| 140 | + tan: { f: Math.tan, g: (v, g) => g / Math.cos(v) ** 2 }, |
| 141 | + tanh: { f: Math.tanh, g: (v, g) => g * (1 - Math.tanh(v) ** 2) }, |
| 142 | +} |
| 143 | + |
| 144 | +const consts = { |
| 145 | + e: Math.E, |
| 146 | + ln2: Math.LN2, |
| 147 | + ln10: Math.LN10, |
| 148 | + log2e: Math.LOG2E, |
| 149 | + log10e: Math.LOG10E, |
| 150 | + pi: Math.PI, |
| 151 | + sqrt1_2: Math.SQRT1_2, |
| 152 | + sqrt2: Math.SQRT2, |
| 153 | +} |
| 154 | + |
| 155 | +const tokenTable = [...Object.keys(bops), ...Object.keys(uops), '(', ')', ',', '[', ']'] |
| 156 | +tokenTable.sort((a, b) => b.length - a.length) |
| 157 | + |
| 158 | +const tokenize = e => { |
| 159 | + let p = 0 |
| 160 | + const tk = [] |
| 161 | + |
| 162 | + const isToken = s => { |
| 163 | + for (const op of tokenTable) { |
| 164 | + if (op === e.slice(p + s, p + s + op.length)) { |
| 165 | + return op |
| 166 | + } |
| 167 | + } |
| 168 | + return null |
| 169 | + } |
| 170 | + |
| 171 | + while (p < e.length) { |
| 172 | + if (e[p] === ' ') { |
| 173 | + p++ |
| 174 | + continue |
| 175 | + } |
| 176 | + const op = isToken(0) |
| 177 | + if (op) { |
| 178 | + p += op.length |
| 179 | + tk.push(op) |
| 180 | + continue |
| 181 | + } |
| 182 | + |
| 183 | + let i = 1 |
| 184 | + for (; i < e.length - p; i++) { |
| 185 | + if (e[p + i] === ' ' || isToken(i)) { |
| 186 | + break |
| 187 | + } |
| 188 | + } |
| 189 | + tk.push(e.slice(p, p + i)) |
| 190 | + p += i |
| 191 | + } |
| 192 | + return tk |
| 193 | +} |
| 194 | + |
| 195 | +const construct = e => { |
| 196 | + const tokens = tokenize(e) |
| 197 | + |
| 198 | + const rpn = [] |
| 199 | + const stack = [] |
| 200 | + let lastExpr = false |
| 201 | + for (const token of tokens) { |
| 202 | + if (consts[token]) { |
| 203 | + rpn.push(consts[token]) |
| 204 | + lastExpr = true |
| 205 | + } else if (funcs[token]) { |
| 206 | + stack.push(token) |
| 207 | + lastExpr = false |
| 208 | + } else if (uops[token] || bops[token]) { |
| 209 | + if ((lastExpr && !bops[token]) || (!lastExpr && !uops[token])) { |
| 210 | + throw new Error(`Invalid operation '${token}'.`) |
| 211 | + } |
| 212 | + const op = lastExpr ? bops[token] : uops[token] |
| 213 | + while (true) { |
| 214 | + const lt = stack[stack.length - 1] |
| 215 | + if (lt instanceof OP && lt.p >= op.p) { |
| 216 | + rpn.push(stack.pop()) |
| 217 | + } else { |
| 218 | + break |
| 219 | + } |
| 220 | + } |
| 221 | + stack.push(op) |
| 222 | + lastExpr = false |
| 223 | + } else if (token === ',') { |
| 224 | + while (true) { |
| 225 | + if (stack.length === 0) { |
| 226 | + throw new Error('Invalid parenthesis') |
| 227 | + } |
| 228 | + if (stack[stack.length - 1] === '(') { |
| 229 | + break |
| 230 | + } |
| 231 | + rpn.push(stack.pop()) |
| 232 | + } |
| 233 | + lastExpr = false |
| 234 | + } else if (token === '(') { |
| 235 | + stack.push(token) |
| 236 | + lastExpr = false |
| 237 | + } else if (token === ')') { |
| 238 | + while (true) { |
| 239 | + const lt = stack.pop() |
| 240 | + if (!lt) { |
| 241 | + throw new Error('Invalid parenthesis') |
| 242 | + } |
| 243 | + if (lt === '(') { |
| 244 | + if (funcs[stack[stack.length - 1]]) { |
| 245 | + rpn.push(stack.pop()) |
| 246 | + } |
| 247 | + break |
| 248 | + } |
| 249 | + rpn.push(lt) |
| 250 | + } |
| 251 | + lastExpr = true |
| 252 | + } else if (Number.isFinite(+token)) { |
| 253 | + rpn.push(+token) |
| 254 | + lastExpr = true |
| 255 | + } else { |
| 256 | + rpn.push(token) |
| 257 | + lastExpr = true |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + while (stack.length > 0) { |
| 262 | + rpn.push(stack.pop()) |
| 263 | + } |
| 264 | + return rpn |
| 265 | +} |
| 266 | + |
| 267 | +const execute = (rpn, env) => { |
| 268 | + const n = rpn.length |
| 269 | + let k = n - 1 |
| 270 | + |
| 271 | + const calc = () => { |
| 272 | + const token = rpn[k--] |
| 273 | + if (typeof token === 'number') { |
| 274 | + return { v: token, gx: 0, gy: 0 } |
| 275 | + } else if (Object.hasOwn(env, token)) { |
| 276 | + if (token === 'x') { |
| 277 | + return { v: env[token], gx: 1, gy: 0 } |
| 278 | + } else if (token === 'y') { |
| 279 | + return { v: env[token], gx: 0, gy: 1 } |
| 280 | + } |
| 281 | + } |
| 282 | + if (token instanceof OP) { |
| 283 | + const args = [] |
| 284 | + for (let i = 0; i < token.length; i++) { |
| 285 | + args.unshift(calc()) |
| 286 | + } |
| 287 | + const v = token.f(...args.map(v => v.v)) |
| 288 | + const gx = token.g(...args.map(v => v.v), ...args.map(v => v.gx)) |
| 289 | + const gy = token.g(...args.map(v => v.v), ...args.map(v => v.gy)) |
| 290 | + return { v, gx, gy } |
| 291 | + } |
| 292 | + if (funcs[token]) { |
| 293 | + const an = funcs[token].f.length |
| 294 | + const args = [] |
| 295 | + for (let i = 0; i < an; i++) { |
| 296 | + args.unshift(calc()) |
| 297 | + } |
| 298 | + const v = funcs[token].f(...args.map(v => v.v)) |
| 299 | + const gx = funcs[token].g(...args.map(v => v.v), ...args.map(v => v.gx)) |
| 300 | + const gy = funcs[token].g(...args.map(v => v.v), ...args.map(v => v.gy)) |
| 301 | + return { v, gx, gy } |
| 302 | + } |
| 303 | + throw new Error(`Invalid token '${token}'.`) |
| 304 | + } |
| 305 | + const ans = calc() |
| 306 | + if (k !== -1) { |
| 307 | + throw new Error('Invalid expression.') |
| 308 | + } |
| 309 | + return ans |
| 310 | +} |
| 311 | + |
| 312 | +const stringToFunction = e => { |
| 313 | + const rpn = construct(e) |
| 314 | + return env => execute(rpn, env) |
| 315 | +} |
0 commit comments