Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/extensions/Distributions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export normlogpdf

# watch https://github.com/JuliaStats/Distributions.jl/issues/1183

"""
GPU automatic differentiable version for the logpdf function of normal distributions.
Adding an epsilon value to guarantee numeric stability if sigma is exactly zero
(e.g. if relu is used in output layer).
"""
function normlogpdf(μ, σ, x; ϵ = 1.0f-8)
z = (x .- μ) ./ (σ .+ ϵ)
-(z .^ 2 .+ log(2.0f0π)) / 2.0f0 .- log.(σ .+ ϵ)
end
1 change: 1 addition & 0 deletions src/extensions/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ include("CUDA.jl")
include("Zygote.jl")
include("ReinforcementLearningBase.jl")
include("ElasticArrays.jl")
include("Distributions.jl")
20 changes: 20 additions & 0 deletions test/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,23 @@
clip_by_global_norm!(gs, ps, 4.0f0)
@test isapprox(gs[:x], [0.0 0.0 0.0; 0.0 0.0 0.0])
end


@testset "Distributions" begin
@test isapprox(logpdf(Normal(), 2), normlogpdf(0, 1, 2))
@test isapprox(logpdf.([Normal(), Normal()], [2, 10]), normlogpdf([0, 0], [1, 1], [2, 10]))

# Test numeric stability for 0 sigma
@test isnan(normlogpdf(0, 0, 2, ϵ=0))
@test !isnan(normlogpdf(0, 0, 2))

if CUDA.functional()
cpu_grad = Zygote.gradient([0.2, 0.5]) do x
sum(logpdf.([Normal(1,0.1), Normal(2,0.2)], x))
end
gpu_grad = Zygote.gradient(cu([0.2, 0.5])) do x
sum(normlogpdf(cu([1, 2]), cu([0.1, 0.2]),x))
end
@test isapprox(cpu_grad[1], gpu_grad[1] |> Array)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using ReinforcementLearningCore
using Random
using Test
using StatsBase
using Distributions: probs
using Distributions: probs, Normal, logpdf
using ReinforcementLearningEnvironments
using Flux
using Zygote
Expand Down