From 9179e892b0c2d6589eda7649e701fe33c92cc642 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 21 Aug 2020 01:17:27 +0800 Subject: [PATCH] support mask for more explorers --- src/components/explorers/batch_explorer.jl | 3 +++ .../explorers/gumbel_softmax_explorer.jl | 5 ++++ src/components/explorers/weighted_explorer.jl | 16 ++++++++--- .../explorers/weighted_softmax_explorer.jl | 27 +++++++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 src/components/explorers/weighted_softmax_explorer.jl diff --git a/src/components/explorers/batch_explorer.jl b/src/components/explorers/batch_explorer.jl index 0b13af1..45cf384 100644 --- a/src/components/explorers/batch_explorer.jl +++ b/src/components/explorers/batch_explorer.jl @@ -16,6 +16,9 @@ Apply inner explorer to each column of `values`. """ (x::BatchExplorer)(values::AbstractMatrix) = [x.explorer(v) for v in eachcol(values)] +(x::BatchExplorer)(values::AbstractMatrix, mask::AbstractMatrix) = [x.explorer(v,m) for (v,m) in zip(eachcol(values), eachcol(mask))] + (x::BatchExplorer)(v::AbstractVector) = x.explorer(v) +(x::BatchExplorer)(v::AbstractVector, m::AbstractVector) = x.explorer(v,m) Flux.testmode!(x::BatchExplorer, mode = true) = testmode!(x.explorer, mode) diff --git a/src/components/explorers/gumbel_softmax_explorer.jl b/src/components/explorers/gumbel_softmax_explorer.jl index da8a738..5e2861e 100644 --- a/src/components/explorers/gumbel_softmax_explorer.jl +++ b/src/components/explorers/gumbel_softmax_explorer.jl @@ -14,3 +14,8 @@ function (p::GumbelSoftmaxExplorer)(v::AbstractVector{T}) where {T} u = rand(p.rng, T, length(logits)) argmax(logits .- log.(-log.(u))) end + +function (p::GumbelSoftmaxExplorer)(v::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} + v[.!mask] .= typemin(T) + p(v) +end diff --git a/src/components/explorers/weighted_explorer.jl b/src/components/explorers/weighted_explorer.jl index 77ddc5c..accfbaa 100644 --- a/src/components/explorers/weighted_explorer.jl +++ b/src/components/explorers/weighted_explorer.jl @@ -2,13 +2,17 @@ export WeightedExplorer using Random using StatsBase: sample, Weights -using Flux: softmax """ - WeightedExplorer(;is_normalized::Bool) + WeightedExplorer(;is_normalized::Bool, rng=Random.GLOBAL_RNG) `is_normalized` is used to indicate if the feeded action values are alrady normalized to have a sum of `1.0`. + +!!! warning + Elements are assumed to be `>=0`. + +See also: [`WeightedSoftmaxExplorer`](@ref) """ struct WeightedExplorer{T,R<:AbstractRNG} <: AbstractExplorer rng::R @@ -21,6 +25,10 @@ end (s::WeightedExplorer{true})(values::AbstractVector{T}) where {T} = sample(s.rng, Weights(values, one(T))) -# ??? add a softmax layer here? (s::WeightedExplorer{false})(values::AbstractVector{T}) where {T} = - sample(s.rng, Weights(softmax(values), one(T))) + sample(s.rng, Weights(values)) + +function (s::WeightedExplorer)(values, mask) + values[.!mask] .= 0 + s(values) +end diff --git a/src/components/explorers/weighted_softmax_explorer.jl b/src/components/explorers/weighted_softmax_explorer.jl new file mode 100644 index 0000000..0bdf5d3 --- /dev/null +++ b/src/components/explorers/weighted_softmax_explorer.jl @@ -0,0 +1,27 @@ +export WeightedSoftmaxExplorer + +using Random +using StatsBase: sample, Weights +using Flux: softmax + +""" + WeightedSoftmaxExplorer(;rng=Random.GLOBAL_RNG) + +See also: [`WeightedExplorer`](@ref) +""" +struct WeightedSoftmaxExplorer{R<:AbstractRNG} <: AbstractExplorer + rng::R +end + +function WeightedSoftmaxExplorer(;rng = Random.GLOBAL_RNG) + WeightedSoftmaxExplorer(rng) +end + +(s::WeightedSoftmaxExplorer)(values::AbstractVector{T}) where {T} = + sample(s.rng, Weights(softmax(values), one(T))) + +function (s::WeightedSoftmaxExplorer)(values::AbstractVector{T}, mask) where T + values[.!mask] .= typemin(T) + s(values) +end +