Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,36 @@ Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, x̄)
x .-= x̄
return
end

function update!(opt, x, x̄)
# skip, if gradient is nothing
update!(x::AbstractArray, x̄::Nothing) = nothing
update!(opt, x::AbstractArray, x̄::Nothing) = nothing
update!(opt, m::M, ∇m::Nothing) where M = nothing

function update!(opt, x::AbstractArray, x̄)
x .-= apply!(opt, x, x̄)
return
end

# NOTE: since there won't be real loop in a struct
# we could always flatten it, which is a bit
# faster.
@generated function update!(opt, m::M, ∇m) where M
body = Expr(:block)
for each in fieldnames(M)
each = QuoteNode(each)
push!(body.args, :(update!(opt, getfield(m, $each), getfield(∇m, $each))))
end
return body
end

function update!(opt, xs::Params, gs)
for x in xs
gs[x] == nothing && continue
update!(opt, x, gs[x])
end
return
end

# Callback niceties
Expand Down
23 changes: 22 additions & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux.Optimise
using Flux.Optimise: runall
using Flux.Optimise: runall, update!
using Flux: Params, gradient
using Test

Expand Down Expand Up @@ -89,3 +89,24 @@ end
@test decay_steps == ground_truth
@test o.eta == o.clip
end

@testset "update!" begin
opt = ADAM()
A = rand(2, 2)
B = copy(A)
@test update!(opt, A, nothing) == nothing
@test A == B

A = Dense(10, 10)
B = deepcopy(A)
@test update!(opt, A, nothing) == nothing
@test A.W == B.W && A.b == B.b

gs = (W=rand(10, 10), b=rand(10), σ=nothing)
update!(opt, A, gs)

update!(opt, B.W, gs.W)
update!(opt, B.b, gs.b)

@test A.W ≈ B.W && A.b ≈ B.b
end