Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit b6cd83a

Browse files
authored
refactor NeuralNetworkApproximator with @forward (#221)
1 parent c70f865 commit b6cd83a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export NeuralNetworkApproximator, ActorCritic
22

33
using Flux
44
import Functors: functor
5+
using MacroTools: @forward
56

67
"""
78
NeuralNetworkApproximator(;kwargs)
@@ -21,12 +22,11 @@ end
2122
# some model may accept multiple inputs
2223
(app::NeuralNetworkApproximator)(args...; kwargs...) = app.model(args...; kwargs...)
2324

25+
@forward NeuralNetworkApproximator.model Flux.testmode!, Flux.trainmode!, Flux.params, device
2426

2527
functor(x::NeuralNetworkApproximator) =
2628
(model = x.model,), y -> NeuralNetworkApproximator(y.model, x.optimizer)
2729

28-
device(app::NeuralNetworkApproximator) = device(app.model)
29-
3030
RLBase.update!(app::NeuralNetworkApproximator, gs) =
3131
Flux.Optimise.update!(app.optimizer, params(app), gs)
3232

0 commit comments

Comments
 (0)