Skip to content

Eliminate redundant device moves for callbacks #6990

@import-antigravity

Description

@import-antigravity

🚀 Feature

Make it possible to access batches on callbacks without moving them to the CPU

Motivation

Currently, on callback hooks like on_train_batch_end where a batch tensor is available, the tensor is on cpu by default. If I want to, for example, run that batch through another model, I have to move it back to CUDA with batch.to('cuda') even though it must have been on CUDA in order for the training step to have taken place. This creates a serious performance bottleneck.

Pitch

Either add an argument to on_train_batch_end and similar hooks for cuda or add new hooks for before the move takes place.

Additional context

#6945

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions