Skip to content

Improve moving data / model to GPU using torchtext #1245

@elkotito

Description

@elkotito

🚀 Feature

Improve moving data and model to GPU if torchtext is used.

Motivation

Case 1:

Batch object generated by torchtext.data.Iterator doesn't follow the rules described here https://github.com/PyTorchLightning/pytorch-lightning/blob/45d671a4a81788b9d97fd6b47763816926e58e95/pytorch_lightning/trainer/distrib_parts.py#L420

As the result data is not moved to GPU. torchtext.data.Iterator is returned by method train_dataloader. Take in mind that torchtext.data.Iterator has a device argument that is not properly utilized by pytorch-ligthning.

    @ptl.data_loader
    def train_dataloader(self):
        ...  
        return Iterator(dataset=dataset, batch_size=self.batch_size, shuffle=False, device=DEVICE)

Partially reported here #226

Case 2

Using torchtext you can read pre-trained embeddings and create nn.Embedding object as follows

    def train_dataloader(self):
        ...
        self.text_field.build_vocab(
            dataset,
            vectors=Vectors("/data/embeddings/glove/glove.840B.300d.txt"),
        )

        self.embeddings = nn.Embedding(
            ...
            padding_idx=self.text_field.vocab.stoi[PAD_TOKEN],
            _weight=self.text_field.vocab.vectors.to(DEVICE),
        )

nn.Embedding is clearly dependent on self.text_field.vocab and this is in turn dependent on dataset that is used by train_dataloader. Currently any part of the model that is not created fully in __init__ of the ptl.LigthningModule is not moved to the GPU. It requires still to have a global variable that determines a device i.e. DEVICE. It makes Trainer(n_gpus=...) useless.

Pitch

I would like not to worry about moving data to GPU using torchtext combined with pytorch-lightning.

Metadata

Metadata

Assignees

Labels

featureIs an improvement or enhancementhelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions