@@ -3,6 +3,7 @@ using Pkg
33Pkg. develop (; path= joinpath (@__DIR__ , " .." ))
44
55using DynamicPPL: DynamicPPL, make_benchmark_suite, VarInfo
6+ using ADTypes
67using BenchmarkTools: @benchmark , median, run
78using PrettyTables: PrettyTables, ft_printf
89using ForwardDiff: ForwardDiff
@@ -48,29 +49,44 @@ lda_instance = begin
4849 Models. lda (2 , d, w)
4950end
5051
52+ # AD types setup
53+ fd = AutoForwardDiff ()
54+ rd = AutoReverseDiff ()
55+ mc = AutoMooncake (; config= nothing )
56+ """
57+ get_adtype_shortname(adtype::ADTypes.AbstractADType)
58+
59+ Get the package name that corresponds to the the AD backend `adtype`. Only used
60+ for pretty-printing.
61+ """
62+ get_adtype_shortname (:: AutoMooncake ) = " Mooncake"
63+ get_adtype_shortname (:: AutoForwardDiff ) = " ForwardDiff"
64+ get_adtype_shortname (:: AutoReverseDiff{false} ) = " ReverseDiff"
65+ get_adtype_shortname (:: AutoReverseDiff{true} ) = " ReverseDiff:Compiled"
66+
5167# Specify the combinations to test:
5268# (Model Name, model instance, VarInfo choice, AD backend, linked)
5369chosen_combinations = [
5470 (
5571 " Simple assume observe" ,
5672 Models. simple_assume_observe (randn (StableRNG (23 ))),
5773 :typed ,
58- :forwarddiff ,
74+ fd ,
5975 false ,
6076 ),
61- (" Smorgasbord" , smorgasbord_instance, :typed , :forwarddiff , false ),
62- (" Smorgasbord" , smorgasbord_instance, :simple_namedtuple , :forwarddiff , true ),
63- (" Smorgasbord" , smorgasbord_instance, :untyped , :forwarddiff , true ),
64- (" Smorgasbord" , smorgasbord_instance, :simple_dict , :forwarddiff , true ),
65- (" Smorgasbord" , smorgasbord_instance, :typed , :reversediff , true ),
66- (" Smorgasbord" , smorgasbord_instance, :typed , :mooncake , true ),
67- (" Loop univariate 1k" , loop_univariate1k, :typed , :mooncake , true ),
68- (" Multivariate 1k" , multivariate1k, :typed , :mooncake , true ),
69- (" Loop univariate 10k" , loop_univariate10k, :typed , :mooncake , true ),
70- (" Multivariate 10k" , multivariate10k, :typed , :mooncake , true ),
71- (" Dynamic" , Models. dynamic (), :typed , :mooncake , true ),
72- (" Submodel" , Models. parent (randn (StableRNG (23 ))), :typed , :mooncake , true ),
73- (" LDA" , lda_instance, :typed , :reversediff , true ),
77+ (" Smorgasbord" , smorgasbord_instance, :typed , fd , false ),
78+ (" Smorgasbord" , smorgasbord_instance, :simple_namedtuple , fd , true ),
79+ (" Smorgasbord" , smorgasbord_instance, :untyped , fd , true ),
80+ (" Smorgasbord" , smorgasbord_instance, :simple_dict , fd , true ),
81+ (" Smorgasbord" , smorgasbord_instance, :typed , rd , true ),
82+ (" Smorgasbord" , smorgasbord_instance, :typed , mc , true ),
83+ (" Loop univariate 1k" , loop_univariate1k, :typed , mc , true ),
84+ (" Multivariate 1k" , multivariate1k, :typed , mc , true ),
85+ (" Loop univariate 10k" , loop_univariate10k, :typed , mc , true ),
86+ (" Multivariate 10k" , multivariate10k, :typed , mc , true ),
87+ (" Dynamic" , Models. dynamic (), :typed , mc , true ),
88+ (" Submodel" , Models. parent (randn (StableRNG (23 ))), :typed , mc , true ),
89+ (" LDA" , lda_instance, :typed , rd , true ),
7490]
7591
7692# Time running a model-like function that does not use DynamicPPL, as a reference point.
8399results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
84100
85101for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
86- @info " Running benchmark for $model_name "
102+ @info " Running benchmark for $model_name / $varinfo_choice / $( get_adtype_shortname (adbackend)) "
87103 suite = make_benchmark_suite (StableRNG (23 ), model, varinfo_choice, adbackend, islinked)
88104 results = run (suite)
89105 eval_time = median (results[" evaluation" ]). time
@@ -95,7 +111,7 @@ for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinati
95111 (
96112 model_name,
97113 model_dimension (model, islinked),
98- string (adbackend),
114+ get_adtype_shortname (adbackend),
99115 string (varinfo_choice),
100116 islinked,
101117 relative_eval_time,
0 commit comments