Skip to content

Commit 7ccf7e7

Browse files
authored
Merge branch 'master' into codeq/tensorboard-logger
2 parents 894a623 + e451fa2 commit 7ccf7e7

File tree

23 files changed

+125
-71
lines changed

23 files changed

+125
-71
lines changed

docs/source-app/code_samples/quickstart/app/app_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,5 @@ def run(self):
8888
# Step 4: download a dataset to your local directory under `/data`
8989
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
9090

91-
# Initalize your Lightning app with 5 epochs
91+
# Initialize your Lightning app with 5 epochs
9292
app = L.LightningApp(RootFlow(5, "./data/hymenoptera_data"))

docs/source-app/code_samples/quickstart/hello_world/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ def run(self):
1212
print("Hello World!")
1313

1414

15-
# Step 3: Initalize a LightningApp with the LightningFlow you defined (in step 1)
15+
# Step 3: Initialize a LightningApp with the LightningFlow you defined (in step 1)
1616
app = L.LightningApp(HelloWorld())

docs/source-app/get_started/go_beyond_training_content.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ Implement the ``configure_layout`` method to connect them together:
308308

309309
5: Init the ``app`` object
310310
^^^^^^^^^^^^^^^^^^^^^^^^^^
311-
Initalize an ``app`` object with the ``TrainDeploy`` component (this won't run the App yet):
311+
Initialize an ``app`` object with the ``TrainDeploy`` component (this won't run the App yet):
312312

313313
.. code:: python
314314
:emphasize-lines: 29

docs/source-pytorch/common/checkpointing_basic.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ The LightningModule also has access to the Hyperparameters
106106
107107
----
108108

109-
Initalize with other parameters
110-
===============================
109+
Initialize with other parameters
110+
================================
111111
If you used the *self.save_hyperparameters()* method in the init of the LightningModule, you can initialize the model with different hyperparameters.
112112

113113
.. code-block:: python

docs/source-pytorch/common/trainer.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,8 @@ How often within one training epoch to check the validation set.
14791479
Can specify as float or int.
14801480

14811481
- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
1482-
- pass an ``int`` to check after a fixed number of training batches.
1482+
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
1483+
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.
14831484

14841485
.. testcode::
14851486

@@ -1489,10 +1490,13 @@ Can specify as float or int.
14891490
# check validation set 4 times during a training epoch
14901491
trainer = Trainer(val_check_interval=0.25)
14911492

1492-
# check validation set every 1000 training batches
1493+
# check validation set every 1000 training batches in the current epoch
1494+
trainer = Trainer(val_check_interval=1000)
1495+
1496+
# check validation set every 1000 training batches across complete epochs or during iteration-based training
14931497
# use this when using iterableDataset and your dataset has no length
14941498
# (ie: production cases with streaming data)
1495-
trainer = Trainer(val_check_interval=1000)
1499+
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
14961500

14971501

14981502
.. code-block:: python

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ warn_no_return = "False"
4747
# the list can be generated with:
4848
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
4949
module = [
50-
"pytorch_lightning.callbacks.model_checkpoint",
5150
"pytorch_lightning.callbacks.progress.rich_progress",
5251
"pytorch_lightning.callbacks.quantization",
5352
"pytorch_lightning.callbacks.stochastic_weight_avg",

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
144144
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
145145

146146

147-
-
147+
- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
148148

149149

150150
### Deprecated
@@ -345,6 +345,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
345345
- Fixed `Trainer.predict(return_predictions=False)` to track prediction's batch_indices ([#13629](https://github.com/Lightning-AI/lightning/pull/13629))
346346

347347

348+
- Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
349+
350+
348351
## [1.6.5] - 2022-07-13
349352

350353
### Fixed

src/pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
135135
# validation, then we run after validation instead of on train epoch end
136136
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
137137

138-
def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
138+
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
139139
monitor_val = logs.get(self.monitor)
140140

141141
error_msg = (

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4040
from pytorch_lightning.utilities.logger import _name, _version
4141
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
42-
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
42+
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
4343
from pytorch_lightning.utilities.warnings import WarningCache
4444

4545
log = logging.getLogger(__name__)
@@ -231,13 +231,14 @@ def __init__(
231231
self._save_on_train_epoch_end = save_on_train_epoch_end
232232
self._last_global_step_saved = 0 # no need to save when no steps were taken
233233
self._last_time_checked: Optional[float] = None
234-
self.current_score = None
235-
self.best_k_models = {}
234+
self.current_score: Optional[Tensor] = None
235+
self.best_k_models: Dict[str, Tensor] = {}
236236
self.kth_best_model_path = ""
237-
self.best_model_score = None
237+
self.best_model_score: Optional[Tensor] = None
238238
self.best_model_path = ""
239239
self.last_model_path = ""
240240

241+
self.kth_value: Tensor
241242
self.__init_monitor_mode(mode)
242243
self.__init_ckpt_dir(dirpath, filename)
243244
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
@@ -256,6 +257,7 @@ def state_key(self) -> str:
256257

257258
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
258259
self.__resolve_ckpt_dir(trainer)
260+
assert self.dirpath is not None
259261
if trainer.is_global_zero and stage == "fit":
260262
self.__warn_if_dir_not_empty(self.dirpath)
261263

@@ -362,7 +364,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
362364
self._save_topk_checkpoint(trainer, monitor_candidates)
363365
self._save_last_checkpoint(trainer, monitor_candidates)
364366

365-
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
367+
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
366368
if self.save_top_k == 0:
367369
return
368370

@@ -395,7 +397,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
395397
from pytorch_lightning.trainer.states import TrainerFn
396398

397399
return (
398-
trainer.fast_dev_run # disable checkpointing with fast_dev_run
400+
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
399401
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
400402
or trainer.sanity_checking # don't save anything during sanity check
401403
or self._last_global_step_saved == trainer.global_step # already saved at the last step
@@ -493,15 +495,15 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
493495
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
494496

495497
# If using multiple devices, make sure all processes are unanimous on the decision.
496-
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
498+
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
497499

498500
return should_update_best_and_save
499501

500502
@classmethod
501503
def _format_checkpoint_name(
502504
cls,
503505
filename: Optional[str],
504-
metrics: Dict[str, _METRIC],
506+
metrics: Dict[str, Tensor],
505507
prefix: str = "",
506508
auto_insert_metric_name: bool = True,
507509
) -> str:
@@ -522,7 +524,7 @@ def _format_checkpoint_name(
522524
filename = filename.replace(group, f"{{0[{name}]")
523525

524526
if name not in metrics:
525-
metrics[name] = 0
527+
metrics[name] = torch.tensor(0)
526528
filename = filename.format(metrics)
527529

528530
if prefix:
@@ -531,7 +533,7 @@ def _format_checkpoint_name(
531533
return filename
532534

533535
def format_checkpoint_name(
534-
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
536+
self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
535537
) -> str:
536538
"""Generate a filename according to the defined template.
537539
@@ -591,6 +593,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
591593
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
592594
elif trainer.loggers:
593595
if len(trainer.loggers) == 1:
596+
assert trainer.logger is not None
594597
save_dir = trainer.logger.save_dir or trainer.default_root_dir
595598
else:
596599
save_dir = trainer.default_root_dir
@@ -613,7 +616,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
613616
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
614617

615618
def _get_metric_interpolated_filepath_name(
616-
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
619+
self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None
617620
) -> str:
618621
filepath = self.format_checkpoint_name(monitor_candidates)
619622

@@ -624,7 +627,7 @@ def _get_metric_interpolated_filepath_name(
624627

625628
return filepath
626629

627-
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
630+
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
628631
monitor_candidates = deepcopy(trainer.callback_metrics)
629632
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
630633
# or does not exist we overwrite it as it's likely an error
@@ -634,7 +637,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
634637
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
635638
return monitor_candidates
636639

637-
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
640+
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
638641
if not self.save_last:
639642
return
640643

@@ -651,16 +654,18 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
651654
if previous and previous != filepath:
652655
trainer.strategy.remove_checkpoint(previous)
653656

654-
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
657+
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
658+
assert self.monitor
655659
current = monitor_candidates.get(self.monitor)
656660
if self.check_monitor_top_k(trainer, current):
661+
assert current is not None
657662
self._update_best_and_save(current, trainer, monitor_candidates)
658663
elif self.verbose:
659664
epoch = monitor_candidates["epoch"]
660665
step = monitor_candidates["step"]
661666
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
662667

663-
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
668+
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
664669
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
665670
# set the best model path before saving because it will be part of the state.
666671
previous, self.best_model_path = self.best_model_path, filepath
@@ -669,7 +674,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
669674
trainer.strategy.remove_checkpoint(previous)
670675

671676
def _update_best_and_save(
672-
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
677+
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
673678
) -> None:
674679
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
675680

@@ -691,11 +696,11 @@ def _update_best_and_save(
691696
if len(self.best_k_models) == k:
692697
# monitor dict has reached k elements
693698
_op = max if self.mode == "min" else min
694-
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
699+
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
695700
self.kth_value = self.best_k_models[self.kth_best_model_path]
696701

697702
_op = min if self.mode == "min" else max
698-
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
703+
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
699704
self.best_model_score = self.best_k_models[self.best_model_path]
700705

701706
if self.verbose:
@@ -715,6 +720,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
715720
file."""
716721
best_k = {k: v.item() for k, v in self.best_k_models.items()}
717722
if filepath is None:
723+
assert self.dirpath
718724
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
719725
with self._fs.open(filepath, "w") as fp:
720726
yaml.dump(best_k, fp)

src/pytorch_lightning/callbacks/progress/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,27 @@ def total_val_batches(self) -> Union[int, float]:
172172
assert self._trainer is not None
173173
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0
174174

175+
@property
176+
def total_batches_current_epoch(self) -> Union[int, float]:
177+
total_train_batches = self.total_train_batches
178+
total_val_batches = self.total_val_batches
179+
assert self._trainer is not None
180+
181+
if total_train_batches != float("inf") and total_val_batches != float("inf"):
182+
# val can be checked multiple times per epoch
183+
val_check_batch = self.trainer.val_check_batch
184+
if self.trainer.check_val_every_n_epoch is None:
185+
train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1
186+
val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - (
187+
train_batches_processed // val_check_batch
188+
)
189+
else:
190+
val_checks_per_epoch = total_train_batches // val_check_batch
191+
192+
total_val_batches = total_val_batches * val_checks_per_epoch
193+
194+
return total_train_batches + total_val_batches
195+
175196
def has_dataloader_changed(self, dataloader_idx: int) -> bool:
176197
old_dataloader_idx = self._current_eval_dataloader_idx
177198
self._current_eval_dataloader_idx = dataloader_idx

0 commit comments

Comments
 (0)