Skip to content

Commit 3a4c627

Browse files
authored
Merge pull request #651 from FluxML/mji/dogfood
Refactor training loop
2 parents fc6232b + 4cf43c0 commit 3a4c627

File tree

7 files changed

+111
-87
lines changed

7 files changed

+111
-87
lines changed

src/optimise/optimisers.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
3737

3838
function apply!(o::Momentum, x, Δ)
3939
η, ρ = o.eta, o.rho
40-
v = get!(o.velocity, x, zero(x))::typeof(x)
40+
v = get!(o.velocity, x, zero(x))::typeof(data(x))
4141
@. v = ρ * v - η * Δ
4242
@. Δ = -v
4343
end
@@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
5757

5858
function apply!(o::Nesterov, x, Δ)
5959
η, ρ = o.eta, o.rho
60-
v = get!(o.velocity, x, zero(x))::typeof(x)
60+
v = get!(o.velocity, x, zero(x))::typeof(data(x))
6161
d = @. ρ^2 * v - (1+ρ) * η * Δ
6262
@. v = ρ*v - η*Δ
6363
@. Δ = -d
@@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
8080

8181
function apply!(o::RMSProp, x, Δ)
8282
η, ρ = o.eta, o.rho
83-
acc = get!(o.acc, x, zero(x))::typeof(x)
83+
acc = get!(o.acc, x, zero(x))::typeof(data(x))
8484
@. acc = ρ * acc + (1 - ρ) * Δ^2
8585
@. Δ *= η / (acc + ϵ)
8686
end
@@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
147147

148148
function apply!(o::ADAGrad, x, Δ)
149149
η = o.eta
150-
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
150+
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
151151
@. acc += Δ^2
152152
@. Δ *= η / (acc + ϵ)
153153
end
@@ -321,7 +321,7 @@ end
321321

322322
WeightDecay() = WeightDecay(0)
323323

324-
function apply!(o::WeightDecay, x, Δ)
324+
function apply!(o::WeightDecay, x, Δ)
325325
wd = o.wd
326-
@. Δ += wd * x
326+
@. Δ += wd * data(x)
327327
end

src/optimise/train.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
using Juno
2-
import Flux.Tracker: data, grad, back!, update!
2+
import Flux.Tracker: Params, gradient, data, update!
33
import Base.depwarn
44

55
function update!(opt, x, x̄)
6-
update!(x, apply!(opt, x, copy(data(x̄))))
6+
update!(x, -apply!(opt, x, data(x̄)))
77
end
88

9+
function update!(opt, xs::Params, gs)
10+
for x in xs
11+
update!(opt, x, gs[x])
12+
end
13+
end
14+
15+
# Added as an internal API but everyone started using it.
916
function _update_params!(opt, xs)
17+
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
1018
for x in xs
11-
Δ = apply!(opt, x.data, x.grad)
12-
x.data .-= Δ
13-
Δ .= 0
19+
update!(opt, x, Tracker.grad(x))
20+
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
1421
end
1522
end
1623

@@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
1926
runall(f) = f
2027
runall(fs::AbstractVector) = () -> foreach(call, fs)
2128

22-
# The AD generates fairly large backtraces that are unhelpful if you interrupt
23-
# while training; this just cleans that up.
24-
macro interrupts(ex)
25-
:(try $(esc(ex))
26-
catch e
27-
e isa InterruptException || rethrow()
28-
throw(e)
29-
end)
30-
end
31-
3229
struct StopException <: Exception end
3330
"""
3431
stop()
@@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
6764
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
6865
"""
6966
function train!(loss, ps, data, opt; cb = () -> ())
67+
ps = Params(ps)
7068
cb = runall(cb)
71-
opt = runall(opt)
7269
@progress for d in data
7370
try
74-
l = loss(d...)
75-
@interrupts back!(l)
76-
_update_params!(opt, ps)
71+
gs = gradient(ps) do
72+
loss(d...)
73+
end
74+
update!(opt, ps, gs)
7775
if cb() == :stop
7876
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
7977
break

src/tracker/Tracker.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ macro grad(ex)
6262
end
6363

6464
include("idset.jl")
65+
include("params.jl")
6566
include("back.jl")
6667
include("numeric.jl")
6768
include("lib/real.jl")

src/tracker/back.jl

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
# The AD generates fairly large backtraces that are unhelpful if you interrupt
2+
# while training; this just cleans that up.
3+
macro interrupts(ex)
4+
:(try $(esc(ex))
5+
catch e
6+
e isa InterruptException || rethrow()
7+
throw(e)
8+
end)
9+
end
10+
11+
# In-place gradients
12+
113
init_grad(x) = zero(x)
214
zero_grad!(x) = zero(x)
315
zero_grad!(x::AbstractArray) = (x .= 0)
@@ -66,63 +78,33 @@ function back!(x, Δ; once = true)
6678
return
6779
end
6880

