diff --git a/src/extensions/Distributions.jl b/src/extensions/Distributions.jl new file mode 100644 index 0000000..6a3cfbf --- /dev/null +++ b/src/extensions/Distributions.jl @@ -0,0 +1,13 @@ +export normlogpdf + +# watch https://github.com/JuliaStats/Distributions.jl/issues/1183 + +""" +GPU automatic differentiable version for the logpdf function of normal distributions. +Adding an epsilon value to guarantee numeric stability if sigma is exactly zero +(e.g. if relu is used in output layer). +""" +function normlogpdf(μ, σ, x; ϵ = 1.0f-8) + z = (x .- μ) ./ (σ .+ ϵ) + -(z .^ 2 .+ log(2.0f0π)) / 2.0f0 .- log.(σ .+ ϵ) +end diff --git a/src/extensions/extensions.jl b/src/extensions/extensions.jl index 9375b69..f4a17d9 100644 --- a/src/extensions/extensions.jl +++ b/src/extensions/extensions.jl @@ -3,3 +3,4 @@ include("CUDA.jl") include("Zygote.jl") include("ReinforcementLearningBase.jl") include("ElasticArrays.jl") +include("Distributions.jl") diff --git a/test/extensions.jl b/test/extensions.jl index 07659dc..6a3265a 100644 --- a/test/extensions.jl +++ b/test/extensions.jl @@ -14,3 +14,23 @@ clip_by_global_norm!(gs, ps, 4.0f0) @test isapprox(gs[:x], [0.0 0.0 0.0; 0.0 0.0 0.0]) end + + +@testset "Distributions" begin + @test isapprox(logpdf(Normal(), 2), normlogpdf(0, 1, 2)) + @test isapprox(logpdf.([Normal(), Normal()], [2, 10]), normlogpdf([0, 0], [1, 1], [2, 10])) + + # Test numeric stability for 0 sigma + @test isnan(normlogpdf(0, 0, 2, ϵ=0)) + @test !isnan(normlogpdf(0, 0, 2)) + + if CUDA.functional() + cpu_grad = Zygote.gradient([0.2, 0.5]) do x + sum(logpdf.([Normal(1,0.1), Normal(2,0.2)], x)) + end + gpu_grad = Zygote.gradient(cu([0.2, 0.5])) do x + sum(normlogpdf(cu([1, 2]), cu([0.1, 0.2]),x)) + end + @test isapprox(cpu_grad[1], gpu_grad[1] |> Array) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b71ad52..a563f68 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using ReinforcementLearningCore using Random using Test using StatsBase -using Distributions: probs +using Distributions: probs, Normal, logpdf using ReinforcementLearningEnvironments using Flux using Zygote