|
1 | | -# # Deep Kernel Learning |
2 | | -# |
3 | | -# !!! warning |
4 | | -# This example is under construction |
5 | | - |
6 | | -# Setup |
7 | | - |
| 1 | +# # Deep Kernel Learning with Flux |
| 2 | +# ## Package loading |
| 3 | +# We use a couple of useful packages to plot and optimize |
| 4 | +# the different hyper-parameters |
8 | 5 | using KernelFunctions |
9 | | -using MLDataUtils |
10 | | -using Zygote |
11 | 6 | using Flux |
12 | 7 | using Distributions, LinearAlgebra |
13 | 8 | using Plots |
| 9 | +using ProgressMeter |
| 10 | +using AbstractGPs |
| 11 | +pyplot(); |
| 12 | +default(; legendfontsize=15.0, linewidth=3.0); |
14 | 13 |
|
15 | | -Flux.@functor SqExponentialKernel |
16 | | -Flux.@functor KernelSum |
17 | | -Flux.@functor Matern32Kernel |
18 | | -Flux.@functor FunctionTransform |
19 | | - |
20 | | -# set up a kernel with a neural network feature extractor: |
21 | | - |
22 | | -neuralnet = Chain(Dense(1, 3), Dense(3, 2)) |
23 | | -k = SqExponentialKernel() ∘ FunctionTransform(neuralnet) |
24 | | - |
25 | | -# Generate date |
26 | | - |
| 14 | +# ## Data creation |
| 15 | +# We create a simple 1D Problem with very different variations |
27 | 16 | xmin = -3; |
28 | | -xmax = 3; |
29 | | -x = range(xmin, xmax; length=100) |
30 | | -x_test = rand(Uniform(xmin, xmax), 200) |
31 | | -x, y = noisy_function(sinc, x; noise=0.1) |
32 | | -X = RowVecs(reshape(x, :, 1)) |
33 | | -X_test = RowVecs(reshape(x_test, :, 1)) |
34 | | -λ = [0.1] |
35 | | -#md nothing #hide |
36 | | - |
37 | | -# |
38 | | - |
39 | | -f(x, k, λ) = kernelmatrix(k, x, X) / (kernelmatrix(k, X) + exp(λ[1]) * I) * y |
40 | | -f(X, k, 1.0) |
41 | | - |
42 | | -# |
43 | | - |
44 | | -loss(k, λ) = (ŷ -> sum(y - ŷ) / length(y) + exp(λ[1]) * norm(ŷ))(f(X, k, λ)) |
45 | | -loss(k, λ) |
46 | | - |
47 | | -# |
48 | | - |
| 17 | +xmax = 3; # Limits |
| 18 | +N = 150 |
| 19 | +noise = 0.01 |
| 20 | +x_train = collect(eachrow(rand(Uniform(xmin, xmax), N))) # Training dataset |
| 21 | +target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value |
| 22 | +target_f(x::AbstractArray) = target_f(first(x)) |
| 23 | +y_train = target_f.(x_train) + randn(N) * noise |
| 24 | +x_test = collect(eachrow(range(xmin, xmax; length=200))) # Testing dataset |
| 25 | +spectral_mixture_kernel() |
| 26 | +# ## Model definition |
| 27 | +# We create a neural net with 2 layers and 10 units each |
| 28 | +# The data is passed through the NN before being used in the kernel |
| 29 | +neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5)) |
| 30 | +# We use two cases : |
| 31 | +# - The Squared Exponential Kernel |
| 32 | +k = transform(SqExponentialKernel(), FunctionTransform(neuralnet)) |
| 33 | + |
| 34 | +# We use AbstractGPs.jl to define our model |
| 35 | +gpprior = GP(k) # GP Prior |
| 36 | +fx = AbstractGPs.FiniteGP(gpprior, x_train, noise) # Prior on f |
| 37 | +fp = posterior(fx, y_train) # Posterior of f |
| 38 | + |
| 39 | +# This compute the log evidence of `y`, |
| 40 | +# which is going to be used as the objective |
| 41 | +loss(y) = -logpdf(fx, y) |
| 42 | + |
| 43 | +@info "Init Loss = $(loss(y_train))" |
| 44 | + |
| 45 | +# Flux will automatically extract all the parameters of the kernel |
49 | 46 | ps = Flux.params(k) |
50 | | -# push!(ps,λ) |
51 | | -opt = Flux.Momentum(1.0) |
52 | | -#md nothing #hide |
53 | 47 |
|
54 | | -# |
55 | | - |
56 | | -plots = [] |
57 | | -for i in 1:10 |
58 | | - grads = Zygote.gradient(() -> loss(k, λ), ps) |
| 48 | +# We show the initial prediction with the untrained model |
| 49 | +p_init = Plots.plot( |
| 50 | + vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))" |
| 51 | +) |
| 52 | +Plots.scatter!(vcat(x_train...), y_train; lab="data") |
| 53 | +pred = marginals(fp(x_test)) |
| 54 | +Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction") |
| 55 | +# ## Training |
| 56 | +anim = Animation() |
| 57 | +nmax = 1000 |
| 58 | +opt = Flux.ADAM(0.1) |
| 59 | +@showprogress for i in 1:nmax |
| 60 | + global grads = gradient(ps) do |
| 61 | + loss(y_train) |
| 62 | + end |
59 | 63 | Flux.Optimise.update!(opt, ps, grads) |
60 | | - p = Plots.scatter(x, y; lab="data", title="Loss = $(loss(k,λ))") |
61 | | - Plots.plot!(x, f(X, k, λ); lab="Prediction", lw=3.0) |
62 | | - push!(plots, p) |
| 64 | + if i % 100 == 0 |
| 65 | + @info "$i/$nmax" |
| 66 | + L = loss(y_train) |
| 67 | + # @info "Loss = $L" |
| 68 | + p = Plots.plot( |
| 69 | + vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))" |
| 70 | + ) |
| 71 | + p = Plots.scatter!(vcat(x_train...), y_train; lab="data") |
| 72 | + pred = marginals(posterior(fx, y_train)(x_test)) |
| 73 | + Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction") |
| 74 | + frame(anim) |
| 75 | + display(p) |
| 76 | + end |
63 | 77 | end |
64 | | - |
65 | | -# |
66 | | - |
67 | | -l = @layout grid(10, 1) |
68 | | -plot(plots...; layout=l, size=(300, 1500)) |
| 78 | +gif(anim; fps=5) |
0 commit comments