Skip to content

Commit d63d370

Browse files
authored
More efficient PoissonBinomial (#1285)
1 parent afdcc69 commit d63d370

File tree

2 files changed

+66
-54
lines changed

2 files changed

+66
-54
lines changed

src/univariate/discrete/poissonbinomial.jl

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# TODO: this distribution may need clean-up
21
"""
32
PoissonBinomial(p)
43
@@ -24,35 +23,45 @@ External links:
2423
* [Poisson-binomial distribution on Wikipedia](http://en.wikipedia.org/wiki/Poisson_binomial_distribution)
2524
2625
"""
27-
struct PoissonBinomial{T<:Real} <: DiscreteUnivariateDistribution
28-
p::Vector{T}
29-
pmf::Vector{T}
30-
31-
function PoissonBinomial{T}(p::AbstractArray) where {T <: Real}
32-
pb = poissonbinomial_pdf(p)
33-
@assert isprobvec(pb)
34-
new{T}(p, pb)
26+
mutable struct PoissonBinomial{T<:Real,P<:AbstractVector{T}} <: DiscreteUnivariateDistribution
27+
p::P
28+
pmf::Union{Nothing,Vector{T}} # lazy computation of the probability mass function
29+
30+
function PoissonBinomial{T}(p::AbstractVector{T}; check_args=true) where {T <: Real}
31+
check_args && @check_args(PoissonBinomial, all(x -> zero(x) <= x <= one(x), p))
32+
return new{T,typeof(p)}(p, nothing)
3533
end
3634
end
3735

38-
function PoissonBinomial(p::AbstractArray{T}; check_args=true) where {T <: Real}
39-
if check_args
40-
for i in eachindex(p)
41-
@check_args(PoissonBinomial, 0 <= p[i] <= 1)
36+
function PoissonBinomial(p::AbstractVector{T}; check_args=true) where {T<:Real}
37+
return PoissonBinomial{T}(p; check_args=check_args)
38+
end
39+
40+
function Base.getproperty(d::PoissonBinomial, x::Symbol)
41+
if x === :pmf
42+
z = getfield(d, :pmf)
43+
if z === nothing
44+
y = poissonbinomial_pdf(d.p)
45+
isprobvec(y) || error("probability mass function is not normalized")
46+
setfield!(d, :pmf, y)
47+
return y
48+
else
49+
return z
4250
end
51+
else
52+
return getfield(d, x)
4353
end
44-
return PoissonBinomial{T}(p)
4554
end
4655

4756
@distr_support PoissonBinomial 0 length(d.p)
4857

4958
#### Conversions
5059

51-
function PoissonBinomial(::Type{PoissonBinomial{T}}, p::Vector{S}) where {T, S}
52-
return PoissonBinomial(Vector{T}(p))
60+
function PoissonBinomial(::Type{PoissonBinomial{T}}, p::AbstractVector{S}) where {T, S}
61+
return PoissonBinomial(AbstractVector{T}(p))
5362
end
5463
function PoissonBinomial(::Type{PoissonBinomial{T}}, d::PoissonBinomial{S}) where {T, S}
55-
return PoissonBinomial(Vector{T}(d.p), check_args=false)
64+
return PoissonBinomial(AbstractVector{T}(d.p), check_args=false)
5665
end
5766

5867
#### Parameters
@@ -67,7 +76,7 @@ partype(::PoissonBinomial{T}) where {T} = T
6776
#### Properties
6877

6978
mean(d::PoissonBinomial) = sum(succprob(d))
70-
var(d::PoissonBinomial) = sum(succprob(d) .* failprob(d))
79+
var(d::PoissonBinomial) = sum(p * (1 - p) for p in succprob(d))
7180

7281
function skewness(d::PoissonBinomial{T}) where {T}
7382
v = zero(T)
@@ -91,23 +100,27 @@ function kurtosis(d::PoissonBinomial{T}) where {T}
91100
s / v / v
92101
end
93102

94-
entropy(d::PoissonBinomial) = entropy(Categorical(d.pmf))
103+
entropy(d::PoissonBinomial) = entropy(d.pmf)
95104
median(d::PoissonBinomial) = median(Categorical(d.pmf)) - 1
96105
mode(d::PoissonBinomial) = argmax(d.pmf) - 1
97-
modes(d::PoissonBinomial) = [x - 1 for x in modes(Categorical(d.pmf))]
106+
modes(d::PoissonBinomial) = modes(DiscreteNonParametric(support(d), d.pmf))
98107

99108
#### Evaluation
100109

101110
quantile(d::PoissonBinomial, x::Float64) = quantile(Categorical(d.pmf), x) - 1
102111

103-
function mgf(d::PoissonBinomial{T}, t::Real) where {T}
104-
p, = params(d)
105-
prod(one(T) .- p .+ p .* exp(t))
112+
function mgf(d::PoissonBinomial, t::Real)
113+
expm1_t = expm1(t)
114+
mapreduce(*, succprob(d)) do p
115+
1 + p * expm1_t
116+
end
106117
end
107118

108-
function cf(d::PoissonBinomial{T}, t::Real) where {T}
109-
p, = params(d)
110-
prod(one(T) .- p .+ p .* cis(t))
119+
function cf(d::PoissonBinomial, t::Real)
120+
cis_t = cis(t)
121+
mapreduce(*, succprob(d)) do p
122+
1 - p + p * cis_t
123+
end
111124
end
112125

113126
pdf(d::PoissonBinomial, k::Real) = insupport(d, k) ? d.pmf[k+1] : zero(eltype(d.pmf))
@@ -120,19 +133,17 @@ logpdf(d::PoissonBinomial, k::Real) = log(pdf(d, k))
120133
# Calculating binomial probabilities when the trial probabilities are unequal,
121134
# Journal of Statistical Computation and Simulation, 14:2, 125-131, DOI: 10.1080/00949658208810534
122135
#
123-
function poissonbinomial_pdf(p::AbstractArray{T,1}) where {T <: Real}
124-
n = length(p)
125-
S = zeros(T, n+1)
126-
S[1] = 1-p[1]
127-
S[2] = p[1]
128-
@inbounds for col in 2:n
129-
for r in 1:col
130-
row = col - r + 1
131-
S[row+1] = (1-p[col])*S[row+1] + p[col] * S[row]
136+
function poissonbinomial_pdf(p)
137+
S = zeros(eltype(p), length(p) + 1)
138+
S[1] = 1
139+
@inbounds for (col, p_col) in enumerate(p)
140+
q_col = 1 - p_col
141+
for row in col:(-1):1
142+
S[row + 1] = q_col * S[row + 1] + p_col * S[row]
143+
end
144+
S[1] *= q_col
132145
end
133-
S[1] *= 1-p[col]
134-
end
135-
return S
146+
return S
136147
end
137148

138149
# Computes the pdf of a poisson-binomial random variable using

test/poissonbinomial.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ naive_sol = naive_pb(p)
2929

3030
@test Distributions.poissonbinomial_pdf_fft(p) naive_sol
3131
@test Distributions.poissonbinomial_pdf(p) naive_sol
32+
@test Distributions.poissonbinomial_pdf(Tuple(p)) naive_sol
3233

3334
@test Distributions.poissonbinomial_pdf_fft(p) Distributions.poissonbinomial_pdf(p)
3435

@@ -46,23 +47,23 @@ for (p, n) in [(0.8, 6), (0.5, 10), (0.04, 20)]
4647
@test maximum(d) == n
4748
@test extrema(d) == (0, n)
4849
@test ntrials(d) == n
49-
@test entropy(d) entropy(dref)
50-
@test median(d) median(dref)
51-
@test mean(d) mean(dref)
52-
@test var(d) var(dref)
53-
@test kurtosis(d) kurtosis(dref)
54-
@test skewness(d) skewness(dref)
50+
@test @inferred(entropy(d)) entropy(dref)
51+
@test @inferred(median(d)) median(dref)
52+
@test @inferred(mean(d)) mean(dref)
53+
@test @inferred(var(d)) var(dref)
54+
@test @inferred(kurtosis(d)) kurtosis(dref)
55+
@test @inferred(skewness(d)) skewness(dref)
5556

5657
for t=0:5
57-
@test mgf(d, t) mgf(dref, t)
58-
@test cf(d, t) cf(dref, t)
58+
@test @inferred(mgf(d, t)) mgf(dref, t)
59+
@test @inferred(cf(d, t)) cf(dref, t)
5960
end
6061
for i=0.1:0.1:.9
61-
@test quantile(d, i) quantile(dref, i)
62+
@test @inferred(quantile(d, i)) quantile(dref, i)
6263
end
6364
for i=0:n
64-
@test isapprox(cdf(d, i), cdf(dref, i), atol=1e-15)
65-
@test isapprox(pdf(d, i), pdf(dref, i), atol=1e-15)
65+
@test isapprox(@inferred(cdf(d, i)), cdf(dref, i), atol=1e-15)
66+
@test isapprox(@inferred(pdf(d, i)), pdf(dref, i), atol=1e-15)
6667
end
6768

6869
end
@@ -88,11 +89,11 @@ for (n₁, n₂, n₃, p₁, p₂, p₃) in [(10, 10, 10, 0.1, 0.5, 0.9),
8889
pmf2 = pdf.(b2, support(b2))
8990
pmf3 = pdf.(b3, support(b3))
9091

91-
@test mean(d) (mean(b1) + mean(b2) + mean(b3))
92-
@test var(d) (var(b1) + var(b2) + var(b3))
92+
@test @inferred(mean(d)) (mean(b1) + mean(b2) + mean(b3))
93+
@test @inferred(var(d)) (var(b1) + var(b2) + var(b3))
9394
for t=0:5
94-
@test mgf(d, t) (mgf(b1, t) * mgf(b2, t) * mgf(b3, t))
95-
@test cf(d, t) (cf(b1, t) * cf(b2, t) * cf(b3, t))
95+
@test @inferred(mgf(d, t)) (mgf(b1, t) * mgf(b2, t) * mgf(b3, t))
96+
@test @inferred(cf(d, t)) (cf(b1, t) * cf(b2, t) * cf(b3, t))
9697
end
9798

9899
for k=0:n
@@ -104,7 +105,7 @@ for (n₁, n₂, n₃, p₁, p₂, p₃) in [(10, 10, 10, 0.1, 0.5, 0.9),
104105
end
105106
m += pmf1[i+1] * mc
106107
end
107-
@test isapprox(pdf(d, k), m, atol=1e-15)
108+
@test isapprox(@inferred(pdf(d, k)), m, atol=1e-15)
108109
end
109110
end
110111

0 commit comments

Comments
 (0)