Skip to content

Commit 0f0f4d0

Browse files
authored
Fix and improve RL environments, and add tests (#668)
* Fix and improve RL environments, and add tests * Forgot to commit and format * Add tests
1 parent e1adc5f commit 0f0f4d0

17 files changed

+1358
-304
lines changed

js/renderer/rl/draughts.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class ManualPlayer {
160160
for (let i = 0; i < board.size[0]; i++) {
161161
this._check[i] = []
162162
for (let j = 0; j < board.size[1]; j++) {
163-
if ((i + j) % 2 > 0) continue
163+
if ((i + j) % 2 === 0) continue
164164
this._check[i][j] = document.createElementNS('http://www.w3.org/2000/svg', 'rect')
165165
this._check[i][j].setAttribute('x', dw * j)
166166
this._check[i][j].setAttribute('y', dh * i)

lib/rl/acrobot.js

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,6 @@ export default class AcrobotRLEnvironment extends RLEnvironmentBase {
4646
]
4747
}
4848

49-
set reward(value) {
50-
this._reward = {
51-
goal: 0,
52-
step: -1,
53-
fail: 0,
54-
}
55-
if (value === 'achieve') {
56-
this._reward = {
57-
goal: 0,
58-
step: -1,
59-
fail: 0,
60-
}
61-
}
62-
}
63-
6449
reset() {
6550
super.reset()
6651
this._theta1 = Math.random() * 0.2 - 0.1
@@ -111,11 +96,11 @@ export default class AcrobotRLEnvironment extends RLEnvironmentBase {
11196

11297
const clip = (x, min, max) => (x < min ? min : x > max ? max : x)
11398
t1 += this._dt * dt1
114-
if (t1 < -Math.PI) t1 = t1 + 2 * Math.PI
115-
if (t1 > Math.PI) t1 = t1 - 2 * Math.PI
99+
while (t1 < -Math.PI) t1 += 2 * Math.PI
100+
while (t1 > Math.PI) t1 -= 2 * Math.PI
116101
t2 += this._dt * dt2
117-
if (t2 < -Math.PI) t2 = t2 + 2 * Math.PI
118-
if (t2 > Math.PI) t2 = t2 - 2 * Math.PI
102+
while (t2 < -Math.PI) t2 += 2 * Math.PI
103+
while (t2 > Math.PI) t2 -= 2 * Math.PI
119104
dt1 = clip(dt1 + this._dt * ddt1, -this._max_vel1, this._max_vel1)
120105
dt2 = clip(dt2 + this._dt * ddt2, -this._max_vel2, this._max_vel2)
121106

lib/rl/draughts.js

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export default class DraughtsRLEnvironment extends RLEnvironmentBase {
4747
]
4848
const checkBound = (x, y) => 0 <= x && x < this._size[0] && 0 <= y && y < this._size[1]
4949
for (let i = 0; i < this._size[0]; i++) {
50-
for (let j = i % 2 === 0 ? 0 : 1; j < this._size[1]; j += 2) {
50+
for (let j = i % 2 === 1 ? 0 : 1; j < this._size[1]; j += 2) {
5151
let midpath = []
5252
for (const [di, dj] of d) {
5353
const i1 = i + di
@@ -93,18 +93,14 @@ export default class DraughtsRLEnvironment extends RLEnvironmentBase {
9393
get states() {
9494
const s = [[RED, WHITE]]
9595
for (let i = 0; i < this._size[0]; i++) {
96-
for (let j = 0; j < this._size[1]; j++) {
97-
if (j % 2 === i % 2) {
98-
s.push([
99-
EMPTY,
100-
DraughtsRLEnvironment.OWN,
101-
DraughtsRLEnvironment.OWN | KING,
102-
DraughtsRLEnvironment.OTHER,
103-
DraughtsRLEnvironment.OTHER | KING,
104-
])
105-
} else {
106-
s.push([EMPTY])
107-
}
96+
for (let j = i % 2 === 0 ? 1 : 0; j < this._size[1]; j += 2) {
97+
s.push([
98+
EMPTY,
99+
DraughtsRLEnvironment.OWN,
100+
DraughtsRLEnvironment.OWN | KING,
101+
DraughtsRLEnvironment.OTHER,
102+
DraughtsRLEnvironment.OTHER | KING,
103+
])
108104
}
109105
}
110106
return s
@@ -123,7 +119,7 @@ export default class DraughtsRLEnvironment extends RLEnvironmentBase {
123119
_makeState(board, agentturn, gameturn) {
124120
const s = [gameturn]
125121
for (let i = 0; i < this._size[0]; i++) {
126-
for (let j = 0; j < this._size[1]; j++) {
122+
for (let j = i % 2 === 0 ? 1 : 0; j < this._size[1]; j += 2) {
127123
const p = board.at([i, j])
128124
if (p === EMPTY) {
129125
s.push(EMPTY)
@@ -144,7 +140,7 @@ export default class DraughtsRLEnvironment extends RLEnvironmentBase {
144140
const board = new DraughtsBoard(this._size, this._evaluation)
145141
const opturn = turn === RED ? WHITE : RED
146142
for (let i = 0, p = 1; i < this._size[0]; i++) {
147-
for (let j = 0; j < this._size[1]; j++, p++) {
143+
for (let j = i % 2 === 0 ? 1 : 0; j < this._size[1]; j += 2, p++) {
148144
if (state[p] === EMPTY) {
149145
board._board[i][j] = EMPTY
150146
} else {
@@ -241,6 +237,7 @@ class DraughtsBoard {
241237
constructor(size, evaluator) {
242238
this._evaluator = evaluator
243239
this._size = size
240+
this._lines = 3
244241

245242
this.reset()
246243
}
@@ -280,6 +277,26 @@ class DraughtsBoard {
280277
return null
281278
}
282279

280+
toString() {
281+
let buf = ''
282+
for (let i = 0; i < this._size[0]; i++) {
283+
for (let j = 0; j < this._size[1]; j++) {
284+
if (j > 0) {
285+
buf += ' '
286+
}
287+
if (this._board[i][j] === RED) {
288+
buf += 'x'
289+
} else if (this._board[i][j] === WHITE) {
290+
buf += 'o'
291+
} else {
292+
buf += '-'
293+
}
294+
}
295+
buf += '\n'
296+
}
297+
return buf
298+
}
299+
283300
nextTurn(turn) {
284301
if (turn === WHITE) {
285302
return RED
@@ -310,20 +327,44 @@ class DraughtsBoard {
310327
}
311328
}
312329

330+
_num_to_pos(n) {
331+
if (typeof n !== 'number') {
332+
return n
333+
}
334+
const r = Math.floor((n - 1) / this._size[1])
335+
const c = (n - 1) % this._size[1]
336+
if (c < (this._size[1] - 1) / 2) {
337+
return [r * 2, c * 2 + 1]
338+
} else {
339+
return [r * 2 + 1, (c - Math.floor(this._size[1] / 2)) * 2]
340+
}
341+
}
342+
313343
at(p) {
344+
if (typeof p === 'number') {
345+
p = this._num_to_pos(p)
346+
}
314347
return this._board[p[0]][p[1]]
315348
}
316349

317350
set(p, turn) {
318-
let piece = this._board[p.from[0]][p.from[1]]
351+
p = {
352+
from: this._num_to_pos(p.from),
353+
path: p.path.map(v => this._num_to_pos(v)),
354+
jump: p.jump.map(v => this._num_to_pos(v)),
355+
}
356+
let piece = this.at(p.from)
319357
if (!(turn & piece)) {
320358
return false
321359
}
360+
if ((p.jump.length !== 0 || p.path.length !== 1) && p.jump.length !== p.path.length) {
361+
return false
362+
}
322363
const nturn = this.nextTurn(turn)
323-
if (p.jump.some(([i, j]) => !(this._board[i][j] & nturn))) {
364+
if (p.jump.some(j => !(this.at(j) & nturn))) {
324365
return false
325366
}
326-
if (p.path.some(([i, j]) => this._board[i][j] !== EMPTY)) {
367+
if (p.path.some(j => this.at(j) !== EMPTY)) {
327368
return false
328369
}
329370

@@ -334,6 +375,27 @@ class DraughtsBoard {
334375
}
335376
}
336377

378+
if (p.jump.length === 0) {
379+
for (let i = 0; i < 2; i++) {
380+
if (Math.abs(p.from[i] - p.path[0][i]) !== 1) {
381+
return false
382+
}
383+
}
384+
} else {
385+
let pos = p.from
386+
for (let k = 0; k < p.path.length; k++) {
387+
for (let i = 0; i < 2; i++) {
388+
if (Math.abs(pos[i] - p.jump[k][i]) !== 1) {
389+
return false
390+
}
391+
if (Math.abs(p.jump[k][i] - p.path[k][i]) !== 1) {
392+
return false
393+
}
394+
}
395+
pos = p.path[k]
396+
}
397+
}
398+
337399
this._board[p.from[0]][p.from[1]] = EMPTY
338400
for (const [i, j] of p.jump) {
339401
this._board[i][j] = EMPTY
@@ -354,10 +416,10 @@ class DraughtsBoard {
354416
this._board[i] = Array(this._size[1]).fill(EMPTY)
355417
}
356418
for (let i = 0; i < this._size[0]; i++) {
357-
for (let j = 0; j < this._size[1]; j++) {
358-
if (i < 3 && (i + j) % 2 === 0) {
419+
for (let j = i % 2 === 0 ? 1 : 0; j < this._size[1]; j += 2) {
420+
if (i < this._lines) {
359421
this._board[i][j] = RED
360-
} else if (this._size[0] - 3 <= i && (i + j) % 2 === 0) {
422+
} else if (this._size[0] - this._lines <= i) {
361423
this._board[i][j] = WHITE
362424
}
363425
}
@@ -418,9 +480,9 @@ class DraughtsBoard {
418480
cp._board[x + dx * 2][y + dy * 2] = this._board[x][y]
419481
cp._board[x][y] = EMPTY
420482
cp._board[x + dx][y + dy] = EMPTY
421-
if (turn === RED && x * dx * 2 === this._size[0] - 1) {
483+
if (turn === RED && x + dx * 2 === this._size[0] - 1) {
422484
cp._board[x + dx * 2][y + dy * 2] |= KING
423-
} else if (turn === WHITE && x * dx * 2 === 0) {
485+
} else if (turn === WHITE && x + dx * 2 === 0) {
424486
cp._board[x + dx * 2][y + dy * 2] |= KING
425487
}
426488
const npath = cp.allPath(x + dx * 2, y + dy * 2, turn, false)

lib/rl/gomoku.js

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,26 @@ class GomokuBoard {
190190
return null
191191
}
192192

193+
toString() {
194+
let buf = ''
195+
for (let i = 0; i < this._size[0]; i++) {
196+
for (let j = 0; j < this._size[1]; j++) {
197+
if (j > 0) {
198+
buf += ' '
199+
}
200+
if (this._board[i][j] === BLACK) {
201+
buf += 'x'
202+
} else if (this._board[i][j] === WHITE) {
203+
buf += 'o'
204+
} else {
205+
buf += '-'
206+
}
207+
}
208+
buf += '\n'
209+
}
210+
return buf
211+
}
212+
193213
nextTurn(turn) {
194214
return turn === BLACK ? WHITE : BLACK
195215
}

lib/rl/inhypercube.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ export default class InHypercubeRLEnvironment extends RLEnvironmentBase {
6868
}
6969

7070
const success = p[this._success_dim] <= -this._fail_position
71-
const fail = !success && p.every(v => Math.abs(v) >= this._fail_position)
71+
const fail = !success && p.some(v => Math.abs(v) >= this._fail_position)
7272
const done = this.epoch >= this._max_step || success || fail
7373
const reward = fail ? this._reward.fail : success ? this._reward.goal : this._reward.step
7474
return {

lib/rl/reversi.js

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export default class ReversiRLEnvironment extends RLEnvironmentBase {
4646
const a = [EMPTY]
4747
for (let i = 0; i < this._size[0]; i++) {
4848
for (let j = 0; j < this._size[1]; j++) {
49-
a.push(`${i}_${j}`)
49+
a.push(`${String.fromCharCode('a'.charCodeAt(0) + i)}${i + 1}`)
5050
}
5151
}
5252
return [a]
@@ -167,8 +167,7 @@ export default class ReversiRLEnvironment extends RLEnvironmentBase {
167167
invalid,
168168
}
169169
}
170-
const choice = action[0].split('_').map(v => +v)
171-
const changed = board.set(choice, agent)
170+
const changed = board.set(action[0], agent)
172171
const done = board.finish
173172
if (!changed) {
174173
return {
@@ -233,6 +232,26 @@ class ReversiBoard {
233232
return null
234233
}
235234

235+
toString() {
236+
let buf = ''
237+
for (let i = 0; i < this._size[0]; i++) {
238+
for (let j = 0; j < this._size[1]; j++) {
239+
if (j > 0) {
240+
buf += ' '
241+
}
242+
if (this._board[i][j] === BLACK) {
243+
buf += 'x'
244+
} else if (this._board[i][j] === WHITE) {
245+
buf += 'o'
246+
} else {
247+
buf += '-'
248+
}
249+
}
250+
buf += '\n'
251+
}
252+
return buf
253+
}
254+
236255
nextTurn(turn) {
237256
return flipPiece(turn)
238257
}
@@ -260,10 +279,16 @@ class ReversiBoard {
260279
}
261280

262281
at(p) {
282+
if (typeof p === 'string') {
283+
p = [p[1] - 1, p.charCodeAt(0) - 'a'.charCodeAt(0)]
284+
}
263285
return this._board[p[0]][p[1]]
264286
}
265287

266288
set(p, turn) {
289+
if (typeof p === 'string') {
290+
p = [p[1] - 1, p.charCodeAt(0) - 'a'.charCodeAt(0)]
291+
}
267292
const flips = this.flipPositions(p[0], p[1], turn)
268293
if (flips.length === 0) {
269294
return false
@@ -282,10 +307,10 @@ class ReversiBoard {
282307
}
283308
const cx = Math.floor(this._size[0] / 2)
284309
const cy = Math.floor(this._size[1] / 2)
285-
this._board[cx - 1][cy - 1] = BLACK
286-
this._board[cx - 1][cy] = WHITE
287-
this._board[cx][cy - 1] = WHITE
288-
this._board[cx][cy] = BLACK
310+
this._board[cx - 1][cy - 1] = WHITE
311+
this._board[cx - 1][cy] = BLACK
312+
this._board[cx][cy - 1] = BLACK
313+
this._board[cx][cy] = WHITE
289314
}
290315

291316
choices(turn) {

tests/lib/rl/acrobot.test.js

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,41 @@ describe('test', () => {
8080
expect(info.state[3]).toBeGreaterThan(0)
8181
})
8282

83+
test('small t1, t2', () => {
84+
const env = new AcrobotRLEnvironment()
85+
const info = env.test([-4, -13, 0, 0], [0])
86+
expect(info.done).toBeFalsy()
87+
expect(info.reward).toBe(-1)
88+
expect(info.state[0]).toBeCloseTo(-4 + 2 * Math.PI)
89+
expect(info.state[1]).toBeCloseTo(-13 + 4 * Math.PI)
90+
expect(info.state[2]).toBeLessThan(0)
91+
expect(info.state[3]).toBeGreaterThan(0)
92+
})
93+
94+
test('big t1, t2', () => {
95+
const env = new AcrobotRLEnvironment()
96+
const info = env.test([26, 4, 0, 0], [0])
97+
expect(info.done).toBeFalsy()
98+
expect(info.reward).toBe(-1)
99+
expect(info.state[0]).toBeCloseTo(26 - 8 * Math.PI)
100+
expect(info.state[1]).toBeCloseTo(4 - 2 * Math.PI)
101+
expect(info.state[2]).toBeLessThan(0)
102+
expect(info.state[3]).toBeGreaterThan(0)
103+
})
104+
105+
test('clip dt1, dt2', () => {
106+
const env = new AcrobotRLEnvironment()
107+
const info = env.test([0, 0, -100, 100], [0])
108+
expect(info.done).toBeFalsy()
109+
expect(info.reward).toBe(-1)
110+
for (let i = 0; i < 2; i++) {
111+
expect(info.state[i]).toBeLessThanOrEqual(Math.PI)
112+
expect(info.state[i]).toBeGreaterThanOrEqual(-Math.PI)
113+
}
114+
expect(info.state[2]).toBeCloseTo(-4 * Math.PI)
115+
expect(info.state[3]).toBeCloseTo(9 * Math.PI)
116+
})
117+
83118
test('goal', () => {
84119
const env = new AcrobotRLEnvironment()
85120
const info = env.test([Math.PI, Math.PI / 2, 0, 0], [0])

0 commit comments

Comments
 (0)