Skip to content

Commit 1ab94f7

Browse files
authored
Improve tests (#899)
1 parent e096a57 commit 1ab94f7

File tree

8 files changed

+227
-44
lines changed

8 files changed

+227
-44
lines changed

tests/lib/model/affinity_propagation.test.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ test('predict', () => {
2020
}
2121
}
2222
expect(model.size).toBe(2)
23+
const centroids = model.centroids
24+
for (let i = 0; i < 2; i++) {
25+
let hasSame = false
26+
for (let k = 0; k < x.length && !hasSame; k++) {
27+
hasSame |= x[k].every((v, d) => v === centroids[i][d])
28+
}
29+
expect(hasSame).toBeTruthy()
30+
}
2331
const y = model.predict()
2432
expect(y).toHaveLength(x.length)
2533

tests/lib/model/arma.test.js

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,17 @@ test('sin', () => {
3636
expect(future[i]).toBeCloseTo(Math.sin(10 + i / 10), 1)
3737
}
3838
})
39+
40+
test('const', () => {
41+
const model = new ARMA(5, 2)
42+
const x = Array(100).fill(0)
43+
44+
for (let i = 0; i < 1; i++) {
45+
model.fit(x)
46+
}
47+
const future = model.predict(x, 20)
48+
expect(future).toHaveLength(20)
49+
for (let i = 0; i < 20; i++) {
50+
expect(future[i]).toBeCloseTo(0)
51+
}
52+
})

tests/lib/model/bridge.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { jest } from '@jest/globals'
2-
jest.retryTimes(3)
2+
jest.retryTimes(5)
33

44
import Matrix from '../../../lib/util/matrix.js'
55
import BRIDGE from '../../../lib/model/bridge.js'

tests/lib/model/diana.test.js

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,58 @@ import DIANA from '../../../lib/model/diana.js'
33

44
import { randIndex } from '../../../lib/evaluate/clustering.js'
55

6-
test('clustering', () => {
7-
const model = new DIANA()
8-
const n = 50
9-
const x = Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)).toArray()
10-
11-
model.init(x)
12-
model.fit()
13-
expect(model.size).toBe(2)
14-
const y = model.predict()
15-
expect(y).toHaveLength(x.length)
16-
17-
const t = []
18-
for (let i = 0; i < x.length; i++) {
19-
t[i] = Math.floor(i / n)
20-
}
21-
const ri = randIndex(y, t)
22-
expect(ri).toBeGreaterThan(0.9)
6+
describe('clustering', () => {
7+
test('2 clusters', () => {
8+
const model = new DIANA()
9+
const n = 50
10+
const x = Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)).toArray()
11+
12+
model.init(x)
13+
model.fit()
14+
expect(model.size).toBe(2)
15+
const y = model.predict()
16+
expect(y).toHaveLength(x.length)
17+
18+
const t = []
19+
for (let i = 0; i < x.length; i++) {
20+
t[i] = Math.floor(i / n)
21+
}
22+
const ri = randIndex(y, t)
23+
expect(ri).toBeGreaterThan(0.9)
24+
})
25+
26+
test('4 clusters', () => {
27+
const model = new DIANA()
28+
const n = 50
29+
const x = Matrix.concat(
30+
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
31+
Matrix.concat(Matrix.randn(n, 2, 10, 0.1), Matrix.randn(n, 2, 15, 0.1))
32+
).toArray()
33+
34+
model.init(x)
35+
model.fit()
36+
model.fit()
37+
expect(model.size).toBe(4)
38+
const y = model.predict()
39+
expect(y).toHaveLength(x.length)
40+
41+
const t = []
42+
for (let i = 0; i < x.length; i++) {
43+
t[i] = Math.floor(i / n)
44+
}
45+
const ri = randIndex(y, t)
46+
expect(ri).toBeGreaterThan(0.9)
47+
})
48+
49+
test('single cluster', () => {
50+
const model = new DIANA()
51+
const x = [[0, 0]]
52+
53+
model.init(x)
54+
model.fit()
55+
expect(model.size).toBe(1)
56+
const y = model.predict()
57+
expect(y).toHaveLength(x.length)
58+
expect(y[0]).toBe(0)
59+
})
2360
})

tests/lib/model/lsif.test.js

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,38 @@ jest.retryTimes(3)
44
import Matrix from '../../../lib/util/matrix.js'
55
import LSIF from '../../../lib/model/lsif.js'
66

