diff --git a/src/policies/policies.jl b/src/policies/policies.jl index ca0cf6d..06ccea3 100644 --- a/src/policies/policies.jl +++ b/src/policies/policies.jl @@ -2,3 +2,4 @@ include("base.jl") include("agents/agents.jl") include("q_based_policies/q_based_policies.jl") include("random_policy.jl") +include("random_start_policy.jl") diff --git a/src/policies/q_based_policies/q_based_policy.jl b/src/policies/q_based_policies/q_based_policy.jl index 4eb8b84..7fe1ad1 100644 --- a/src/policies/q_based_policies/q_based_policy.jl +++ b/src/policies/q_based_policies/q_based_policy.jl @@ -19,9 +19,8 @@ end Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learner (π::QBasedPolicy)(env) = π(env, ActionStyle(env)) -(π::QBasedPolicy)(env, ::MinimalActionSet) = action_space(env)[π.explorer(π.learner(env))] -(π::QBasedPolicy)(env, ::FullActionSet) = - action_space(env)[π.explorer(π.learner(env), legal_action_space_mask(env))] +(π::QBasedPolicy)(env, ::MinimalActionSet) = π.explorer(π.learner(env)) +(π::QBasedPolicy)(env, ::FullActionSet) = π.explorer(π.learner(env), legal_action_space_mask(env)) RLBase.prob(p::QBasedPolicy, env) = prob(p, env, ActionStyle(env)) RLBase.prob(p::QBasedPolicy, env, ::MinimalActionSet) =