|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
14 | | -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
| 14 | +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | from monai.config import IgniteInfo |
@@ -88,34 +88,34 @@ class AdversarialTrainer(Trainer): |
88 | 88 |
|
89 | 89 | def __init__( |
90 | 90 | self, |
91 | | - device: Union[torch.device, str], |
| 91 | + device: torch.device | str, |
92 | 92 | max_epochs: int, |
93 | | - train_data_loader: Union[Iterable, DataLoader], |
| 93 | + train_data_loader: Iterable | DataLoader, |
94 | 94 | g_network: torch.nn.Module, |
95 | 95 | g_optimizer: Optimizer, |
96 | 96 | g_loss_function: Callable, |
97 | 97 | recon_loss_function: Callable, |
98 | 98 | d_network: torch.nn.Module, |
99 | 99 | d_optimizer: Optimizer, |
100 | 100 | d_loss_function: Callable, |
101 | | - epoch_length: Optional[int] = None, |
| 101 | + epoch_length: int | None = None, |
102 | 102 | non_blocking: bool = False, |
103 | | - prepare_batch: Union[Callable[[Engine, Any], Any], None] = default_prepare_batch, |
104 | | - iteration_update: Optional[Callable] = None, |
105 | | - g_inferer: Optional[Inferer] = None, |
106 | | - d_inferer: Optional[Inferer] = None, |
107 | | - postprocessing: Optional[Transform] = None, |
108 | | - key_train_metric: Optional[Dict[str, Metric]] = None, |
109 | | - additional_metrics: Optional[Dict[str, Metric]] = None, |
| 103 | + prepare_batch: Callable[[Engine, Any], Any] | None = default_prepare_batch, |
| 104 | + iteration_update: Callable | None = None, |
| 105 | + g_inferer: Inferer | None = None, |
| 106 | + d_inferer: Inferer | None = None, |
| 107 | + postprocessing: Transform | None = None, |
| 108 | + key_train_metric: dict[str, Metric] | None = None, |
| 109 | + additional_metrics: dict[str, Metric] | None = None, |
110 | 110 | metric_cmp_fn: Callable = default_metric_cmp_fn, |
111 | | - train_handlers: Optional[Sequence] = None, |
| 111 | + train_handlers: Sequence | None = None, |
112 | 112 | amp: bool = False, |
113 | | - event_names: Union[List[Union[str, EventEnum]], None] = None, |
114 | | - event_to_attr: Union[dict, None] = None, |
| 113 | + event_names: list[str | EventEnum] | None = None, |
| 114 | + event_to_attr: dict | None = None, |
115 | 115 | decollate: bool = True, |
116 | 116 | optim_set_to_none: bool = False, |
117 | | - to_kwargs: Union[dict, None] = None, |
118 | | - amp_kwargs: Union[dict, None] = None, |
| 117 | + to_kwargs: dict | None = None, |
| 118 | + amp_kwargs: dict | None = None, |
119 | 119 | ): |
120 | 120 | super().__init__( |
121 | 121 | device=device, |
@@ -183,8 +183,8 @@ def _complete_state_dict_user_keys(self) -> None: |
183 | 183 | self._state_dict_user_keys.append("recon_loss_function") |
184 | 184 |
|
185 | 185 | def _iteration( |
186 | | - self, engine: AdversarialTrainer, batchdata: Dict[str, torch.Tensor] |
187 | | - ) -> Dict[str, Union[torch.Tensor, int, float, bool]]: |
| 186 | + self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor] |
| 187 | + ) -> dict[str, torch.Tensor | int | float | bool]: |
188 | 188 | """ |
189 | 189 | Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine. |
190 | 190 | Return below items in a dictionary: |
@@ -219,8 +219,8 @@ def _iteration( |
219 | 219 |
|
220 | 220 | if len(batch) == 2: |
221 | 221 | inputs, targets = batch |
222 | | - args: Tuple = () |
223 | | - kwargs: Dict = {} |
| 222 | + args: tuple = () |
| 223 | + kwargs: dict = {} |
224 | 224 | else: |
225 | 225 | inputs, targets, args, kwargs = batch |
226 | 226 |
|
|
0 commit comments