Skip to content

Commit 24d306a

Browse files
committed
extract deep-kernel-learning example from st/examples (#234)
1 parent b150888 commit 24d306a

File tree

1 file changed

+68
-58
lines changed

1 file changed

+68
-58
lines changed
Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,78 @@
1-
# # Deep Kernel Learning
2-
#
3-
# !!! warning
4-
# This example is under construction
5-
6-
# Setup
7-
1+
# # Deep Kernel Learning with Flux
2+
# ## Package loading
3+
# We use a couple of useful packages to plot and optimize
4+
# the different hyper-parameters
85
using KernelFunctions
9-
using MLDataUtils
10-
using Zygote
116
using Flux
127
using Distributions, LinearAlgebra
138
using Plots
9+
using ProgressMeter
10+
using AbstractGPs
11+
pyplot();
12+
default(; legendfontsize=15.0, linewidth=3.0);
1413

15-
Flux.@functor SqExponentialKernel
16-
Flux.@functor KernelSum
17-
Flux.@functor Matern32Kernel
18-
Flux.@functor FunctionTransform
19-
20-
# set up a kernel with a neural network feature extractor:
21-
22-
neuralnet = Chain(Dense(1, 3), Dense(3, 2))
23-
k = SqExponentialKernel() FunctionTransform(neuralnet)
24-
25-
# Generate date
26-
14+
# ## Data creation
15+
# We create a simple 1D Problem with very different variations
2716
xmin = -3;
28-
xmax = 3;
29-
x = range(xmin, xmax; length=100)
30-
x_test = rand(Uniform(xmin, xmax), 200)
31-
x, y = noisy_function(sinc, x; noise=0.1)
32-
X = RowVecs(reshape(x, :, 1))
33-
X_test = RowVecs(reshape(x_test, :, 1))
34-
λ = [0.1]
35-
#md nothing #hide
36-
37-
#
38-
39-
f(x, k, λ) = kernelmatrix(k, x, X) / (kernelmatrix(k, X) + exp(λ[1]) * I) * y
40-
f(X, k, 1.0)
41-
42-
#
43-
44-
loss(k, λ) = (ŷ -> sum(y - ŷ) / length(y) + exp(λ[1]) * norm(ŷ))(f(X, k, λ))
45-
loss(k, λ)
46-
47-
#
48-
17+
xmax = 3; # Limits
18+
N = 150
19+
noise = 0.01
20+
x_train = collect(eachrow(rand(Uniform(xmin, xmax), N))) # Training dataset
21+
target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value
22+
target_f(x::AbstractArray) = target_f(first(x))
23+
y_train = target_f.(x_train) + randn(N) * noise
24+
x_test = collect(eachrow(range(xmin, xmax; length=200))) # Testing dataset
25+
spectral_mixture_kernel()
26+
# ## Model definition
27+
# We create a neural net with 2 layers and 10 units each
28+
# The data is passed through the NN before being used in the kernel
29+
neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5))
30+
# We use two cases :
31+
# - The Squared Exponential Kernel
32+
k = transform(SqExponentialKernel(), FunctionTransform(neuralnet))
33+
34+
# We use AbstractGPs.jl to define our model
35+
gpprior = GP(k) # GP Prior
36+
fx = AbstractGPs.FiniteGP(gpprior, x_train, noise) # Prior on f
37+
fp = posterior(fx, y_train) # Posterior of f
38+
39+
# This compute the log evidence of `y`,
40+
# which is going to be used as the objective
41+
loss(y) = -logpdf(fx, y)
42+
43+
@info "Init Loss = $(loss(y_train))"
44+
45+
# Flux will automatically extract all the parameters of the kernel
4946
ps = Flux.params(k)
50-
# push!(ps,λ)
51-
opt = Flux.Momentum(1.0)
52-
#md nothing #hide
5347

54-
#
55-
56-
plots = []
57-
for i in 1:10
58-
grads = Zygote.gradient(() -> loss(k, λ), ps)
48+
# We show the initial prediction with the untrained model
49+
p_init = Plots.plot(
50+
vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))"
51+
)
52+
Plots.scatter!(vcat(x_train...), y_train; lab="data")
53+
pred = marginals(fp(x_test))
54+
Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction")
55+
# ## Training
56+
anim = Animation()
57+
nmax = 1000
58+
opt = Flux.ADAM(0.1)
59+
@showprogress for i in 1:nmax
60+
global grads = gradient(ps) do
61+
loss(y_train)
62+
end
5963
Flux.Optimise.update!(opt, ps, grads)
60-
p = Plots.scatter(x, y; lab="data", title="Loss = $(loss(k,λ))")
61-
Plots.plot!(x, f(X, k, λ); lab="Prediction", lw=3.0)
62-
push!(plots, p)
64+
if i % 100 == 0
65+
@info "$i/$nmax"
66+
L = loss(y_train)
67+
# @info "Loss = $L"
68+
p = Plots.plot(
69+
vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))"
70+
)
71+
p = Plots.scatter!(vcat(x_train...), y_train; lab="data")
72+
pred = marginals(posterior(fx, y_train)(x_test))
73+
Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction")
74+
frame(anim)
75+
display(p)
76+
end
6377
end
64-
65-
#
66-
67-
l = @layout grid(10, 1)
68-
plot(plots...; layout=l, size=(300, 1500))
78+
gif(anim; fps=5)

0 commit comments

Comments
 (0)