@@ -16,17 +16,17 @@ Generally speaking, it does nothing but update the trajectory and policy appropr
1616
1717- `policy`::[`AbstractPolicy`](@ref): the policy to use
1818- `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment
19- - `role=: DEFAULT_PLAYER`: used to distinguish different agents
19+ - `role=RLBase. DEFAULT_PLAYER`: used to distinguish different agents
2020"""
2121Base. @kwdef mutable struct Agent{P<: AbstractPolicy ,T<: AbstractTrajectory ,R} <: AbstractAgent
2222 policy:: P
23- trajectory:: T
24- role:: R = : DEFAULT_PLAYER
23+ trajectory:: T = DummyTrajectory ()
24+ role:: R = RLBase . DEFAULT_PLAYER
2525 is_training:: Bool = true
2626end
2727
2828# avoid polluting trajectory
29- (agent:: Agent )(obs ) = agent. policy (obs )
29+ (agent:: Agent )(env ) = agent. policy (env )
3030
3131Flux. functor (x:: Agent ) = (policy = x. policy,), y -> @set x. policy = y. policy
3232
@@ -69,47 +69,54 @@ function Flux.testmode!(agent::Agent, mode = true)
6969 testmode! (agent. policy, mode)
7070end
7171
72- (agent:: Agent )(stage:: AbstractStage , obs ) =
73- agent. is_training ? agent (Training (stage), obs ) : agent (Testing (stage), obs )
72+ (agent:: Agent )(stage:: AbstractStage , env ) =
73+ agent. is_training ? agent (Training (stage), env ) : agent (Testing (stage), env )
7474
75- (agent:: Agent )(:: Testing , obs) = nothing
76- (agent:: Agent )(:: Testing{PreActStage} , obs) = agent. policy (obs)
75+ (agent:: Agent )(:: Testing , env) = nothing
76+ (agent:: Agent )(:: Testing{PreActStage} , env) = agent. policy (env)
77+
78+ # ####
79+ # DummyTrajectory
80+ # ####
81+
82+ (agent:: Agent{<:AbstractPolicy, <:DummyTrajectory} )(stage:: AbstractStage , env) = nothing
83+ (agent:: Agent{<:AbstractPolicy, <:DummyTrajectory} )(stage:: PreActStage , env) = agent. policy (env)
7784
7885# ####
7986# EpisodicCompactSARTSATrajectory
8087# ####
8188function (agent:: Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory} )(
8289 :: Training{PreEpisodeStage} ,
83- obs ,
90+ env ,
8491)
8592 empty! (agent. trajectory)
8693 nothing
8794end
8895
8996function (agent:: Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory} )(
9097 :: Training{PreActStage} ,
91- obs ,
98+ env ,
9299)
93- action = agent. policy (obs )
94- push! (agent. trajectory; state = get_state (obs ), action = action)
100+ action = agent. policy (env )
101+ push! (agent. trajectory; state = get_state (env ), action = action)
95102 update! (agent. policy, agent. trajectory)
96103 action
97104end
98105
99106function (agent:: Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory} )(
100107 :: Training{PostActStage} ,
101- obs ,
108+ env ,
102109)
103- push! (agent. trajectory; reward = get_reward (obs ), terminal = get_terminal (obs ))
110+ push! (agent. trajectory; reward = get_reward (env ), terminal = get_terminal (env ))
104111 nothing
105112end
106113
107114function (agent:: Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory} )(
108115 :: Training{PostEpisodeStage} ,
109- obs ,
116+ env ,
110117)
111- action = agent. policy (obs )
112- push! (agent. trajectory; state = get_state (obs ), action = action)
118+ action = agent. policy (env )
119+ push! (agent. trajectory; state = get_state (env ), action = action)
113120 update! (agent. policy, agent. trajectory)
114121 action
115122end
@@ -125,7 +132,7 @@ function (
125132 }
126133)(
127134 :: Training{PreEpisodeStage} ,
128- obs ,
135+ env ,
129136)
130137 if length (agent. trajectory) > 0
131138 pop! (agent. trajectory, :state , :action )
@@ -140,10 +147,10 @@ function (
140147 }
141148)(
142149 :: Training{PreActStage} ,
143- obs ,
150+ env ,
144151)
145- action = agent. policy (obs )
146- push! (agent. trajectory; state = get_state (obs ), action = action)
152+ action = agent. policy (env )
153+ push! (agent. trajectory; state = get_state (env ), action = action)
147154 update! (agent. policy, agent. trajectory)
148155 action
149156end
@@ -155,9 +162,9 @@ function (
155162 }
156163)(
157164 :: Training{PostActStage} ,
158- obs ,
165+ env ,
159166)
160- push! (agent. trajectory; reward = get_reward (obs ), terminal = get_terminal (obs ))
167+ push! (agent. trajectory; reward = get_reward (env ), terminal = get_terminal (env ))
161168 nothing
162169end
163170
@@ -168,10 +175,10 @@ function (
168175 }
169176)(
170177 :: Training{PostEpisodeStage} ,
171- obs ,
178+ env ,
172179)
173- action = agent. policy (obs )
174- push! (agent. trajectory; state = get_state (obs ), action = action)
180+ action = agent. policy (env )
181+ push! (agent. trajectory; state = get_state (env ), action = action)
175182 update! (agent. policy, agent. trajectory)
176183 action
177184end
182189
183190function (agent:: Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory} )(
184191 :: Training{PreEpisodeStage} ,
185- obs ,
192+ env ,
186193)
187194 if length (agent. trajectory) > 0
188195 pop! (agent. trajectory, :state , :action )
@@ -192,28 +199,28 @@ end
192199
193200function (agent:: Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory} )(
194201 :: Training{PreActStage} ,
195- obs ,
202+ env ,
196203)
197- action = agent. policy (obs )
198- push! (agent. trajectory; state = get_state (obs ), action = action)
204+ action = agent. policy (env )
205+ push! (agent. trajectory; state = get_state (env ), action = action)
199206 update! (agent. policy, agent. trajectory)
200207 action
201208end
202209
203210function (agent:: Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory} )(
204211 :: Training{PostActStage} ,
205- obs ,
212+ env ,
206213)
207- push! (agent. trajectory; reward = get_reward (obs ), terminal = get_terminal (obs ))
214+ push! (agent. trajectory; reward = get_reward (env ), terminal = get_terminal (env ))
208215 nothing
209216end
210217
211218function (agent:: Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory} )(
212219 :: Training{PostEpisodeStage} ,
213- obs ,
220+ env ,
214221)
215- action = agent. policy (obs )
216- push! (agent. trajectory; state = get_state (obs ), action = action)
222+ action = agent. policy (env )
223+ push! (agent. trajectory; state = get_state (env ), action = action)
217224 update! (agent. policy, agent. trajectory)
218225 action
219226end
0 commit comments