-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
My proposal is a support of third party data structures as Batches. At the moment you need to overwrite the transfer_batch_to_device method of your model, if your Batch is not a collection or one of the other supported data types.
My suggestion would be to accept all kinds of data types, as long as they have a to(device)-method.
Motivation
I want to use pytorch_geometric, but there an own dataloader is used and an own Batch-datatype. So I had some trouble using it together with pytorch lighning. After some struggles I figured out to overwrite the transfer_batch_to_device-Method like this:
class Net(pl.LightningModule):
....
def transfer_batch_to_device(self,batch, device):
return batch.to(device)
At least I think it would be nice to mention this necessarity in the docs. Or change the default behaviour of transfer_batch_to_device, so that it is no longer necessary.
Pitch
The transfer_batch_to_device should accept all datatypes that contain a "to(device)" method.
Alternatives
Alternative there should be a mentioning in the documentation for using non default dataloader and Batches
Additional context
I saw #1756 but couldnt figure out, if this solves my problem and is just not merged yet or not. If it does, sorry for the extra work.
Thank you for the nice library and all your work =)