|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import inspect |
14 | 15 | import os |
15 | 16 | import sys |
16 | 17 | from functools import partial, update_wrapper |
|
51 | 52 | locals()["ArgumentParser"] = object |
52 | 53 | locals()["Namespace"] = object |
53 | 54 |
|
| 55 | +ModuleType = TypeVar("ModuleType") |
| 56 | + |
54 | 57 |
|
55 | 58 | class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): |
56 | 59 | def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: |
@@ -198,30 +201,6 @@ def add_lr_scheduler_args( |
198 | 201 | self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) |
199 | 202 | self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) |
200 | 203 |
|
201 | | - def class_instantiator(self, class_type, *args, **kwargs): |
202 | | - for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items(): |
203 | | - if issubclass(class_type, base_type): |
204 | | - with given_hyperparameters_context(hparams): |
205 | | - return super().class_instantiator(class_type, *args, **kwargs) |
206 | | - return super().class_instantiator(class_type, *args, **kwargs) |
207 | | - |
208 | | - def instantiate_classes( |
209 | | - self, |
210 | | - cfg: Namespace, |
211 | | - instantiate_groups: bool = True, |
212 | | - hparam_context: Optional[Dict[str, type]] = None, |
213 | | - ) -> Namespace: |
214 | | - if hparam_context: |
215 | | - cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets! |
216 | | - self._hparam_context = {} |
217 | | - for key, base_type in hparam_context.items(): |
218 | | - hparams = cfg_dict.get(key, {}) |
219 | | - self._hparam_context[key] = (base_type, hparams) |
220 | | - init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups) |
221 | | - if hparam_context: |
222 | | - delattr(self, "_hparam_context") |
223 | | - return init |
224 | | - |
225 | 204 |
|
226 | 205 | class SaveConfigCallback(Callback): |
227 | 206 | """Saves a LightningCLI config to the log_dir when training starts. |
@@ -405,6 +384,7 @@ def __init__( |
405 | 384 |
|
406 | 385 | self._set_seed() |
407 | 386 |
|
| 387 | + self._add_instantiators() |
408 | 388 | self.before_instantiate_classes() |
409 | 389 | self.instantiate_classes() |
410 | 390 |
|
@@ -551,18 +531,28 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No |
551 | 531 | else: |
552 | 532 | self.config = parser.parse_args(args) |
553 | 533 |
|
| 534 | + def _add_instantiators(self) -> None: |
| 535 | + self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False)) |
| 536 | + if "subcommand" in self.config: |
| 537 | + self.config_dump = self.config_dump[self.config.subcommand] |
| 538 | + |
| 539 | + self.parser.add_instantiator( |
| 540 | + _InstantiatorFn(cli=self, key="model"), |
| 541 | + _get_module_type(self._model_class), |
| 542 | + subclasses=self.subclass_mode_model, |
| 543 | + ) |
| 544 | + self.parser.add_instantiator( |
| 545 | + _InstantiatorFn(cli=self, key="data"), |
| 546 | + _get_module_type(self._datamodule_class), |
| 547 | + subclasses=self.subclass_mode_data, |
| 548 | + ) |
| 549 | + |
554 | 550 | def before_instantiate_classes(self) -> None: |
555 | 551 | """Implement to run some code before instantiating the classes.""" |
556 | 552 |
|
557 | 553 | def instantiate_classes(self) -> None: |
558 | 554 | """Instantiates the classes and sets their attributes.""" |
559 | | - hparam_prefix = "" |
560 | | - if "subcommand" in self.config: |
561 | | - hparam_prefix = self.config["subcommand"] + "." |
562 | | - hparam_context = {hparam_prefix + "model": self._model_class} |
563 | | - if self.datamodule_class is not None: |
564 | | - hparam_context[hparam_prefix + "data"] = self._datamodule_class |
565 | | - self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context) |
| 555 | + self.config_init = self.parser.instantiate_classes(self.config) |
566 | 556 | self.datamodule = self._get(self.config_init, "data") |
567 | 557 | self.model = self._get(self.config_init, "model") |
568 | 558 | self._add_configure_optimizers_method_to_model(self.subcommand) |
@@ -788,7 +778,20 @@ def _get_short_description(component: object) -> Optional[str]: |
788 | 778 | rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") |
789 | 779 |
|
790 | 780 |
|
791 | | -ModuleType = TypeVar("ModuleType") |
| 781 | +def _get_module_type(value: Union[Callable, type]) -> type: |
| 782 | + if callable(value) and not isinstance(value, type): |
| 783 | + return inspect.signature(value).return_annotation |
| 784 | + return value |
| 785 | + |
| 786 | + |
| 787 | +class _InstantiatorFn: |
| 788 | + def __init__(self, cli: LightningCLI, key: str) -> None: |
| 789 | + self.cli = cli |
| 790 | + self.key = key |
| 791 | + |
| 792 | + def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: |
| 793 | + with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})): |
| 794 | + return class_type(*args, **kwargs) |
792 | 795 |
|
793 | 796 |
|
794 | 797 | def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: |
|
0 commit comments