|
| 1 | +# The AD generates fairly large backtraces that are unhelpful if you interrupt |
| 2 | +# while training; this just cleans that up. |
| 3 | +macro interrupts(ex) |
| 4 | + :(try $(esc(ex)) |
| 5 | + catch e |
| 6 | + e isa InterruptException || rethrow() |
| 7 | + throw(e) |
| 8 | + end) |
| 9 | +end |
| 10 | + |
| 11 | +# In-place gradients |
| 12 | + |
1 | 13 | init_grad(x) = zero(x) |
2 | 14 | zero_grad!(x) = zero(x) |
3 | 15 | zero_grad!(x::AbstractArray) = (x .= 0) |
@@ -66,63 +78,33 @@ function back!(x, Δ; once = true) |
66 | 78 | return |
67 | 79 | end |
68 | 80 |
|
| 81 | +function extract_grad!(x) |
| 82 | + x̄ = copy(grad(x)) |
| 83 | + x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) |
| 84 | + tracker(x).grad = zero_grad!(grad(x)) |
| 85 | + return x̄ |
| 86 | +end |
| 87 | + |
69 | 88 | function gradient_(f, xs...) |
70 | 89 | xs = param.(data.(xs)) |
71 | 90 | l = f(xs...) |
72 | 91 | losscheck(l) |
73 | | - back!(l) |
74 | | - nobacksies("Use `gradient(...; nest = true)` for nested derivatives", |
75 | | - grad.(xs)) |
| 92 | + @interrupts back!(l) |
| 93 | + extract_grad!.(xs) |
76 | 94 | end |
77 | 95 |
|
78 | | -# Out-of-place gradients |
79 | | - |
80 | | -struct Params |
81 | | - order::Vector{Any} |
82 | | - params::IdSet{Any} |
83 | | - Params() = new([], IdSet()) |
84 | | -end |
85 | | - |
86 | | -@forward Params.order Base.iterate, Base.length |
87 | | - |
88 | | -function Base.push!(ps::Params, x) |
89 | | - if !(x in ps.params) |
90 | | - push!(ps.order, x) |
91 | | - push!(ps.params, x) |
| 96 | +function gradient_(f, xs::Params) |
| 97 | + l = f() |
| 98 | + losscheck(l) |
| 99 | + @interrupts back!(l) |
| 100 | + gs = Grads() |
| 101 | + for x in xs |
| 102 | + gs[tracker(x)] = extract_grad!(x) |
92 | 103 | end |
93 | | - return ps |
94 | | -end |
95 | | - |
96 | | -Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) |
97 | | - |
98 | | -Params(xs) = push!(Params(), xs...) |
99 | | - |
100 | | -function Base.show(io::IO, ps::Params) |
101 | | - print(io, "Params([") |
102 | | - join(io, ps.order, ", ") |
103 | | - print(io, "])") |
104 | | -end |
105 | | - |
106 | | -struct Grads |
107 | | - grads::IdDict{Any,Any} |
| 104 | + return gs |
108 | 105 | end |
109 | 106 |
|
110 | | -Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") |
111 | | - |
112 | | -Grads() = Grads(IdDict()) |
113 | | - |
114 | | -@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate |
115 | | - |
116 | | -Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) |
117 | | - |
118 | | -Base.getindex(g::Grads, x::Tracked) = g.grads[x] |
119 | | - |
120 | | -function Base.getindex(g::Grads, x) |
121 | | - istracked(x) || error("Object not tracked: $x") |
122 | | - g[tracker(x)] |
123 | | -end |
124 | | - |
125 | | -accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ |
| 107 | +# Out-of-place gradients |
126 | 108 |
|
127 | 109 | function back_(g::Grads, c::Call, Δ) |
128 | 110 | Δs = c.func(Δ) |
|
182 | 164 | gradient(f, xs...; nest = false) = |
183 | 165 | nest ? gradient_nested(f, xs...) : gradient_(f, xs...) |
184 | 166 |
|
185 | | -gradient(f, ps::Params) = gradient_nested(f, ps) |
186 | | - |
187 | 167 | # Jacobians and Hessians |
188 | 168 |
|
189 | 169 | import ..Flux |
|
0 commit comments