diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 79bf398e0..25172c5fd 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -50,6 +50,7 @@ jobs: - Down/Detector - Down/DifferentiateWith - Down/Flux + - Down/Lux exclude: # lts - version: 'lts' @@ -74,6 +75,8 @@ jobs: group: Down/Detector - version: 'lts' group: Down/Flux + - version: 'lts' + group: Down/Lux # pre-release - version: 'pre' group: Formalities diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index 4c21cd9ec..c7ef89c1e 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(["Enzyme", "FiniteDifferences", "Flux", "Zygote"]) +Pkg.add(["FiniteDifferences", "Enzyme", "Flux", "Zygote"]) using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl new file mode 100644 index 000000000..bbe294e34 --- /dev/null +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -0,0 +1,22 @@ +using Pkg +Pkg.add(["FiniteDiff", "Lux", "LuxTestUtils", "Zygote"]) + +using ComponentArrays: ComponentArrays +using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT +using FiniteDiff: FiniteDiff +using Lux: Lux +using LuxTestUtils: LuxTestUtils +using Random + +Random.seed!(0) + +test_differentiation( + AutoZygote(), + DIT.lux_scenarios(Random.Xoshiro(63)); + isequal=DIT.lux_isequal, + isapprox=DIT.lux_isapprox, + rtol=1.0f-2, + atol=1.0f-3, + logging=LOGGING, +) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 74953fb50..35fea5311 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -21,15 +21,20 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" +DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDiff", "Lux", "LuxTestUtils"] DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" [compat] @@ -40,8 +45,8 @@ ComponentArrays = "0.15" DataFrames = "1.6.1" DifferentiationInterface = "0.5.6" DocStringExtensions = "0.8,0.9" -Flux = "0.13,0.14" FiniteDifferences = "0.12" +Flux = "0.13,0.14" Functors = "0.4" JET = "0.4 - 0.8, 0.9" JLArrays = "0.1" @@ -68,6 +73,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -78,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Lux", "LuxTestUtils", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl new file mode 100644 index 000000000..cd0b7ae66 --- /dev/null +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -0,0 +1,138 @@ +module DifferentiationInterfaceTestLuxExt + +using Compat: @compat +using ComponentArrays: ComponentArray +import DifferentiationInterface as DI +using DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT +using FiniteDiff: FiniteDiff +using Lux +using LuxTestUtils +using LuxTestUtils: check_approx +using Random: AbstractRNG, default_rng + +#= +Relevant discussions: + +- https://github.com/LuxDL/Lux.jl/issues/769 +=# + +function DIT.lux_isequal(a, b) + return check_approx(a, b; atol=0, rtol=0) +end + +function DIT.lux_isapprox(a, b; atol, rtol) + return check_approx(a, b; atol, rtol) +end + +struct SquareLoss{M,X,S} + model::M + x::X + st::S +end + +function (sql::SquareLoss)(ps) + @compat (; model, x, st) = sql + return sum(abs2, first(model(x, ps, st))) +end + +function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) + models_and_xs = [ + (Dense(2, 4), randn(rng, Float32, 2, 3)), + (Dense(2, 4, gelu), randn(rng, Float32, 2, 3)), + (Dense(2, 4, gelu; use_bias=false), randn(rng, Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(rng, Float32, 2, 3)), + (Scale(2), randn(rng, Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(rng, Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(rng, Float32, 3, 3, 2, 2)), + ( + Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), + randn(rng, Float32, 3, 3, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), + rand(rng, Float32, 5, 5, 2, 2), + ), + ( + Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), + rand(rng, Float32, 5, 5, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), + rand(rng, Float32, 5, 5, 2, 2), + ), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(rng, Float32, 5, 2)), + (Bilinear((2, 2) => 3), randn(rng, Float32, 2, 3)), + (SkipConnection(Dense(2 => 2), vcat), randn(rng, Float32, 2, 3)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(rng, Float32, 5, 5, 3, 1)), + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(rng, Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(rng, Float32, 3, 2)), + ( + StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), + rand(rng, Float32, 3, 2), + ), + ( + Chain( + StatefulRecurrentCell(RNNCell(3 => 5)), + StatefulRecurrentCell(RNNCell(5 => 3)), + ), + rand(rng, Float32, 3, 2), + ), + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(rng, Float32, 3, 2)), + ( + Chain( + StatefulRecurrentCell(LSTMCell(3 => 5)), + StatefulRecurrentCell(LSTMCell(5 => 3)), + ), + rand(rng, Float32, 3, 2), + ), + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(rng, Float32, 3, 10)), + ( + Chain( + StatefulRecurrentCell(GRUCell(3 => 5)), + StatefulRecurrentCell(GRUCell(5 => 3)), + ), + rand(rng, Float32, 3, 10), + ), + (Chain(Dense(2, 4), BatchNorm(4)), randn(rng, Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(rng, Float32, 2, 3)), + ( + Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), + randn(rng, Float32, 2, 3), + ), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(rng, Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(rng, Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(rng, Float32, 6, 6, 2, 2)), + ( + Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), + randn(rng, Float32, 6, 6, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), + randn(rng, Float32, 4, 4, 2, 2), + ), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), + ( + Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), + randn(rng, Float32, 6, 6, 2, 2), + ), + ] + + scens = Scenario[] + + for (model, x) in models_and_xs + ps, st = Lux.setup(rng, model) + ps = ComponentArray(ps) + loss = SquareLoss(model, x, st) + l = loss(ps) + g = DI.gradient(loss, DI.AutoFiniteDiff(), ps) + scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace) + push!(scens, scen) + end + + return scens +end + +end diff --git a/DifferentiationInterfaceTest/src/scenarios/extensions.jl b/DifferentiationInterfaceTest/src/scenarios/extensions.jl index 9c16c857e..83c1ce284 100644 --- a/DifferentiationInterfaceTest/src/scenarios/extensions.jl +++ b/DifferentiationInterfaceTest/src/scenarios/extensions.jl @@ -34,7 +34,7 @@ function gpu_scenarios end Create a vector of [`Scenario`](@ref)s with neural networks from [Flux.jl](https://github.com/FluxML/Flux.jl). !!! warning - This function requires Flux.jl and FiniteDifferences.jl to be loaded (it is implemented in a package extension). + This function requires FiniteDifferences.jl and Flux.jl to be loaded (it is implemented in a package extension). !!! danger These scenarios are still experimental and not part of the public API. @@ -55,3 +55,31 @@ function flux_isapprox end Exact comparison function to use in correctness tests with gradients of Flux.jl networks. """ function flux_isequal end + +""" + lux_scenarios(rng=Random.default_rng()) + +Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https://github.com/LuxDL/Lux.jl). + +!!! warning + This function requires ComponentArrays.jl, FiniteDiff.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension). + +!!! danger + These scenarios are still experimental and not part of the public API. + Their ground truth values are computed with finite differences, and thus subject to imprecision. +""" +function lux_scenarios end + +""" + lux_isapprox(x, y; atol, rtol) + +Approximate comparison function to use in correctness tests with gradients of Lux.jl networks. +""" +function lux_isapprox end + +""" + lux_isequal(x, y) + +Exact comparison function to use in correctness tests with gradients of Lux.jl networks. +""" +function lux_isequal end diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index ba86d363b..b8475b528 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -8,6 +8,8 @@ using FiniteDifferences: FiniteDifferences using Flux: Flux using ForwardDiff: ForwardDiff using JLArrays: JLArrays +using Lux: Lux +using LuxTestUtils: LuxTestUtils using Random using SparseConnectivityTracer using SparseMatrixColorings @@ -43,3 +45,13 @@ test_differentiation( atol=1e-6, logging=LOGGING, ) + +test_differentiation( + AutoZygote(), + DIT.lux_scenarios(Random.Xoshiro(63)); + isequal=DIT.lux_isequal, + isapprox=DIT.lux_isapprox, + rtol=1.0f-2, + atol=1.0f-3, + logging=LOGGING, +)