-
-
Notifications
You must be signed in to change notification settings - Fork 108
Closed
Labels
Description
Policies are saved by
function save(f::String, p::AbstractPolicy)
policy = cpu(p)
BSON.@save f policy
endin ReinforcementLearningCore.jl/src/extensions/ReinforcementLearningBase.jl
But some policies, such as PPOPolicy can not be saved, if they are in GPU.
due to cpu(p) does not work.
reproduce code
using ReinforcementLearningZoo, ReinforcementLearningBase,ReinforcementLearningCore,Flux
N_ENV = 2;
agent = Agent(
policy=PPOPolicy(
approximator=ActorCritic(
actor=Chain(Dense(3, 3),),
critic=Chain(Dense(3, 3),),
optimizer=ADAM(1e-3),
) |> gpu,
γ=0.99f0,λ=0.95f0,clip_range=0.1f0,max_grad_norm=0.5f0,n_epochs=4,n_microbatches=4,actor_loss_weight=1.0f0,critic_loss_weight=0.5f0,entropy_loss_weight=0.001f0,
),
trajectory=PPOTrajectory(;capacity=32,state_type=Float32,state_size=(2, N_ENV),action_type=Int,action_size=(N_ENV,),action_log_prob_type=Float32,action_log_prob_size=(N_ENV,),reward_type=Float32,reward_size=(N_ENV,),terminal_type=Bool,terminal_size=(N_ENV,),),
);
cpu(agent.policy) |> typeof
cpu(agent.policy.approximator) |> typeofoutput:
julia> cpu(agent.policy) |> typeof
PPOPolicy{ActorCritic{Chain{Tuple{Dense{typeof(identity),CUDA.CuArray{Float32,2},CUDA.CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(identity),CUDA.CuArray{Float32,2},CUDA.CuArray{Float32,1}}}},ADAM},Distributions.DiscreteNonParametric{Int64,P,Base.OneTo{Int64},Ps} where Ps<:AbstractArray{P,1} where P<:Real,Random._GLOBAL_RNG}
julia> cpu(agent.policy.approximator) |> typeof
ActorCritic{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},ADAM}I think these policies should add Flux.functor method, similar to
Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learnerI tried to add
Flux.functor(x::PPOPolicy) = (approximator=x.approximator,), y -> @set x.approximator = y.approximatorin ppo.jl, but this does not work
julia> cpu(agent.policy) |> typeof
ERROR: MethodError: no method matching PPOPolicy(::ActorCritic{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},ADAM}, ::Float32, ::Float32, ::Float32, ::Float32, ::Int64, ::Int64, ::Float32, ::Float32, ::Float32, ::Random._GLOBAL_RNG, ::Array{Float32,2}, ::Array{Float32,2}, ::Array{Float32,2}, ::Array{Float32,2}, ::Array{Float32,2})