diff --git a/src/algorithms/dqns/iqn.jl b/src/algorithms/dqns/iqn.jl index bb205dc..0c9f825 100644 --- a/src/algorithms/dqns/iqn.jl +++ b/src/algorithms/dqns/iqn.jl @@ -160,7 +160,7 @@ function (learner::IQNLearner)(obs) τ = rand(learner.device_rng, Float32, learner.K, 1) τₑₘ = embed(τ, learner.Nₑₘ) quantiles = learner.approximator(state, τₑₘ) - vec(sum(quantiles; dims = 2)) |> send_to_host + vec(mean(quantiles; dims = 2)) |> send_to_host end embed(x, Nₑₘ) = cos.(Float32(π) .* (1:Nₑₘ) .* reshape(x, 1, :)) @@ -184,9 +184,10 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple) τ′ = rand(learner.device_rng, Float32, N′, batch_size) # TODO: support β distribution τₑₘ′ = embed(τ′, Nₑₘ) zₜ = Zₜ(s′, τₑₘ′) - aₜ = argmax(zₜ; dims = 1) # risk-sensitive - qₜ = reshape(zₜ[aₜ], N′, batch_size) # view? - target = reshape(r, 1, batch_size) .+ reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast + aₜ = argmax(mean(zₜ, dims = 2), dims = 1) + aₜ = aₜ .+ typeof(aₜ)(CartesianIndices((0, 0:N′-1, 0))) + qₜ = reshape(zₜ[aₜ], :, batch_size) + target = reshape(r, 1, batch_size) .+ learner.γ * reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast τ = rand(learner.device_rng, Float32, N, batch_size) τₑₘ = embed(τ, Nₑₘ) diff --git a/src/experiments/atari.jl b/src/experiments/atari.jl index 4c1f994..e059dd1 100644 --- a/src/experiments/atari.jl +++ b/src/experiments/atari.jl @@ -157,7 +157,7 @@ function RLCore.Experiment( stop_condition = StopAfterStep(N_TRAINING_STEPS) description = """ - This experiment uses alomost the same config in [dopamine](https://github.com/google/dopamine/blob/master/dopamine/agents/dqn/configs/dqn.gin). But do notice that there are some minor differences: + This experiment uses almost the same config in [dopamine](https://github.com/google/dopamine/blob/master/dopamine/agents/dqn/configs/dqn.gin). But do notice that there are some minor differences: - The RMSProp in Flux do not support center option (also the epsilon is not the same). - The image resize method used here is provided by ImageTransformers, which is not the same with the one in cv2.