Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
name = "LogDensityProblemsAD"
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
authors = ["Tamás K. Papp <[email protected]>"]
version = "1.1.1"
version = "1.2.0"

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
EnzymeExt = "Enzyme"
ForwardDiffBenchmarkToolsExt = ["BenchmarkTools", "ForwardDiff"]
ForwardDiffExt = "ForwardDiff"
ReverseDiffExt = "ReverseDiff"
TrackerExt = "Tracker"
ZygoteExt = "Zygote"

[compat]
julia = "1.6"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -18,14 +34,13 @@ UnPack = "0.1, 1"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["BenchmarkTools", "ForwardDiff", "Pkg", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
test = ["BenchmarkTools", "Enzyme", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
11 changes: 5 additions & 6 deletions src/DiffResults_helpers.jl → ext/DiffResults_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#####
##### Helper functions for working with DiffResults
#####
##### Only included when required by AD wrappers.

import .DiffResults
###
### Helper functions for working with DiffResults.
### Only included when required by AD wrappers.
### Requires that `DiffResults` and `DocStringExtensions.SIGNATURES` are available.
###

"""
$(SIGNATURES)
Expand Down
19 changes: 18 additions & 1 deletion src/AD_Enzyme.jl → ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
import .Enzyme
"""
Gradient AD implementation using Enzyme.
"""
module EnzymeExt

using LogDensityProblems: logdensity
using LogDensityProblemsAD: ADGradientWrapper, EXTENSIONS_SUPPORTED
using UnPack: @unpack

import LogDensityProblems: logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
if EXTENSIONS_SUPPORTED
import Enzyme
else
import ..Enzyme
end

struct EnzymeGradientLogDensity{L,M<:Union{Enzyme.ForwardMode,Enzyme.ReverseMode},S} <: ADGradientWrapper
ℓ::L
Expand Down Expand Up @@ -61,3 +76,5 @@ function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme
Enzyme.Duplicated(x, ∂ℓ_∂x))
y, ∂ℓ_∂x
end

end # module
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
using .BenchmarkTools: @belapsed
using .ForwardDiff
"""
Utilities for benchmarking a log density problem with various chunk sizes using ForwardDiff.

Loaded when both ForwardDiff and BenchmarkTools are loaded.
"""
module ForwardDiffBenchmarkToolsExt

using DocStringExtensions: SIGNATURES
using LogDensityProblems: dimension, logdensity_and_gradient
using LogDensityProblemsAD: ADgradient, EXTENSIONS_SUPPORTED

if EXTENSIONS_SUPPORTED
using BenchmarkTools: @belapsed
using ForwardDiff: Chunk
else
using ..BenchmarkTools: @belapsed
using ..ForwardDiff: Chunk
end

import LogDensityProblemsAD: benchmark_ForwardDiff_chunks, heuristic_chunks

"""
$(SIGNATURES)
Expand Down Expand Up @@ -38,8 +56,10 @@ function benchmark_ForwardDiff_chunks(ℓ;
markprogress = true,
x = zeros(dimension(ℓ)))
map(chunks) do chunk
∇ℓ = ADgradient(Val(:ForwardDiff), ℓ; chunk = ForwardDiff.Chunk(chunk))
∇ℓ = ADgradient(Val(:ForwardDiff), ℓ; chunk = Chunk(chunk))
markprogress && print(".")
chunk => @belapsed logdensity_and_gradient($(∇ℓ), $(x))
end
end

end # module
27 changes: 22 additions & 5 deletions src/AD_ForwardDiff.jl → ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
#####
##### Gradient AD implementation using ForwardDiff
#####
"""
Gradient AD implementation using ForwardDiff.
"""
module ForwardDiffExt

import .ForwardDiff
using DocStringExtensions: SIGNATURES
using LogDensityProblems: dimension, logdensity
using LogDensityProblemsAD: ADGradientWrapper, EXTENSIONS_SUPPORTED
using UnPack: @unpack

import .ForwardDiff.DiffResults # should load DiffResults_helpers.jl
import LogDensityProblems: logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
if EXTENSIONS_SUPPORTED
import ForwardDiff
import ForwardDiff: DiffResults
else
import ..ForwardDiff
import ..ForwardDiff: DiffResults
end

# Load DiffResults helpers
include("DiffResults_helpers.jl")

struct ForwardDiffLogDensity{L, C} <: ADGradientWrapper
ℓ::L
Expand Down Expand Up @@ -51,3 +66,5 @@ function logdensity_and_gradient(fℓ::ForwardDiffLogDensity, x::AbstractVector)
result = ForwardDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x, gradientconfig)
_diffresults_extract(result)
end

end # module
30 changes: 24 additions & 6 deletions src/AD_ReverseDiff.jl → ext/ReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
#####
##### Gradient AD implementation using ReverseDiff
#####
"""
Gradient AD implementation using ReverseDiff.
"""
module ReverseDiffExt

import .ReverseDiff
using DocStringExtensions: SIGNATURES
using LogDensityProblems: dimension, logdensity
using LogDensityProblemsAD: ADGradientWrapper, EXTENSIONS_SUPPORTED
using UnPack: @unpack

import .ReverseDiff.DiffResults # should load DiffResults_helpers.jl
import LogDensityProblems: logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
if EXTENSIONS_SUPPORTED
import ReverseDiff
import ReverseDiff: DiffResults
else
import ..ReverseDiff
import ..ReverseDiff: DiffResults
end

# Load DiffResults helpers
include("DiffResults_helpers.jl")

struct ReverseDiffLogDensity{L,C} <: ADGradientWrapper
ℓ::L
Expand All @@ -29,7 +44,8 @@ By default, no tape is created.
However, if the log density contains branches, use of a compiled tape can lead to silently incorrect results.
"""
function ADgradient(::Val{:ReverseDiff}, ℓ;
compile::Union{Val{true},Val{false}}=Val(false), x::Union{Nothing,AbstractVector}=nothing)
compile::Union{Val{true},Val{false}}=Val(false),
x::Union{Nothing,AbstractVector}=nothing)
ReverseDiffLogDensity(ℓ, _compiledtape(ℓ, compile, x))
end

Expand Down Expand Up @@ -58,3 +74,5 @@ function logdensity_and_gradient(∇ℓ::ReverseDiffLogDensity, x::AbstractVecto
end
_diffresults_extract(result)
end

end # module
21 changes: 17 additions & 4 deletions src/AD_Tracker.jl → ext/TrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
#####
##### Gradient AD implementation using Tracker
#####
"""
Gradient AD implementation using Tracker.
"""
module TrackerExt

import .Tracker
using LogDensityProblems: logdensity
using LogDensityProblemsAD: ADGradientWrapper, EXTENSIONS_SUPPORTED
using UnPack: @unpack

import LogDensityProblems: logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
if EXTENSIONS_SUPPORTED
import Tracker
else
import ..Tracker
end

struct TrackerGradientLogDensity{L} <: ADGradientWrapper
ℓ::L
Expand All @@ -29,3 +40,5 @@ function logdensity_and_gradient(∇ℓ::TrackerGradientLogDensity, x::AbstractV
S = typeof(z + 0.0)
S(yval)::S, (S.(first(Tracker.data.(back(1)))))::Vector{S}
end

end # module
19 changes: 18 additions & 1 deletion src/AD_Zygote.jl → ext/ZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
import .Zygote
"""
Gradient AD implementation using Zygote.
"""
module ZygoteExt

using LogDensityProblems: logdensity
using LogDensityProblemsAD: ADGradientWrapper, EXTENSIONS_SUPPORTED
using UnPack: @unpack

import LogDensityProblems: logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
if EXTENSIONS_SUPPORTED
import Zygote
else
import ..Zygote
end

struct ZygoteGradientLogDensity{L} <: ADGradientWrapper
ℓ::L
Expand All @@ -19,3 +34,5 @@ function logdensity_and_gradient(∇ℓ::ZygoteGradientLogDensity, x::AbstractVe
y, back = Zygote.pullback(Base.Fix1(logdensity, ℓ), x)
y, first(back(Zygote.sensitivity(y)))
end

end # module
40 changes: 23 additions & 17 deletions src/LogDensityProblemsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ module LogDensityProblemsAD

export ADgradient

using DocStringExtensions: SIGNATURES, TYPEDEF
using Requires: @require
using UnPack: @unpack
using DocStringExtensions: SIGNATURES
import LogDensityProblems: logdensity, logdensity_and_gradient, capabilities, dimension
using LogDensityProblems: LogDensityOrder

Expand Down Expand Up @@ -65,23 +63,31 @@ function ADgradient(v::Val{kind}, P; kwargs...) where kind
throw(MethodError(ADgradient, (v, P)))
end

####
#### AD wrappers - specific
####

#####
##### Empty method definitions for easier discoverability and backward compatibility
#####
function benchmark_ForwardDiff_chunks end
function heuristic_chunks end

# Backward compatible AD wrappers on Julia versions that do not support extensions
# TODO: Replace with proper version
const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension)
if !EXTENSIONS_SUPPORTED
using Requires: @require
end
function __init__()
@require DiffResults="163ba53b-c6d8-5494-b064-1a9d43ac40c5" include("DiffResults_helpers.jl")
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
include("AD_ForwardDiff.jl")
@require BenchmarkTools="6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" begin
include("ForwardDiff_benchmarking.jl")
@static if !EXTENSIONS_SUPPORTED
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/ForwardDiffExt.jl")
@require BenchmarkTools="6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" begin
include("../ext/ForwardDiffBenchmarkToolsExt.jl")
end
end

@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/TrackerExt.jl")
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/ZygoteExt.jl")
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl")
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" include("../ext/EnzymeExt.jl")
end
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("AD_Tracker.jl")
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("AD_Zygote.jl")
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("AD_ReverseDiff.jl")
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" include("AD_Enzyme.jl")
end

end # module
Loading