@@ -2,20 +2,41 @@ using Distributions: Distributions
22using Bijectors: Bijectors
33using Distributions: Univariate, Multivariate, Matrixvariate
44
5+ """
6+ Base type for distribution wrappers.
7+ """
8+ abstract type WrappedDistribution{variate,support,Td<: Distribution{variate,support} } < :
9+ Distribution{variate,support} end
10+
11+ wrapped_dist_type (:: Type{<:WrappedDistribution{<:Any,<:Any,Td}} ) where {Td} = Td
12+ wrapped_dist_type (d:: WrappedDistribution ) = wrapped_dist_type (d)
13+
14+ wrapped_dist (d:: WrappedDistribution ) = d. dist
15+
16+ Base. length (d:: WrappedDistribution{<:Multivariate} ) = length (wrapped_dist (d))
17+ Base. size (d:: WrappedDistribution{<:Multivariate} ) = size (wrapped_dist (d))
18+ Base. eltype (:: Type{T} ) where {T<: WrappedDistribution } = eltype (wrapped_dist_type (T))
19+ Base. eltype (d:: WrappedDistribution ) = eltype (wrapped_dist_type (d))
20+
21+ function Distributions. rand (rng:: Random.AbstractRNG , d:: WrappedDistribution )
22+ return rand (rng, wrapped_dist (d))
23+ end
24+ Distributions. minimum (d:: WrappedDistribution ) = minimum (wrapped_dist (d))
25+ Distributions. maximum (d:: WrappedDistribution ) = maximum (wrapped_dist (d))
26+
27+ Bijectors. bijector (d:: WrappedDistribution ) = bijector (wrapped_dist (d))
28+
529"""
630A named distribution that carries the name of the random variable with it.
731"""
832struct NamedDist{variate,support,Td<: Distribution{variate,support} ,Tv<: VarName } < :
9- Distribution {variate,support}
33+ WrappedDistribution {variate,support,Td }
1034 dist:: Td
1135 name:: Tv
1236end
1337
1438NamedDist (dist:: Distribution , name:: Symbol ) = NamedDist (dist, VarName {name} ())
1539
16- Base. length (dist:: NamedDist ) = Base. length (dist. dist)
17- Base. size (dist:: NamedDist ) = Base. size (dist. dist)
18-
1940Distributions. logpdf (dist:: NamedDist , x:: Real ) = Distributions. logpdf (dist. dist, x)
2041function Distributions. logpdf (dist:: NamedDist , x:: AbstractArray{<:Real} )
2142 return Distributions. logpdf (dist. dist, x)
@@ -27,29 +48,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
2748 return Distributions. loglikelihood (dist. dist, x)
2849end
2950
30- Bijectors. bijector (d:: NamedDist ) = Bijectors. bijector (d. dist)
51+ """
52+ Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
3153
54+ Note that *SampleFromPrior* would still sample from `Td`.
55+ """
3256struct NoDist{variate,support,Td<: Distribution{variate,support} } < :
33- Distribution {variate,support}
57+ WrappedDistribution {variate,support,Td }
3458 dist:: Td
3559end
3660NoDist (dist:: NamedDist ) = NamedDist (NoDist (dist. dist), dist. name)
3761
3862nodist (dist:: Distribution ) = NoDist (dist)
3963nodist (dists:: AbstractArray ) = nodist .(dists)
4064
41- Base. length (dist:: NoDist ) = Base. length (dist. dist)
42- Base. size (dist:: NoDist ) = Base. size (dist. dist)
43-
4465Distributions. rand (rng:: Random.AbstractRNG , d:: NoDist ) = rand (rng, d. dist)
4566Distributions. logpdf (d:: NoDist{<:Univariate} , :: Real ) = 0
4667Distributions. logpdf (d:: NoDist{<:Multivariate} , :: AbstractVector{<:Real} ) = 0
4768function Distributions. logpdf (d:: NoDist{<:Multivariate} , x:: AbstractMatrix{<:Real} )
4869 return zeros (Int, size (x, 2 ))
4970end
5071Distributions. logpdf (d:: NoDist{<:Matrixvariate} , :: AbstractMatrix{<:Real} ) = 0
51- Distributions. minimum (d:: NoDist ) = minimum (d. dist)
52- Distributions. maximum (d:: NoDist ) = maximum (d. dist)
5372
5473Bijectors. logpdf_with_trans (d:: NoDist{<:Univariate} , :: Real , :: Bool ) = 0
5574function Bijectors. logpdf_with_trans (
@@ -67,5 +86,3 @@ function Bijectors.logpdf_with_trans(
6786)
6887 return 0
6988end
70-
71- Bijectors. bijector (d:: NoDist ) = Bijectors. bijector (d. dist)
0 commit comments