From 24d306ade564f9fc7705dc85aa7102b9710a0da1 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 1 Jul 2021 14:02:31 +0300 Subject: [PATCH 1/3] extract deep-kernel-learning example from st/examples (#234) --- examples/deep-kernel-learning/script.jl | 126 +++++++++++++----------- 1 file changed, 68 insertions(+), 58 deletions(-) diff --git a/examples/deep-kernel-learning/script.jl b/examples/deep-kernel-learning/script.jl index 14c1560c3..ebe279618 100644 --- a/examples/deep-kernel-learning/script.jl +++ b/examples/deep-kernel-learning/script.jl @@ -1,68 +1,78 @@ -# # Deep Kernel Learning -# -# !!! warning -# This example is under construction - -# Setup - +# # Deep Kernel Learning with Flux +# ## Package loading +# We use a couple of useful packages to plot and optimize +# the different hyper-parameters using KernelFunctions -using MLDataUtils -using Zygote using Flux using Distributions, LinearAlgebra using Plots +using ProgressMeter +using AbstractGPs +pyplot(); +default(; legendfontsize=15.0, linewidth=3.0); -Flux.@functor SqExponentialKernel -Flux.@functor KernelSum -Flux.@functor Matern32Kernel -Flux.@functor FunctionTransform - -# set up a kernel with a neural network feature extractor: - -neuralnet = Chain(Dense(1, 3), Dense(3, 2)) -k = SqExponentialKernel() ∘ FunctionTransform(neuralnet) - -# Generate date - +# ## Data creation +# We create a simple 1D Problem with very different variations xmin = -3; -xmax = 3; -x = range(xmin, xmax; length=100) -x_test = rand(Uniform(xmin, xmax), 200) -x, y = noisy_function(sinc, x; noise=0.1) -X = RowVecs(reshape(x, :, 1)) -X_test = RowVecs(reshape(x_test, :, 1)) -λ = [0.1] -#md nothing #hide - -# - -f(x, k, λ) = kernelmatrix(k, x, X) / (kernelmatrix(k, X) + exp(λ[1]) * I) * y -f(X, k, 1.0) - -# - -loss(k, λ) = (ŷ -> sum(y - ŷ) / length(y) + exp(λ[1]) * norm(ŷ))(f(X, k, λ)) -loss(k, λ) - -# - +xmax = 3; # Limits +N = 150 +noise = 0.01 +x_train = collect(eachrow(rand(Uniform(xmin, xmax), N))) # Training dataset +target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value +target_f(x::AbstractArray) = target_f(first(x)) +y_train = target_f.(x_train) + randn(N) * noise +x_test = collect(eachrow(range(xmin, xmax; length=200))) # Testing dataset +spectral_mixture_kernel() +# ## Model definition +# We create a neural net with 2 layers and 10 units each +# The data is passed through the NN before being used in the kernel +neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5)) +# We use two cases : +# - The Squared Exponential Kernel +k = transform(SqExponentialKernel(), FunctionTransform(neuralnet)) + +# We use AbstractGPs.jl to define our model +gpprior = GP(k) # GP Prior +fx = AbstractGPs.FiniteGP(gpprior, x_train, noise) # Prior on f +fp = posterior(fx, y_train) # Posterior of f + +# This compute the log evidence of `y`, +# which is going to be used as the objective +loss(y) = -logpdf(fx, y) + +@info "Init Loss = $(loss(y_train))" + +# Flux will automatically extract all the parameters of the kernel ps = Flux.params(k) -# push!(ps,λ) -opt = Flux.Momentum(1.0) -#md nothing #hide -# - -plots = [] -for i in 1:10 - grads = Zygote.gradient(() -> loss(k, λ), ps) +# We show the initial prediction with the untrained model +p_init = Plots.plot( + vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))" +) +Plots.scatter!(vcat(x_train...), y_train; lab="data") +pred = marginals(fp(x_test)) +Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction") +# ## Training +anim = Animation() +nmax = 1000 +opt = Flux.ADAM(0.1) +@showprogress for i in 1:nmax + global grads = gradient(ps) do + loss(y_train) + end Flux.Optimise.update!(opt, ps, grads) - p = Plots.scatter(x, y; lab="data", title="Loss = $(loss(k,λ))") - Plots.plot!(x, f(X, k, λ); lab="Prediction", lw=3.0) - push!(plots, p) + if i % 100 == 0 + @info "$i/$nmax" + L = loss(y_train) + # @info "Loss = $L" + p = Plots.plot( + vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))" + ) + p = Plots.scatter!(vcat(x_train...), y_train; lab="data") + pred = marginals(posterior(fx, y_train)(x_test)) + Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction") + frame(anim) + display(p) + end end - -# - -l = @layout grid(10, 1) -plot(plots...; layout=l, size=(300, 1500)) +gif(anim; fps=5) From 1c0464bd3d09172cabe3d1b236fa61c2f7aa59c8 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 1 Jul 2021 14:15:45 +0300 Subject: [PATCH 2/3] add missing dependencies --- examples/deep-kernel-learning/Manifest.toml | 12 ++++++++++++ examples/deep-kernel-learning/Project.toml | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/examples/deep-kernel-learning/Manifest.toml b/examples/deep-kernel-learning/Manifest.toml index c3a4f40f2..478c804df 100644 --- a/examples/deep-kernel-learning/Manifest.toml +++ b/examples/deep-kernel-learning/Manifest.toml @@ -6,6 +6,12 @@ git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.0.1" +[[AbstractGPs]] +deps = ["ChainRulesCore", "Distributions", "FillArrays", "KernelFunctions", "LinearAlgebra", "Random", "RecipesBase", "Reexport", "Statistics", "StatsBase"] +git-tree-sha1 = "d8b6584ff1d523dd1304671f2c8a557dad26e214" +uuid = "99985d1d-32ba-4be9-9821-2ec096f28918" +version = "0.3.6" + [[AbstractTrees]] git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -729,6 +735,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.7.1" + [[Qt5Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "xkbcommon_jll"] git-tree-sha1 = "ad368663a5e20dbb8d6dc2fddeefe4dae0781ae8" diff --git a/examples/deep-kernel-learning/Project.toml b/examples/deep-kernel-learning/Project.toml index f3b3a5b77..e7a9e4661 100644 --- a/examples/deep-kernel-learning/Project.toml +++ b/examples/deep-kernel-learning/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" @@ -6,14 +7,17 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AbstractGPs = "0.3" Distributions = "0.25" Flux = "0.12" KernelFunctions = "0.10" Literate = "2" MLDataUtils = "0.5" Plots = "1" +ProgressMeter = "1" Zygote = "0.6" julia = "1.3" From 16abb580e3ac13f22b10ae06daccff0b605f383b Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 2 Jul 2021 11:20:18 +0300 Subject: [PATCH 3/3] drop pyplot backend --- examples/deep-kernel-learning/script.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/deep-kernel-learning/script.jl b/examples/deep-kernel-learning/script.jl index ebe279618..88865f69a 100644 --- a/examples/deep-kernel-learning/script.jl +++ b/examples/deep-kernel-learning/script.jl @@ -8,7 +8,6 @@ using Distributions, LinearAlgebra using Plots using ProgressMeter using AbstractGPs -pyplot(); default(; legendfontsize=15.0, linewidth=3.0); # ## Data creation