From cef1840d7f6976de0d958bf888f74ce0038f325e Mon Sep 17 00:00:00 2001 From: findmyway Date: Mon, 21 Dec 2020 00:41:40 +0000 Subject: [PATCH] Format .jl files --- src/policies/agents/agent.jl | 3 ++- src/policies/q_based_policies/q_based_policy.jl | 8 ++++---- src/policies/random_policy.jl | 16 ++++++++++++---- src/policies/random_start_policy.jl | 2 +- test/components/agents.jl | 3 +-- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 2f09fd8..4d2cfde 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -24,7 +24,8 @@ functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy (agent::Agent)(env) = agent.policy(env) function check(agent::Agent, env::AbstractEnv) - if ActionStyle(env) === FULL_ACTION_SET && !haskey(agent.trajectory, :legal_actions_mask) + if ActionStyle(env) === FULL_ACTION_SET && + !haskey(agent.trajectory, :legal_actions_mask) @warn "The env[$(nameof(env))] is of FULL_ACTION_SET, but I can not find a trace named :legal_actions_mask in the trajectory" end check(agent.policy, env) diff --git a/src/policies/q_based_policies/q_based_policy.jl b/src/policies/q_based_policies/q_based_policy.jl index 7d4b184..7a15643 100644 --- a/src/policies/q_based_policies/q_based_policy.jl +++ b/src/policies/q_based_policies/q_based_policy.jl @@ -20,11 +20,11 @@ Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y. (π::QBasedPolicy)(env) = π(env, ActionStyle(env)) (π::QBasedPolicy)(env, ::MinimalActionSet) = π.explorer(π.learner(env)) -(π::QBasedPolicy)(env, ::FullActionSet) = π.explorer(π.learner(env), legal_action_space_mask(env)) +(π::QBasedPolicy)(env, ::FullActionSet) = + π.explorer(π.learner(env), legal_action_space_mask(env)) RLBase.prob(p::QBasedPolicy, env) = prob(p, env, ActionStyle(env)) -RLBase.prob(p::QBasedPolicy, env, ::MinimalActionSet) = - prob(p.explorer, p.learner(env)) +RLBase.prob(p::QBasedPolicy, env, ::MinimalActionSet) = prob(p.explorer, p.learner(env)) RLBase.prob(p::QBasedPolicy, env, ::FullActionSet) = prob(p.explorer, p.learner(env), legal_action_space_mask(env)) @@ -36,7 +36,7 @@ RLBase.update!(p::QBasedPolicy, trajectory::AbstractTrajectory) = function check(p::QBasedPolicy, env::AbstractEnv) A = action_space(env) if (A isa AbstractVector && A == 1:length(A)) || - (A isa Tuple && A == Tuple(1:length(A))) + (A isa Tuple && A == Tuple(1:length(A))) # this is expected else @warn "Applying a QBasedPolicy to an environment with a unknown action space. Maybe convert the environment with `discrete2standard_discrete` in ReinforcementLearningEnvironments.jl first or redesign the environment." diff --git a/src/policies/random_policy.jl b/src/policies/random_policy.jl index d3c690e..67b0c17 100644 --- a/src/policies/random_policy.jl +++ b/src/policies/random_policy.jl @@ -20,19 +20,23 @@ end Random.seed!(p::RandomPolicy, seed) = Random.seed!(p.rng, seed) -RandomPolicy(s=nothing; rng = Random.GLOBAL_RNG) = RandomPolicy(s, rng) +RandomPolicy(s = nothing; rng = Random.GLOBAL_RNG) = RandomPolicy(s, rng) (p::RandomPolicy{Nothing})(env) = rand(p.rng, legal_action_space(env)) (p::RandomPolicy)(env) = rand(p.rng, p.action_space) function RLBase.prob(p::RandomPolicy{<:Union{AbstractVector,Tuple}}, env::AbstractEnv) n = length(p.action_space) - Categorical(fill(1/n, n); check_args=false) + Categorical(fill(1 / n, n); check_args = false) end RLBase.prob(p::RandomPolicy{Nothing}, env::AbstractEnv) = prob(p, env, ChanceStyle(env)) -function RLBase.prob(p::RandomPolicy{Nothing}, env::AbstractEnv, ::RLBase.AbstractChanceStyle) +function RLBase.prob( + p::RandomPolicy{Nothing}, + env::AbstractEnv, + ::RLBase.AbstractChanceStyle, +) mask = legal_action_space_mask(env) n = sum(mask) prob = zeros(length(mask)) @@ -40,7 +44,11 @@ function RLBase.prob(p::RandomPolicy{Nothing}, env::AbstractEnv, ::RLBase.Abstra prob end -function RLBase.prob(p::RandomPolicy{Nothing}, env::AbstractEnv, ::RLBase.ExplicitStochastic) +function RLBase.prob( + p::RandomPolicy{Nothing}, + env::AbstractEnv, + ::RLBase.ExplicitStochastic, +) if current_player(env) == chance_player(env) prob(env, chance_player(env)) else diff --git a/src/policies/random_start_policy.jl b/src/policies/random_start_policy.jl index 17cafa6..8aef0a5 100644 --- a/src/policies/random_start_policy.jl +++ b/src/policies/random_start_policy.jl @@ -25,4 +25,4 @@ for f in (:prob, :priority) $f(p.random_policy, args...) end end -end \ No newline at end of file +end diff --git a/test/components/agents.jl b/test/components/agents.jl index 54a1e1b..55e80e0 100644 --- a/test/components/agents.jl +++ b/test/components/agents.jl @@ -1,2 +1 @@ -@testset "Agent" begin -end +@testset "Agent" begin end