|
14 | 14 | """ |
15 | 15 | Convention: |
16 | 16 | - Do not include any `_TYPE` suffix |
17 | | - - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) |
| 17 | + - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) |
18 | 18 | """ |
19 | 19 | from pathlib import Path |
20 | | -from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union |
| 20 | +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union |
21 | 21 |
|
22 | 22 | import torch |
23 | | -from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau |
| 23 | +from torch.optim import Optimizer |
24 | 24 | from torch.utils.data import DataLoader |
25 | 25 | from torchmetrics import Metric |
| 26 | +from typing_extensions import TypedDict |
26 | 27 |
|
27 | 28 | _NUMBER = Union[int, float] |
28 | 29 | _METRIC = Union[Metric, torch.Tensor, _NUMBER] |
|
43 | 44 | Dict[str, Sequence[DataLoader]], |
44 | 45 | ] |
45 | 46 | EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] |
| 47 | + |
| 48 | + |
| 49 | +# Copied from `torch.optim.lr_scheduler.pyi` |
| 50 | +# Missing attributes were added to improve typing |
| 51 | +class _LRScheduler: |
| 52 | + optimizer: Optimizer |
| 53 | + |
| 54 | + def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: |
| 55 | + ... |
| 56 | + |
| 57 | + def state_dict(self) -> dict: |
| 58 | + ... |
| 59 | + |
| 60 | + def load_state_dict(self, state_dict: dict) -> None: |
| 61 | + ... |
| 62 | + |
| 63 | + def get_last_lr(self) -> List[float]: |
| 64 | + ... |
| 65 | + |
| 66 | + def get_lr(self) -> float: |
| 67 | + ... |
| 68 | + |
| 69 | + def step(self, epoch: Optional[int] = ...) -> None: |
| 70 | + ... |
| 71 | + |
| 72 | + |
| 73 | +# Copied from `torch.optim.lr_scheduler.pyi` |
| 74 | +# Missing attributes were added to improve typing |
| 75 | +class ReduceLROnPlateau: |
| 76 | + in_cooldown: bool |
| 77 | + optimizer: Optimizer |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + optimizer: Optimizer, |
| 82 | + mode: str = ..., |
| 83 | + factor: float = ..., |
| 84 | + patience: int = ..., |
| 85 | + verbose: bool = ..., |
| 86 | + threshold: float = ..., |
| 87 | + threshold_mode: str = ..., |
| 88 | + cooldown: int = ..., |
| 89 | + min_lr: float = ..., |
| 90 | + eps: float = ..., |
| 91 | + ) -> None: |
| 92 | + ... |
| 93 | + |
| 94 | + def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: |
| 95 | + ... |
| 96 | + |
| 97 | + def state_dict(self) -> dict: |
| 98 | + ... |
| 99 | + |
| 100 | + def load_state_dict(self, state_dict: dict) -> None: |
| 101 | + ... |
| 102 | + |
| 103 | + |
46 | 104 | # todo: improve LRSchedulerType naming/typing |
47 | | -LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) |
48 | | -LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] |
49 | | -LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] |
| 105 | +LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) |
| 106 | +LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] |
| 107 | +LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] |
| 108 | + |
| 109 | + |
| 110 | +class LRSchedulerConfig(TypedDict): |
| 111 | + scheduler: Union[_LRScheduler, ReduceLROnPlateau] |
| 112 | + name: Optional[str] |
| 113 | + interval: str |
| 114 | + frequency: int |
| 115 | + reduce_on_plateau: bool |
| 116 | + monitor: Optional[str] |
| 117 | + strict: bool |
| 118 | + opt_idx: Optional[int] |
0 commit comments