diff --git a/src/extensions/Distributions.jl b/src/extensions/Distributions.jl index 6a3cfbf..9d06fa1 100644 --- a/src/extensions/Distributions.jl +++ b/src/extensions/Distributions.jl @@ -1,5 +1,7 @@ export normlogpdf +using Distributions: DiscreteNonParametric, support, probs + # watch https://github.com/JuliaStats/Distributions.jl/issues/1183 """ @@ -11,3 +13,8 @@ function normlogpdf(μ, σ, x; ϵ = 1.0f-8) z = (x .- μ) ./ (σ .+ ϵ) -(z .^ 2 .+ log(2.0f0π)) / 2.0f0 .- log.(σ .+ ϵ) end + + +# watch https://github.com/JuliaStats/Distributions.jl/pull/1184 + +Base.convert(::Type{DiscreteNonParametric{T,P,Ts,Ps}}, d::DiscreteNonParametric) where {T,P,Ts,Ps} = DiscreteNonParametric{T,P,Ts,Ps}(convert(Ts, support(d)), convert(Ps, probs(d)), check_args=false) \ No newline at end of file diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 6d3736d..81690f8 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -16,8 +16,10 @@ AT.children( t::StructTree{T}, ) where {T<:Union{AbstractArray,MersenneTwister,ProgressMeter.Progress,Function}} = () AT.children(t::Pair{Symbol,<:StructTree}) = children(last(t)) -AT.printnode(io::IO, t::StructTree{<:Union{Number,Symbol}}) = print(io, t.x) +AT.children(t::StructTree{UnionAll}) = () +AT.printnode(io::IO, t::StructTree{<:Union{Number,Symbol}}) = print(io, t.x) +AT.printnode(io::IO, t::StructTree{UnionAll}) = print(io, t.x) AT.printnode(io::IO, t::StructTree{T}) where {T} = print(io, T.name) AT.printnode(io::IO, t::StructTree{<:AbstractArray}) where {T} = summary(io, t.x)