Skip to content

Commit 3fc29a2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent cfc69b7 commit 3fc29a2

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

src/pytorch_lightning/callbacks/progress/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
222222
if not trainer.is_global_zero:
223223
self.disable()
224224

225-
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str, float, Dict[str, float]]]:
225+
def get_metrics(
226+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
227+
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
226228
r"""
227229
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
228230
Implement this to override the items displayed in the progress bar.

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,9 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
439439
return dataloader
440440

441441
@staticmethod
442-
def _resolve_overfit_batches(dataloaders: Union[Collection[DataLoader], DataLoader], mode: RunningStage) -> Collection[DataLoader]:
442+
def _resolve_overfit_batches(
443+
dataloaders: Union[Collection[DataLoader], DataLoader], mode: RunningStage
444+
) -> Collection[DataLoader]:
443445
all_have_sequential_sampler = True
444446

445447
def resolve_has_no_sequential_sampler(dataloader: DataLoader):

src/pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -634,12 +634,11 @@ def _setup_on_init(self) -> None:
634634
self.num_test_batches: List[Union[int, float]] = []
635635
self.num_val_batches: List[Union[int, float]] = []
636636
self.num_predict_batches: List[Union[int, float]] = []
637-
637+
638638
self.test_dataloaders: Optional[List[DataLoader]] = None
639639
self.val_dataloaders: Optional[List[DataLoader]] = None
640640
self._last_train_dl_reload_epoch = float("-inf")
641641
self._last_val_dl_reload_epoch = float("-inf")
642-
643642

644643
def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
645644
r"""
@@ -986,7 +985,7 @@ def _predict_impl(
986985
self.state.status = TrainerStatus.RUNNING
987986
self.predicting = True
988987

989-
self.predict_loop.return_predictions = return_predictions # type: ignore
988+
self.predict_loop.return_predictions = return_predictions # type: ignore
990989

991990
# if a datamodule comes in as the second arg, then fix it for the user
992991
if isinstance(dataloaders, LightningDataModule):
@@ -1451,7 +1450,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
14511450
" or last checkpoint available. No checkpoint will be loaded."
14521451
)
14531452
return None
1454-
ckpt_path = max(candidates_ts.keys(), key=partial(operator.getitem, candidates_ts)) # type: ignore
1453+
ckpt_path = max(candidates_ts.keys(), key=partial(operator.getitem, candidates_ts)) # type: ignore
14551454

14561455
if not ckpt_path:
14571456
raise MisconfigurationException(
@@ -2128,7 +2127,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:
21282127
return self.device_ids if isinstance(self.accelerator, CUDAAccelerator) else None
21292128

21302129
@property
2131-
def lightning_module(self) -> "pl.LightningModule": # type: ignore
2130+
def lightning_module(self) -> "pl.LightningModule": # type: ignore
21322131
# TODO: this is actually an optional return
21332132
assert self.strategy.lightning_module is not None
21342133
return self.strategy.lightning_module

0 commit comments

Comments
 (0)