|
15 | 15 | import sys |
16 | 16 | from functools import partial, update_wrapper |
17 | 17 | from types import MethodType |
18 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union |
| 18 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union |
19 | 19 |
|
20 | 20 | import torch |
| 21 | +import yaml |
21 | 22 | from lightning_utilities.core.imports import RequirementCache |
22 | 23 | from lightning_utilities.core.rank_zero import _warn |
23 | 24 | from torch.optim import Optimizer |
|
26 | 27 | from lightning.fabric.utilities.cloud_io import get_filesystem |
27 | 28 | from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER |
28 | 29 | from lightning.pytorch import Callback, LightningDataModule, LightningModule, seed_everything, Trainer |
| 30 | +from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context |
29 | 31 | from lightning.pytorch.utilities.exceptions import MisconfigurationException |
30 | 32 | from lightning.pytorch.utilities.model_helpers import is_overridden |
31 | 33 | from lightning.pytorch.utilities.rank_zero import rank_zero_warn |
@@ -196,6 +198,30 @@ def add_lr_scheduler_args( |
196 | 198 | self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) |
197 | 199 | self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) |
198 | 200 |
|
| 201 | + def class_instantiator(self, class_type, *args, **kwargs): |
| 202 | + for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items(): |
| 203 | + if issubclass(class_type, base_type): |
| 204 | + with given_hyperparameters_context(hparams): |
| 205 | + return super().class_instantiator(class_type, *args, **kwargs) |
| 206 | + return super().class_instantiator(class_type, *args, **kwargs) |
| 207 | + |
| 208 | + def instantiate_classes( |
| 209 | + self, |
| 210 | + cfg: Namespace, |
| 211 | + instantiate_groups: bool = True, |
| 212 | + hparam_context: Optional[Dict[str, type]] = None, |
| 213 | + ) -> Namespace: |
| 214 | + if hparam_context: |
| 215 | + cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets! |
| 216 | + self._hparam_context = {} |
| 217 | + for key, base_type in hparam_context.items(): |
| 218 | + hparams = cfg_dict.get(key, {}) |
| 219 | + self._hparam_context[key] = (base_type, hparams) |
| 220 | + init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups) |
| 221 | + if hparam_context: |
| 222 | + delattr(self, "_hparam_context") |
| 223 | + return init |
| 224 | + |
199 | 225 |
|
200 | 226 | class SaveConfigCallback(Callback): |
201 | 227 | """Saves a LightningCLI config to the log_dir when training starts. |
@@ -530,7 +556,13 @@ def before_instantiate_classes(self) -> None: |
530 | 556 |
|
531 | 557 | def instantiate_classes(self) -> None: |
532 | 558 | """Instantiates the classes and sets their attributes.""" |
533 | | - self.config_init = self.parser.instantiate_classes(self.config) |
| 559 | + hparam_prefix = "" |
| 560 | + if "subcommand" in self.config: |
| 561 | + hparam_prefix = self.config["subcommand"] + "." |
| 562 | + hparam_context = {hparam_prefix + "model": self._model_class} |
| 563 | + if self.datamodule_class is not None: |
| 564 | + hparam_context[hparam_prefix + "data"] = self._datamodule_class |
| 565 | + self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context) |
534 | 566 | self.datamodule = self._get(self.config_init, "data") |
535 | 567 | self.model = self._get(self.config_init, "model") |
536 | 568 | self._add_configure_optimizers_method_to_model(self.subcommand) |
@@ -754,3 +786,17 @@ def _get_short_description(component: object) -> Optional[str]: |
754 | 786 | return docstring.short_description |
755 | 787 | except (ValueError, docstring_parser.ParseError) as ex: |
756 | 788 | rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") |
| 789 | + |
| 790 | + |
| 791 | +ModuleType = TypeVar("ModuleType") |
| 792 | + |
| 793 | + |
| 794 | +def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: |
| 795 | + parser = ArgumentParser(exit_on_error=False) |
| 796 | + if "class_path" in config: |
| 797 | + parser.add_subclass_arguments(class_type, "module") |
| 798 | + else: |
| 799 | + parser.add_class_arguments(class_type, "module") |
| 800 | + cfg = parser.parse_object({"module": config}) |
| 801 | + init = parser.instantiate_classes(cfg) |
| 802 | + return init.module |
0 commit comments