81+
function extract_grad!(x)
82+
= copy(grad(x))
83+
= nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
84+
tracker(x).grad = zero_grad!(grad(x))
85+
return
86+
end
87+
6988
function gradient_(f, xs...)
7089
xs = param.(data.(xs))
7190
l = f(xs...)
7291
losscheck(l)
73-
back!(l)
74-
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
75-
grad.(xs))
92+
@interrupts back!(l)
93+
extract_grad!.(xs)
7694
end
7795

78-
# Out-of-place gradients
79-
80-
struct Params
81-
order::Vector{Any}
82-
params::IdSet{Any}
83-
Params() = new([], IdSet())
84-
end
85-
86-
@forward Params.order Base.iterate, Base.length
87-
88-
function Base.push!(ps::Params, x)
89-
if !(x in ps.params)
90-
push!(ps.order, x)
91-
push!(ps.params, x)
96+
function gradient_(f, xs::Params)
97+
l = f()
98+
losscheck(l)
99+
@interrupts back!(l)
100+
gs = Grads()
101+
for x in xs
102+
gs[tracker(x)] = extract_grad!(x)
92103
end
93-
return ps
94-
end
95-
96-
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
97-
98-
Params(xs) = push!(Params(), xs...)
99-
100-
function Base.show(io::IO, ps::Params)
101-
print(io, "Params([")
102-
join(io, ps.order, ", ")
103-
print(io, "])")
104-
end
105-
106-
struct Grads
107-
grads::IdDict{Any,Any}
104+
return gs
108105
end
109106

110-
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
111-
112-
Grads() = Grads(IdDict())
113-
114-
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
115-
116-
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
117-
118-
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
119-
120-
function Base.getindex(g::Grads, x)
121-
istracked(x) || error("Object not tracked: $x")
122-
g[tracker(x)]
123-
end
124-
125-
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
107+
# Out-of-place gradients
126108

127109
function back_(g::Grads, c::Call, Δ)
128110
Δs = c.func(Δ)
@@ -182,8 +164,6 @@ end
182164
gradient(f, xs...; nest = false) =
183165
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
184166

185-
gradient(f, ps::Params) = gradient_nested(f, ps)
186-
187167
# Jacobians and Hessians
188168

189169
import ..Flux

src/tracker/lib/array.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
7171
return x
7272
end
7373

74+
function update!(x::AbstractArray, Δ)
75+
x .+= data(Δ)
76+
return x
77+
end
78+
7479
# Fallthrough methods
7580

7681
for f in :[Base.size, Base.ndims, Base.collect].args

src/tracker/params.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
struct Params
2+
order::Vector{Any}
3+
params::IdSet{Any}
4+
Params() = new([], IdSet())
5+
end
6+
7+
@forward Params.order Base.iterate, Base.length
8+
9+
function Base.push!(ps::Params, x)
10+
if !(x in ps.params)
11+
push!(ps.order, x)
12+
push!(ps.params, x)
13+
end
14+
return ps
15+
end
16+
17+
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
18+
19+
Params(xs) = push!(Params(), xs...)
20+
21+
function Base.show(io::IO, ps::Params)
22+
print(io, "Params([")
23+
join(io, ps.order, ", ")
24+
print(io, "])")
25+
end
26+
27+
struct Grads
28+
grads::IdDict{Any,Any}
29+
end
30+
31+
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
32+
33+
Grads() = Grads(IdDict())
34+
35+
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
36+
37+
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
38+
39+
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
40+
41+
function Base.getindex(g::Grads, x)
42+
istracked(x) || error("Object not tracked: $x")
43+
g[tracker(x)]
44+
end
45+
46+
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ

test/optimise.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,15 @@ using Flux.Tracker
44
using Test
55
@testset "Optimise" begin
66
w = randn(10, 10)
7-
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
7+
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
8+
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
9+
Momentum()]
810
w′ = param(randn(10, 10))
911
loss(x) = Flux.mse(w*x, w′*x)
10-
opt = Opt(0.001)
11-
if opt isa Descent || opt isa ADAGrad
12-
opt = Opt(0.1)
13-
end
14-
if opt isa ADADelta
15-
opt = Opt(0.9)
16-
end
1712
for t = 1: 10^5
18-
l = loss(rand(10))
19-
back!(l)
20-
delta = Optimise.apply!(opt, w′.data, w′.grad)
21-
w′.data .-= delta
13+
θ = Params([w′])
14+
θ̄ = gradient(() -> loss(rand(10)), θ)
15+
Optimise.update!(opt, θ, θ̄)
2216
end
2317
@test Flux.mse(w, w′) < 0.01
2418
end

0 commit comments

Comments
 (0)