-
-
Notifications
You must be signed in to change notification settings - Fork 615
Description
Following up from #607
We should expose functionality, that lets the user write a training loop,
while thinking only about loss.
Rather than thinking about gradient
and update!
.
loss
is a higher level concept than gradients.
Custom traing loops are important since many things do not coomfortably fit into the abstraction of
train!(args->loss(args...), data, params, opt,callbacks)
.
The train!
functioin is good for things that comfortably fit supervised training,
and while it can do anything it becomes increasing akward the further you are from that.
At the other end is writing a custom training loop, invoking gradients
and update
.
This is fully general and you can do all kinds of things l like messing with the gradients during the training loop.
But there is a middle ground,
where you can define the loss, but you have nothing to say about the gradients.
For this I think we should have
train_step!(getloss, ps, opt)
, where getloss
is a 0 arg closure returning the loss (and using the model).
This wouldwould have pleasing symetry in name and arguments,
to train!(loss, ps, data, opt)
where loss
is a closure taking args as provided by iterating data
.
This would be useful because you are not being required to use the abstraction of having data
, but you have the rest.
The implementation would be very simple, but I feel the abstraction away from gradients
is worth it.
function train_step!(getloss, ps, opt)
gs = gradient(getloss, ps)
update!(opt, ps, gs)
end
This would go into the core of train!
.
Replacing
Lines 71 to 74 in 3a4c627
gs = gradient(ps) do | |
loss(d...) | |
end | |
update!(opt, ps, gs) |
with
trainstep!(x->loss(d...), ps, opt)