Skip to content

Commit 04c794c

Browse files
williamFalconawaelchlirohitgr7Borda
authored
[WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213)
* fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 97dfd3a commit 04c794c

26 files changed

+424
-216
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Added
2323

24+
- Added overfit_batches, limit_xxx_batches flags (overfit now uses training set for all three) ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
25+
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
26+
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
27+
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
2428
- Added metrics
2529
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
2630
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
@@ -54,6 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5458

5559
### Deprecated
5660

61+
- Deprecated `overfit_pct`, `val_percent_check`, `test_percent_check` ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
5762
- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
5863
- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))
5964

docs/source/debugging.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,19 @@ Make model overfit on subset of data
4848
A good debugging technique is to take a tiny portion of your data (say 2 samples per class),
4949
and try to get your model to overfit. If it can't, it's a sign it won't work with large datasets.
5050

51-
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_pct`
51+
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_batches`
5252
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
5353

5454
.. testcode::
5555

56-
trainer = Trainer(overfit_pct=0.01)
56+
# use only 1% of training data (and use the same training Dataloader (with shuffle off) in val and test)
57+
trainer = Trainer(overfit_batches=0.01)
58+
59+
# or overfit a number of batches
60+
trainer = Trainer(overfit_batches=0.01)
61+
62+
With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler
63+
in the training set to turn off shuffle for you.
5764

5865
Print a summary of your LightningModule
5966
---------------------------------------

docs/source/fast_training.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ If you don't want to check 100% of the training/validation/test set (for debuggi
5656
# DEFAULT
5757
trainer = Trainer(
5858
train_percent_check=1.0,
59-
val_percent_check=1.0,
60-
test_percent_check=1.0
59+
limit_val_batches=1.0,
60+
limit_test_batches=1.0
6161
)
6262

6363
# check 10%, 20%, 30% only, respectively for training, validation and test set
6464
trainer = Trainer(
6565
train_percent_check=0.1,
66-
val_percent_check=0.2,
67-
test_percent_check=0.3
66+
limit_val_batches=0.2,
67+
limit_test_batches=0.3
6868
)
6969

70-
.. note:: ``train_percent_check``, ``val_percent_check`` and ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. ``val_percent_check`` will be ignored if ``fast_dev_run=True``.
70+
.. note:: ``train_percent_check``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``.
7171

72-
.. note:: If you set ``val_percent_check=0``, validation will be disabled.
72+
.. note:: If you set ``limit_val_batches=0``, validation will be disabled.

pytorch_lightning/callbacks/progress.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def total_val_batches(self) -> int:
9898
elif not self.trainer.disable_validation:
9999
is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0
100100
total_val_batches = trainer.num_val_batches if is_val_epoch else 0
101+
total_val_batches = sum(total_val_batches)
101102
return total_val_batches
102103

103104
@property
@@ -111,6 +112,7 @@ def total_test_batches(self) -> int:
111112
total_test_batches = len(self.trainer.test_dataloaders)
112113
else:
113114
total_test_batches = self.trainer.num_test_batches
115+
total_test_batches = sum(total_test_batches)
114116
return total_test_batches
115117

116118
def disable(self):

