Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
DistributionsADForwardDiffExt = "ForwardDiff"
DistributionsADLazyArraysExt = "LazyArrays"
DistributionsADReverseDiffExt = "ReverseDiff"
DistributionsADTrackerExt = "Tracker"

Expand All @@ -38,7 +36,6 @@ Compat = "3.6, 4"
Distributions = "0.25.41"
FillArrays = "1.4.1"
ForwardDiff = "0.10.12, 1"
LazyArrays = "1, 2"
LinearAlgebra = "<0.0.1, 1"
PDMats = "0.9, 0.10, 0.11"
Random = "<0.0.1, 1"
Expand All @@ -53,6 +50,5 @@ julia = "1.6.5"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
7 changes: 1 addition & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# API

## Functions

```@docs
filldist
arraydist
```
This package provides automatic differentiation support for distributions in Distributions.jl.
52 changes: 0 additions & 52 deletions ext/DistributionsADLazyArraysExt.jl

This file was deleted.

17 changes: 1 addition & 16 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,24 @@ export TuringScalMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart,
arraydist,
filldist
TuringInverseWishart

include("common.jl")
include("arraydist.jl")
include("filldist.jl")
include("univariate.jl")
include("multivariate.jl")
include("matrixvariate.jl")
include("flatten.jl")

include("zygote.jl")

# Empty definition, function requires the LazyArrays extension
function lazyarray end
export lazyarray

if !isdefined(Base, :get_extension)
using Requires
end
function __init__()
# Better error message if users forget to load LazyArrays
Base.Experimental.register_error_hint(MethodError) do io, exc, arg_types, kwargs
if exc.f === lazyarray
print(io, "\\nDid you forget to load LazyArrays?")
end
end
@static if !isdefined(Base, :get_extension)
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("../ext/DistributionsADForwardDiffExt.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/DistributionsADReverseDiffExt.jl")
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/DistributionsADTrackerExt.jl")
@require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" include("../ext/DistributionsADLazyArraysExt.jl")
end
end

Expand Down
108 changes: 0 additions & 108 deletions src/arraydist.jl

This file was deleted.

123 changes: 0 additions & 123 deletions src/filldist.jl

This file was deleted.

8 changes: 0 additions & 8 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@ ZygoteRules.@adjoint function Distributions._logpdf(d::Product, x::AbstractVecto
sum(map(logpdf, d.v, x))
end
end
ZygoteRules.@adjoint function Distributions._logpdf(
d::FillVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return ZygoteRules.pullback(d, x) do d, x
_flat_logpdf(d.v.value, x)
end
end

# Loglikelihood of multi- and matrixvariate distributions: multiple samples
# workaround for Zygote issues discussed in
Expand Down
Loading