-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
In the doc auto-scaling-of-batch-size section, a code example is
# Use default in trainer construction
trainer = Trainer()
tuner = Tuner(trainer)
# Invoke method
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)
# Override old batch size
model.hparams.batch_size = new_batch_size
# Fit as normal
trainer.fit(model)
However, this will not work as expected in the case where a LightningModule contains an attribute self.datamodule. Following the code will give MisconfigurationException: Field batch_size not found in both model and model.hparams.
To Reproduce
See my one-page code
import torch
import torchvision
from torchvision import transforms
from torchvision import models
import utils
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
class ImageNetDataModule(pl.LightningDataModule):
def __init__(self, imagenet_root, batch_size = 128, num_workers = 32):
super().__init__()
self.batch_size = batch_size
self.imagenet_root = imagenet_root
self.num_workers = num_workers
def setup(self, stage):
train_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomGrayscale(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
self.imagenet_train = torchvision.datasets.ImageNet(root = self.imagenet_root+"train/", split="train", transform = train_transform)
val_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
self.imagenet_val = torchvision.datasets.ImageNet(root= self.imagenet_root+"val/", split="val", transform = val_transform)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.imagenet_train,
batch_size=self.batch_size,
shuffle=True,
num_workers = self.num_workers,
persistent_workers=True
)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.imagenet_val,
batch_size=self.batch_size,
shuffle=False,
num_workers = self.num_workers,
persistent_workers=True
)
class NetWrapper(pl.LightningModule):
def __init__(self, model,datamodule, criterion = torch.nn.CrossEntropyLoss()):
super().__init__()
self.model = model
self.criterion = criterion
self.lr = 1e-3
self.datamodule = datamodule
def forward(self, x):
raw_prob = self.model(x) #(batch, 1000)
return raw_prob
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x) #(batch, word_emb_dim)
loss = self.criterion(preds, y)
self.log("cross_entropy_loss_training", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self(x) #(batch, word_emb_dim)
loss = self.criterion(preds, y)
self.log("cross_entropy_loss_val", loss)
return loss
def validation_epoch_end(self, validation_step_outputs):
all_outputs = torch.tensor(validation_step_outputs)
std, mean = torch.std_mean(all_outputs)
self.log("validation_mean", mean)
self.log("validation_std", std)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
imagenet_dm = ImageNetDataModule("../../datasets/ImageNet/", batch_size = 1024)
resnet50 = models.resnet50(pretrained=False)
resnet50_pl_module = NetWrapper(resnet50, imagenet_dm)
trainer = pl.Trainer(gpus=1,
accelerator='dp',
auto_scale_batch_size='binsearch')
tuner = pl.tuner.tuning.Tuner(trainer)
new_batch_size = tuner.scale_batch_size(resnet50_pl_module, mode="binsearch", init_val=128)
# the below line works fine
trainer.tune(resnet50_pl_module)Expected behavior
Tuner should find the attibute batch_size in model.datamodule in the method Tuner.scale_batch_size().
Environment
This issue should be independent of environments.
Additional context
I took a look at the source code and found out that if we call Trainer.tune() directly, the invoke chain is
trainer.tune() -> tuner.tune()->tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->...
while the invoke chain of calling Tuner.scale_batch_size() is
tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->....
The problem is that lightning_hasattr(model, attribute) cannot find the attribute model.datamodule.batch_size if we skip the registration steps in trainer.tune().