Skip to content

[RFC] Gradient clipping hooks in the LightningModule #6346

@carmocca

Description

@carmocca

🚀 Feature

Add clipping hooks to the LightningModule

Motivation

It's currently very difficult to change the clipping logic

Pitch

class LightningModule:
    def clip_gradients(self, optimizer, optimizer_idx):
        ...

The default implementation would be the same as we currently provide, where the trainer's clipping flags are used.

Maybe those would be deprecated in favor of LightningModule properties.

class LightningOptimizer
    def step(closure=closure)
        if closure is None:
            closure = do_nothing_closure
        def wrapper_closure()
            closure()
            self._trainer.call_hook("clip_gradients", self.optimizer)
        self.optimizer.step(closure=wrapper_closure)

Need to evaluate the limitations since clipping is currently tied to plugins

Additional context

This would fix #5096, #6123 (comment), #5671, #5982, and allow easily implementing new clipping techniques without having to merge them into Lightning

cc: @rohitgr7 who has been pushing for this for a while

Metadata

Metadata

Assignees

Labels

designIncludes a design discussionfeatureIs an improvement or enhancementhelp wantedOpen to be worked onrefactor

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions