- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 5.1k
 
Description
Is your feature request related to a problem? Please describe.
I rely on the forward_intermediates() API for object detection models, and I'm experimenting with ViT-g and would like to try gradient checkpointing.
Describe the solution you'd like
In VisionTransformer.forward_features() we have:
if self.grad_checkpointing and not torch.jit.is_scripting():
    x = checkpoint_seq(self.blocks, x)I'm thinking something like this could work in VisionTransformer.forward_intermediates():
for i, blk in enumerate(blocks):
    if self.grad_checkpointing and not torch.jit.is_scripting():
        x = checkpoint_module(blk, x)
    else:
        x = blk(x)I called this checkpoint_module() but I think we could just use checkpoint_seq() directly, based on the code? Either way, is this as simple as I think it would be, or am I missing something? I haven't used gradient checkpointing a lot so I'm not entirely sure.
I'm happy to submit a PR for a few models if it's as simple as calling checkpoint_seq() in forward_intermediates() as I've outlined above. I'm not sure how many models use this API and/or self.grad_checkpointing, and whether you want this to be supported in all of them.