diff --git a/src/utils/sum_tree.jl b/src/utils/sum_tree.jl index 2d97fcc..422142c 100644 --- a/src/utils/sum_tree.jl +++ b/src/utils/sum_tree.jl @@ -124,7 +124,6 @@ function Base.empty!(t::SumTree) t end -"!!! this is unsafe, always check the `real_ind`, or you may get bound error in some extreme cases." function Base.get(t::SumTree, v) parent_ind = 1 leaf_ind = parent_ind @@ -143,6 +142,9 @@ function Base.get(t::SumTree, v) end end end + if leaf_ind <= t.nparents + leaf_ind += t.capacity + end p = t.tree[leaf_ind] ind = leaf_ind - t.nparents real_ind = ind >= t.first ? ind - t.first + 1 : ind + t.capacity - t.first + 1 @@ -152,10 +154,11 @@ end sample(rng::AbstractRNG, t::SumTree{T}) where {T} = get(t, rand(rng, T) * t.tree[1]) sample(t::SumTree) = sample(Random.GLOBAL_RNG, t) -function sample(rng::AbstractRNG, t::SumTree, n::Int) +function sample(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T} inds, priorities = Vector{Int}(undef, n), Vector{Float64}(undef, n) for i in 1:n - ind, p = sample(rng, t) + v = (i-1+rand(rng,T)) / n + ind, p = get(t, v * t.tree[1]) inds[i] = ind priorities[i] = p end