Skip to content

support of other datatypes for Batches #2465

@Tieftaucher

Description

@Tieftaucher

🚀 Feature

My proposal is a support of third party data structures as Batches. At the moment you need to overwrite the transfer_batch_to_device method of your model, if your Batch is not a collection or one of the other supported data types.
My suggestion would be to accept all kinds of data types, as long as they have a to(device)-method.

Motivation

I want to use pytorch_geometric, but there an own dataloader is used and an own Batch-datatype. So I had some trouble using it together with pytorch lighning. After some struggles I figured out to overwrite the transfer_batch_to_device-Method like this:

class Net(pl.LightningModule):
....
    def transfer_batch_to_device(self,batch, device):
        return batch.to(device)

At least I think it would be nice to mention this necessarity in the docs. Or change the default behaviour of transfer_batch_to_device, so that it is no longer necessary.

Pitch

The transfer_batch_to_device should accept all datatypes that contain a "to(device)" method.

Alternatives

Alternative there should be a mentioning in the documentation for using non default dataloader and Batches

Additional context

I saw #1756 but couldnt figure out, if this solves my problem and is just not merged yet or not. If it does, sorry for the extra work.

Thank you for the nice library and all your work =)

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions