|
| 1 | +import Matrix from '../../../util/matrix.js' |
| 2 | + |
| 3 | +export class AMSBoundOptimizer { |
| 4 | + constructor(lr = 0.001, alpha = 0.003, beta1 = 0.9, beta2 = 0.999) { |
| 5 | + this._learningrate = lr |
| 6 | + this._alpha = alpha |
| 7 | + this._beta1 = beta1 |
| 8 | + this._beta2 = beta2 |
| 9 | + |
| 10 | + this._eta_lbound = t => this._learningrate * (1 - 1 / ((1 - beta2) * t + 1)) |
| 11 | + this._eta_ubound = t => this._learningrate * (1 + 1 / ((1 - beta2) * t + 1)) |
| 12 | + } |
| 13 | + |
| 14 | + set learningRate(value) { |
| 15 | + this._learningrate = value |
| 16 | + } |
| 17 | + |
| 18 | + manager() { |
| 19 | + const this_ = this |
| 20 | + return { |
| 21 | + get lr() { |
| 22 | + return this_._learningrate |
| 23 | + }, |
| 24 | + params: {}, |
| 25 | + delta(key, value) { |
| 26 | + const valueIsNumber = typeof value === 'number' |
| 27 | + if (valueIsNumber) { |
| 28 | + value = new Matrix(1, 1, value) |
| 29 | + } |
| 30 | + if (!this.params[key]) { |
| 31 | + const z = value.copy() |
| 32 | + z.fill(0) |
| 33 | + this.params[key] = { m: z.copy(), v: z.copy(), vh: z, t: 1 } |
| 34 | + } |
| 35 | + this.params[key].m.broadcastOperate(value, (a, b) => a * this_._beta1 + b * (1 - this_._beta1)) |
| 36 | + this.params[key].v.broadcastOperate(value, (a, b) => a * this_._beta2 + (1 - this_._beta2) * b * b) |
| 37 | + this.params[key].vh.broadcastOperate(this.params[key].v, (a, b) => Math.max(a, b)) |
| 38 | + const eta_lb = this_._eta_lbound(this.params[key].t) |
| 39 | + const eta_ub = this_._eta_ubound(this.params[key].t) |
| 40 | + const eta = this.params[key].vh.copy() |
| 41 | + eta.map(v => Math.min(eta_ub, Math.max(eta_lb, this_._alpha / Math.sqrt(v)))) |
| 42 | + const ret = this.params[key].m.copy() |
| 43 | + ret.broadcastOperate(eta, (a, b) => (a * b) / Math.sqrt(this.params[key].t)) |
| 44 | + this.params[key].t++ |
| 45 | + return valueIsNumber ? ret.toScaler() : ret |
| 46 | + }, |
| 47 | + } |
| 48 | + } |
| 49 | +} |
0 commit comments