pytorch_lightning/trainer/__init__.py

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,40 @@ def on_train_end(self, trainer, pl_module):
433433
# default used by the Trainer
434434
trainer = Trainer(gradient_clip_val=0.0)
435435
436+
437+
limit_test_batches
438+
^^^^^^^^^^^^^^^^^^
439+
440+
How much of test dataset to check.
441+
442+
Example::
443+
444+
# default used by the Trainer
445+
trainer = Trainer(limit_test_batches=1.0)
446+
447+
# run through only 25% of the test set each epoch
448+
trainer = Trainer(limit_test_batches=0.25)
449+
450+
# run for only 10 batches
451+
trainer = Trainer(limit_test_batches=10)
452+
453+
limit_val_batches
454+
^^^^^^^^^^^^^^^^^
455+
456+
How much of validation dataset to check.
457+
Useful when debugging or testing something that happens at the end of an epoch.
458+
459+
Example::
460+
461+
# default used by the Trainer
462+
trainer = Trainer(limit_val_batches=1.0)
463+
464+
# run through only 25% of the validation set each epoch
465+
trainer = Trainer(limit_val_batches=0.25)
466+
467+
# run for only 10 batches
468+
trainer = Trainer(limit_val_batches=10)
469+
436470
log_gpu_memory
437471
^^^^^^^^^^^^^^
438472
Options:
@@ -652,29 +686,28 @@ def on_train_end(self, trainer, pl_module):
652686
653687
overfit_pct
654688
^^^^^^^^^^^
655-
Uses this much data of all datasets (training, validation, test).
689+
690+
.. warning:: .. deprecated:: 0.8.0.
691+
692+
Use `overfit_batches`. Will remove 1.0.0.
693+
694+
overfit_batches
695+
^^^^^^^^^^^^^^^
696+
Uses this much data of the training set. If will use the same training set for validation and testing.
697+
If the training Dataloaders(shuffle=True), Lightning will automatically disable it.
698+
656699
Useful for quickly debugging or trying to overfit on purpose.
657700
658701
Example::
659702
660703
# default used by the Trainer
661-
trainer = Trainer(overfit_pct=0.0)
662-
663-
# use only 1% of the train, test, val datasets
664-
trainer = Trainer(overfit_pct=0.01)
704+
trainer = Trainer(overfit_batches=0.0)
665705
666-
# equivalent:
667-
trainer = Trainer(
668-
train_percent_check=0.01,
669-
val_percent_check=0.01,
670-
test_percent_check=0.01
671-
)
672-
673-
See Also:
674-
- `train_percent_check`_
675-
- `val_percent_check`_
676-
- `test_percent_check`_
706+
# use only 1% of the train set (and use the train set for val and test)
707+
trainer = Trainer(overfit_batches=0.01)
677708
709+
# overfit on 10 of the same batches
710+
trainer = Trainer(overfit_batches=10)
678711
679712
precision
680713
^^^^^^^^^
@@ -829,39 +862,7 @@ def on_train_end(self, trainer, pl_module):
829862
test_percent_check
830863
^^^^^^^^^^^^^^^^^^
831864
832-
How much of test dataset to check.
833-
834-
Example::
835-
836-
# default used by the Trainer
837-
trainer = Trainer(test_percent_check=1.0)
838-
839-
# run through only 25% of the test set each epoch
840-
trainer = Trainer(test_percent_check=0.25)
841-
842-
val_check_interval
843-
^^^^^^^^^^^^^^^^^^
844-
845-
How often within one training epoch to check the validation set.
846-
Can specify as float or int.
847-
848-
- use (float) to check within a training epoch
849-
- use (int) to check every n steps (batches)
850-
851-
.. code-block:: python
852-
853-
# default used by the Trainer
854-
trainer = Trainer(val_check_interval=1.0)
855-
856-
Example::
857-
858-
# check validation set 4 times during a training epoch
859-
trainer = Trainer(val_check_interval=0.25)
860-
861-
# check validation set every 1000 training batches
862-
# use this when using iterableDataset and your dataset has no length
863-
# (ie: production cases with streaming data)
864-
trainer = Trainer(val_check_interval=1000)
865+
.. warning:: deprecated in v0.8.0 please use `limit_test_batches`. Will remove in 1.0.0
865866
866867
track_grad_norm
867868
^^^^^^^^^^^^^^^
@@ -955,20 +956,36 @@ def tbptt_split_batch(self, batch, split_size):
955956
# do your own splitting on the batch
956957
return splits
957958
959+
val_check_interval
960+
^^^^^^^^^^^^^^^^^^
958961
959-
val_percent_check
960-
^^^^^^^^^^^^^^^^^
962+
How often within one training epoch to check the validation set.
963+
Can specify as float or int.
961964
962-
How much of validation dataset to check.
963-
Useful when debugging or testing something that happens at the end of an epoch.
965+
- use (float) to check within a training epoch
966+
- use (int) to check every n steps (batches)
964967
965-
Example::
968+
.. code-block:: python
966969
967970
# default used by the Trainer
968-
trainer = Trainer(val_percent_check=1.0)
971+
trainer = Trainer(val_check_interval=1.0)
972+
973+
Example::
974+
975+
# check validation set 4 times during a training epoch
976+
trainer = Trainer(val_check_interval=0.25)
977+
978+
# check validation set every 1000 training batches
979+
# use this when using iterableDataset and your dataset has no length
980+
# (ie: production cases with streaming data)
981+
trainer = Trainer(val_check_interval=1000)
982+
983+
984+
val_percent_check
985+
^^^^^^^^^^^^^^^^^
986+
987+
.. warning:: deprecated in v0.8.0 please use `limit_val_batches`. Will remove in 1.0.0
969988
970-
# run through only 25% of the validation set each epoch
971-
trainer = Trainer(val_percent_check=0.25)
972989
973990
weights_save_path
974991
^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)