7-
test('LSIF', () => {
8-
const sigmas = []
9-
const lambdas = []
10-
for (let i = -3; i <= 3; i += 1) {
11-
sigmas.push(10 ** i)
12-
lambdas.push(10 ** i)
13-
}
14-
const model = new LSIF(sigmas, lambdas, 5, 100)
7+
describe('LSIF', () => {
8+
test('some candidates', () => {
9+
const sigmas = []
10+
const lambdas = []
11+
for (let i = -3; i <= 3; i += 1) {
12+
sigmas.push(10 ** i)
13+
lambdas.push(10 ** i)
14+
}
15+
const model = new LSIF(sigmas, lambdas, 5, 100)
1516

16-
const x1 = Matrix.randn(300, 1, 0).toArray()
17-
const x2 = Matrix.randn(200, 1, 0).toArray()
18-
model.fit(x1, x2)
17+
const x1 = Matrix.randn(300, 1, 0).toArray()
18+
const x2 = Matrix.randn(200, 1, 0).toArray()
19+
model.fit(x1, x2)
1920

20-
const r = model.predict(x2)
21-
for (let i = 0; i < x2.length; i++) {
22-
expect(r[i]).toBeCloseTo(1, 0)
23-
}
21+
const r = model.predict(x2)
22+
for (let i = 0; i < x2.length; i++) {
23+
expect(r[i]).toBeCloseTo(1, 0)
24+
}
25+
})
26+
27+
test('single candidates', () => {
28+
const sigmas = [10]
29+
const lambdas = [0.001]
30+
const model = new LSIF(sigmas, lambdas, 5, 100)
31+
32+
const x1 = Matrix.randn(300, 1, 0).toArray()
33+
const x2 = Matrix.randn(200, 1, 0).toArray()
34+
model.fit(x1, x2)
35+
36+
const r = model.predict(x2)
37+
for (let i = 0; i < x2.length; i++) {
38+
expect(r[i]).toBeCloseTo(1, 0)
39+
}
40+
})
2441
})

tests/lib/model/odin.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { jest } from '@jest/globals'
2-
jest.retryTimes(3)
2+
jest.retryTimes(5)
33

44
import Matrix from '../../../lib/util/matrix.js'
55
import ODIN from '../../../lib/model/odin.js'

tests/lib/rl/breaker.test.js

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,29 @@ describe('test', () => {
6969
expect(env.epoch).toBe(0)
7070
})
7171

72-
test.each([0, 1000])('bar position: %p', p => {
72+
test('bar position: under', () => {
7373
const env = new BreakerRLEnvironment()
7474
const state = env.reset()
75-
state[4] = p
75+
state[4] = 0
7676

7777
const info = env.test(state, [0])
7878
expect(info.done).toBeFalsy()
7979
expect(info.reward).toBe(0.1)
8080
expect(info.state).toHaveLength(85)
81+
expect(info.state[4]).toBe(env._paddle_size[0] / 2)
82+
expect(env.epoch).toBe(0)
83+
})
84+
85+
test('bar position: over', () => {
86+
const env = new BreakerRLEnvironment()
87+
const state = env.reset()
88+
state[4] = 1000
89+
90+
const info = env.test(state, [0])
91+
expect(info.done).toBeFalsy()
92+
expect(info.reward).toBe(0.1)
93+
expect(info.state).toHaveLength(85)
94+
expect(info.state[4]).toBe(env._size[0] - env._paddle_size[0] / 2)
8195
expect(env.epoch).toBe(0)
8296
})
8397

@@ -97,27 +111,62 @@ describe('test', () => {
97111
expect(env.epoch).toBe(0)
98112
})
99113

100-
test('hit paddle side', () => {
114+
test.each([
115+
[1, -1],
116+
[-1, 1],
117+
])('hit paddle side: %p', (dx, dy) => {
101118
const env = new BreakerRLEnvironment()
102119
const state = env.reset()
103120
state[0] = 100 - env._paddle_size[0] / 2
104121
state[1] = env._paddle_baseline
105-
state[2] = 1
106-
state[3] = -1
122+
state[2] = dx
123+
state[3] = dy
107124
state[4] = 100
108125

109126
const info = env.test(state, [0])
110127
expect(info.done).toBeFalsy()
111128
expect(info.reward).toBe(100)
112129
expect(info.state).toHaveLength(85)
113-
expect(info.state[0]).toBe(state[0] + 1)
114-
expect(info.state[1]).toBe(state[1] - 1)
115-
expect(info.state[2]).toBe(-1)
116-
expect(info.state[3]).toBe(-1)
130+
expect(info.state[0]).toBe(state[0] + dx)
131+
expect(info.state[1]).toBe(state[1] + dy)
132+
expect(info.state[2]).toBe(-dx)
133+
expect(info.state[3]).toBe(dy)
117134
expect(info.state[4]).toBe(100)
118135
expect(env.epoch).toBe(0)
119136
})
120137

138+
test('hit paddle side top (no contact)', () => {
139+
const env = new BreakerRLEnvironment()
140+
const state = env.reset()
141+
state[0] = 100 - env._paddle_size[0] / 2
142+
state[1] = env._paddle_baseline + env._paddle_size[1] / 2
143+
state[2] = -env._ball_radius / Math.SQRT2
144+
state[3] = env._ball_radius / Math.SQRT2
145+
state[4] = 100
146+
147+
const info = env.test(state, [0])
148+
expect(info.done).toBeFalsy()
149+
expect(info.reward).toBe(0.1)
150+
expect(info.state).toHaveLength(85)
151+
expect(env.epoch).toBe(0)
152+
})
153+
154+
test('hit paddle side top (contact)', () => {
155+
const env = new BreakerRLEnvironment()
156+
const state = env.reset()
157+
state[0] = 100 - env._paddle_size[0] / 2
158+
state[1] = env._paddle_baseline + env._paddle_size[1] / 2
159+
state[2] = -env._ball_radius / 2
160+
state[3] = env._ball_radius / 2
161+
state[4] = 100
162+
163+
const info = env.test(state, [0])
164+
expect(info.done).toBeFalsy()
165+
expect(info.reward).toBe(100)
166+
expect(info.state).toHaveLength(85)
167+
expect(env.epoch).toBe(0)
168+
})
169+
121170
test('hit side left', () => {
122171
const env = new BreakerRLEnvironment()
123172
const state = env.reset()

tests/lib/rl/waterball.test.js

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ test('actions', () => {
1313
test('states', () => {
1414
const env = new WaterballRLEnvironment(100, 100)
1515
expect(env.states).toHaveLength(82)
16+
expect(env.states).toHaveLength(82)
1617
})
1718

1819
test('reset', () => {
@@ -62,11 +63,68 @@ describe('state', () => {
6263
})
6364

6465
describe('step', () => {
65-
test('step', () => {
66+
test.each([0, 1, 2, 3])('step %p', action => {
67+
const env = new WaterballRLEnvironment(100, 100)
68+
const info = env.step([action])
69+
expect(info.done).toBeFalsy()
70+
expect(info.reward).toBeDefined()
71+
expect(info.state).toBeInstanceOf(Array)
72+
})
73+
74+
test('big velocity', () => {
75+
const env = new WaterballRLEnvironment(100, 100)
76+
for (let i = 0; i < 20; i++) {
77+
env.step([0])
78+
}
79+
const info = env.step([0])
80+
expect(info.done).toBeFalsy()
81+
expect(info.reward).toBeDefined()
82+
expect(info.state).toBeInstanceOf(Array)
83+
expect(info.state[0]).toBe(env._agent_max_velocity)
84+
})
85+
86+
test('small velocity', () => {
87+
const env = new WaterballRLEnvironment(100, 100)
88+
for (let i = 0; i < 20; i++) {
89+
env.step([1])
90+
}
91+
const info = env.step([1])
92+
expect(info.done).toBeFalsy()
93+
expect(info.reward).toBeDefined()
94+
expect(info.state).toBeInstanceOf(Array)
95+
expect(info.state[0]).toBe(-env._agent_max_velocity)
96+
})
97+
98+
test('touch wall max', () => {
6699
const env = new WaterballRLEnvironment(100, 100)
100+
for (let i = 0; i < 60; i++) {
101+
env.step([0])
102+
}
67103
const info = env.step([0])
68104
expect(info.done).toBeFalsy()
69105
expect(info.reward).toBeDefined()
70106
expect(info.state).toBeInstanceOf(Array)
71107
})
108+
109+
test('touch wall min', () => {
110+
const env = new WaterballRLEnvironment(100, 100)
111+
for (let i = 0; i < 60; i++) {
112+
env.step([1])
113+
}
114+
const info = env.step([1])
115+
expect(info.done).toBeFalsy()
116+
expect(info.reward).toBeDefined()
117+
expect(info.state).toBeInstanceOf(Array)
118+
})
119+
120+
test('add ball', () => {
121+
const env = new WaterballRLEnvironment(10, 10)
122+
for (let i = 0; i < 1000; i++) {
123+
env.step([Math.floor(Math.random() * 4)])
124+
}
125+
const info = env.step([1])
126+
expect(info.done).toBeFalsy()
127+
expect(info.reward).toBeDefined()
128+
expect(info.state).toBeInstanceOf(Array)
129+
})
72130
})

0 commit comments

Comments
 (0)