Skip to content

"Wrong" code example in doc auto-scaling-of-batch-size section #5967

@ifsheldon

Description

@ifsheldon

🐛 Bug

In the doc auto-scaling-of-batch-size section, a code example is

# Use default in trainer construction
trainer = Trainer()
tuner = Tuner(trainer)

# Invoke method
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# Override old batch size
model.hparams.batch_size = new_batch_size

# Fit as normal
trainer.fit(model)

However, this will not work as expected in the case where a LightningModule contains an attribute self.datamodule. Following the code will give MisconfigurationException: Field batch_size not found in both model and model.hparams.

To Reproduce

See my one-page code

import torch
import torchvision
from torchvision import transforms
from torchvision import models
import utils
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self, imagenet_root, batch_size = 128, num_workers = 32):
        super().__init__()
        self.batch_size = batch_size
        self.imagenet_root = imagenet_root
        self.num_workers = num_workers
        
    def setup(self, stage):
        train_transform = transforms.Compose([transforms.Resize(256), 
                                        transforms.CenterCrop(224), 
                                        transforms.RandomGrayscale(),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(), 
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        self.imagenet_train = torchvision.datasets.ImageNet(root = self.imagenet_root+"train/", split="train", transform = train_transform)
        val_transform = transforms.Compose([transforms.Resize(256), 
                                        transforms.CenterCrop(224), 
                                        transforms.ToTensor(), 
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        self.imagenet_val = torchvision.datasets.ImageNet(root= self.imagenet_root+"val/", split="val", transform = val_transform)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.imagenet_train, 
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers = self.num_workers,
                                           persistent_workers=True
                                          )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.imagenet_val, 
                                           batch_size=self.batch_size,
                                           shuffle=False,
                                           num_workers = self.num_workers,
                                           persistent_workers=True
                                          )
    
class NetWrapper(pl.LightningModule):
    def __init__(self, model,datamodule, criterion = torch.nn.CrossEntropyLoss()):
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.lr = 1e-3
        self.datamodule = datamodule
    
    def forward(self, x):
        raw_prob = self.model(x) #(batch, 1000)
        return raw_prob
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x) #(batch, word_emb_dim)
        loss = self.criterion(preds, y)
        self.log("cross_entropy_loss_training", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x) #(batch, word_emb_dim)
        loss = self.criterion(preds, y)
        self.log("cross_entropy_loss_val", loss)
        
        return loss
    
    def validation_epoch_end(self, validation_step_outputs):
        all_outputs = torch.tensor(validation_step_outputs)
        std, mean = torch.std_mean(all_outputs)
        self.log("validation_mean", mean)
        self.log("validation_std", std)
        
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

imagenet_dm = ImageNetDataModule("../../datasets/ImageNet/", batch_size = 1024)
resnet50 = models.resnet50(pretrained=False)
resnet50_pl_module = NetWrapper(resnet50, imagenet_dm)
trainer = pl.Trainer(gpus=1, 
                     accelerator='dp', 
                     auto_scale_batch_size='binsearch')

tuner = pl.tuner.tuning.Tuner(trainer)
new_batch_size = tuner.scale_batch_size(resnet50_pl_module, mode="binsearch", init_val=128)
# the below line works fine
trainer.tune(resnet50_pl_module)

Expected behavior

Tuner should find the attibute batch_size in model.datamodule in the method Tuner.scale_batch_size().

Environment

This issue should be independent of environments.

Additional context

I took a look at the source code and found out that if we call Trainer.tune() directly, the invoke chain is
trainer.tune() -> tuner.tune()->tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->...
while the invoke chain of calling Tuner.scale_batch_size() is
tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->....
The problem is that lightning_hasattr(model, attribute) cannot find the attribute model.datamodule.batch_size if we skip the registration steps in trainer.tune().

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onpriority: 1Medium priority task

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions