1- import { jest } from '@jest/globals'
1+ import { expect , jest } from '@jest/globals'
22jest . retryTimes ( 5 )
33
44import DQNAgent from '../../../lib/model/dqn.js'
5+ import ReversiRLEnvironment from '../../../lib/rl/reversi.js'
56import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
67import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'
8+ import PendulumRLEnvironment from '../../../lib/rl/pendulum.js'
79
810test ( 'update dqn' , ( ) => {
911 const env = new InHypercubeRLEnvironment ( 2 )
@@ -29,6 +31,7 @@ test('update dqn', () => {
2931 }
3032 }
3133 expect ( totalReward . slice ( Math . max ( 0 , totalReward . length - 10 ) ) . every ( v => v > 0 ) ) . toBeTruthy ( )
34+ agent . terminate ( )
3235} )
3336
3437test ( 'update ddqn' , ( ) => {
@@ -56,16 +59,94 @@ test('update ddqn', () => {
5659 }
5760 }
5861 expect ( totalReward . slice ( Math . max ( 0 , totalReward . length - 10 ) ) . every ( v => v > 0 ) ) . toBeTruthy ( )
62+ agent . terminate ( )
63+ } )
64+
65+ test ( 'realrange action' , ( ) => {
66+ const env = new PendulumRLEnvironment ( )
67+ const agent = new DQNAgent ( env , 10 , [ { type : 'full' , out_size : 3 , activation : 'tanh' } ] , 'adam' )
68+ agent . _net . _batch_size = 1
69+ agent . _net . _fix_param_update_step = 1
70+ agent . _net . _do_update_step = 1
71+
72+ let curState = env . reset ( )
73+ const action = agent . get_action ( curState , 0.9 )
74+ const { state, reward, done } = env . step ( action )
75+ agent . update ( action , curState , state , reward , done , 0.001 , 10 )
76+
77+ const best_action = agent . get_action ( state , 0 )
78+ expect ( best_action ) . toHaveLength ( 1 )
79+ } )
80+
81+ test ( 'array state action' , ( ) => {
82+ const env = new ReversiRLEnvironment ( )
83+ const agent = new DQNAgent ( env , 20 , [ { type : 'full' , out_size : 10 , activation : 'tanh' } ] , 'adam' )
84+
85+ agent . _net . _batch_size = 1
86+ agent . _net . _fix_param_update_step = 1
87+ agent . _net . _do_update_step = 1
88+
89+ let curState = env . reset ( )
90+ const action = agent . get_action ( curState , 0.9 )
91+ const { state, reward, done } = env . step ( action )
92+ agent . update ( action , curState , state , reward , done , 0.001 , 10 )
93+
94+ const best_action = agent . get_action ( state , 0 )
95+ expect ( best_action ) . toHaveLength ( 1 )
96+ } )
97+
98+ test ( 'max memory size' , ( ) => {
99+ const env = new InHypercubeRLEnvironment ( 2 )
100+ const agent = new DQNAgent ( env , 10 , [ { type : 'full' , out_size : 3 , activation : 'tanh' } ] , 'adam' )
101+ agent . method = 'DDQN'
102+ agent . _net . _batch_size = 1
103+ agent . _net . _max_memory_size = 10
104+
105+ let curState = env . reset ( )
106+ const action = agent . get_action ( curState , 0.9 )
107+ const { state, reward, done } = env . step ( action )
108+ for ( let i = 0 ; i < 20 ; i ++ ) {
109+ agent . update ( action , curState , state , reward , done , 0.001 , 10 )
110+ expect ( agent . _net . _memory . length ) . toBeLessThanOrEqual ( 10 )
111+ }
112+ } )
113+
114+ test ( 'reset to dqn' , ( ) => {
115+ const env = new InHypercubeRLEnvironment ( 2 )
116+ const agent = new DQNAgent ( env , 10 , [ { type : 'full' , out_size : 3 , activation : 'tanh' } ] , 'adam' )
117+ agent . method = 'DDQN'
118+ agent . _net . _batch_size = 1
119+ agent . _net . _fix_param_update_step = 1
120+ agent . _net . _do_update_step = 1
121+
122+ let curState = env . reset ( )
123+ const action = agent . get_action ( curState , 0.9 )
124+ const { state, reward, done } = env . step ( action )
125+ agent . update ( action , curState , state , reward , done , 0.001 , 10 )
126+
127+ expect ( agent . _net . _target ) . toBeDefined ( )
128+ agent . method = 'DQN'
129+ expect ( agent . _net . _target ) . toBeNull ( )
59130} )
60131
61132test ( 'get_score' , ( ) => {
62133 const env = new CartPoleRLEnvironment ( )
63- const agent = new DQNAgent ( env , 20 , [ { type : 'full' , out_size : 10 , activation : 'tanh' } ] , 'adam' )
134+ const agent = new DQNAgent ( env , 12 , [ { type : 'full' , out_size : 10 , activation : 'tanh' } ] , 'adam' )
64135
65136 const score = agent . get_score ( )
66- expect ( score ) . toHaveLength ( 20 )
67- expect ( score [ 0 ] ) . toHaveLength ( 20 )
68- expect ( score [ 0 ] [ 0 ] ) . toHaveLength ( 20 )
69- expect ( score [ 0 ] [ 0 ] [ 0 ] ) . toHaveLength ( 20 )
137+ expect ( score ) . toHaveLength ( 12 )
138+ expect ( score [ 0 ] ) . toHaveLength ( 12 )
139+ expect ( score [ 0 ] [ 0 ] ) . toHaveLength ( 12 )
140+ expect ( score [ 0 ] [ 0 ] [ 0 ] ) . toHaveLength ( 12 )
70141 expect ( score [ 0 ] [ 0 ] [ 0 ] [ 0 ] ) . toHaveLength ( 2 )
142+
143+ agent . get_score ( )
144+ } )
145+
146+ test ( 'get_action default' , ( ) => {
147+ const env = new InHypercubeRLEnvironment ( 2 )
148+ const agent = new DQNAgent ( env , 10 , [ { type : 'full' , out_size : 3 , activation : 'tanh' } ] , 'adam' )
149+
150+ const action = agent . get_action ( env . state ( ) )
151+ expect ( action ) . toHaveLength ( 1 )
71152} )
0 commit comments