Skip to content

Commit 626db1b

Browse files
committed
Fix type hints
1 parent b9b2d68 commit 626db1b

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ module = [
8989
"pytorch_lightning.trainer.optimizers",
9090
"pytorch_lightning.trainer.supporters",
9191
"pytorch_lightning.trainer.trainer",
92-
"pytorch_lightning.tuner.batch_size_scaling",
9392
"pytorch_lightning.tuner.lr_finder",
9493
"pytorch_lightning.tuner.tuning",
9594
"pytorch_lightning.utilities.auto_restart",

src/pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def _init_debugging_flags(
601601
"Logging and checkpointing is suppressed."
602602
)
603603

604-
self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches")
604+
self.limit_train_batches: Union[int, float] = _determine_batch_limits(limit_train_batches, "limit_train_batches")
605605
self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches")
606606
self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
607607
self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")

src/pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
import logging
1515
import os
1616
import uuid
17-
from typing import Any, Dict, Optional, Tuple
17+
from typing import Any, Dict, Union, List, Optional, Tuple
1818

1919
from torch.utils.data import DataLoader
2020

2121
import pytorch_lightning as pl
22-
from pytorch_lightning.loggers.logger import DummyLogger
22+
from pytorch_lightning.callbacks.callback import Callback
23+
from pytorch_lightning.loggers.logger import DummyLogger, Logger
24+
2325
from pytorch_lightning.utilities.data import has_len_all_ranks
2426
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2527
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
@@ -41,7 +43,7 @@ def scale_batch_size(
4143
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
4244
if trainer.fast_dev_run:
4345
rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.")
44-
return
46+
return None
4547

4648
if not lightning_hasattr(model, batch_arg_name):
4749
raise MisconfigurationException(f"Field {batch_arg_name} not found in both `model` and `model.hparams`")
@@ -234,18 +236,26 @@ def _adjust_batch_size(
234236
"""
235237
model = trainer.lightning_module
236238
batch_size = lightning_getattr(model, batch_arg_name)
237-
new_size = value if value is not None else int(batch_size * factor)
239+
if value is not None:
240+
new_size = value
241+
else:
242+
if not isinstance(batch_size, int):
243+
raise ValueError(f"value is None and batch_size is not int value: {batch_size}")
244+
new_size = int(batch_size * factor)
245+
238246
if desc:
239247
log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
240248

241249
if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
250+
if not isinstance(trainer.train_dataloader, DataLoader):
251+
raise ValueError("train_dataloader is not a DataLoader")
242252
new_size = min(new_size, len(trainer.train_dataloader.dataset))
243253

244254
changed = new_size != batch_size
245255
lightning_setattr(model, batch_arg_name, new_size)
246256
return new_size, changed
247257

248258

249-
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
259+
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool:
250260
module = trainer.lightning_module or trainer.datamodule
251261
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)

0 commit comments

Comments
 (0)