Skip to content

add trainstep! #666

@oxinabox

Description

@oxinabox

Following up from #607

We should expose functionality, that lets the user write a training loop,
while thinking only about loss.
Rather than thinking about gradient and update!.
loss is a higher level concept than gradients.

Custom traing loops are important since many things do not coomfortably fit into the abstraction of
train!(args->loss(args...), data, params, opt,callbacks).
The train! functioin is good for things that comfortably fit supervised training,
and while it can do anything it becomes increasing akward the further you are from that.
At the other end is writing a custom training loop, invoking gradients and update.
This is fully general and you can do all kinds of things l like messing with the gradients during the training loop.
But there is a middle ground,
where you can define the loss, but you have nothing to say about the gradients.

For this I think we should have

train_step!(getloss, ps, opt), where getloss is a 0 arg closure returning the loss (and using the model).
This wouldwould have pleasing symetry in name and arguments,
to train!(loss, ps, data, opt) where loss is a closure taking args as provided by iterating data.

This would be useful because you are not being required to use the abstraction of having data, but you have the rest.

The implementation would be very simple, but I feel the abstraction away from gradients is worth it.

function train_step!(getloss, ps, opt)
    gs = gradient(getloss, ps)
    update!(opt, ps, gs)
end

This would go into the core of train!.
Replacing

gs = gradient(ps) do
loss(d...)
end
update!(opt, ps, gs)

with

trainstep!(x->loss(d...), ps, opt)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions