Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit ab738de

Browse files
authored
Fix some corner cases when sampling a SumTree (#83)
* fix bug in sum tree * sampling trick =。=
1 parent 1a5679a commit ab738de

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/utils/sum_tree.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ function Base.empty!(t::SumTree)
124124
t
125125
end
126126

127-
"!!! this is unsafe, always check the `real_ind`, or you may get bound error in some extreme cases."
128127
function Base.get(t::SumTree, v)
129128
parent_ind = 1
130129
leaf_ind = parent_ind
@@ -143,6 +142,9 @@ function Base.get(t::SumTree, v)
143142
end
144143
end
145144
end
145+
if leaf_ind <= t.nparents
146+
leaf_ind += t.capacity
147+
end
146148
p = t.tree[leaf_ind]
147149
ind = leaf_ind - t.nparents
148150
real_ind = ind >= t.first ? ind - t.first + 1 : ind + t.capacity - t.first + 1
@@ -152,10 +154,11 @@ end
152154
sample(rng::AbstractRNG, t::SumTree{T}) where {T} = get(t, rand(rng, T) * t.tree[1])
153155
sample(t::SumTree) = sample(Random.GLOBAL_RNG, t)
154156

155-
function sample(rng::AbstractRNG, t::SumTree, n::Int)
157+
function sample(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
156158
inds, priorities = Vector{Int}(undef, n), Vector{Float64}(undef, n)
157159
for i in 1:n
158-
ind, p = sample(rng, t)
160+
v = (i-1+rand(rng,T)) / n
161+
ind, p = get(t, v * t.tree[1])
159162
inds[i] = ind
160163
priorities[i] = p
161164
end

0 commit comments

Comments
 (0)