Skip to content

Commit 78667fa

Browse files
committed
Change implementation to use add_instantiator.
1 parent 0a7388a commit 78667fa

File tree

4 files changed

+39
-36
lines changed

4 files changed

+39
-36
lines changed

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
matplotlib>3.1, <3.7.3
66
omegaconf >=2.0.5, <2.4.0
77
hydra-core >=1.0.5, <1.4.0
8-
jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/issue-170-class-instantiator
8+
jsonargparse[signatures] >=4.24.0, <4.25.0
99
rich >=12.3.0, <=13.5.2
1010
tensorboardX >=2.2, <=2.6.2 # min version is set by torch.onnx missing attribute

src/lightning/pytorch/cli.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
import sys
1617
from functools import partial, update_wrapper
@@ -51,6 +52,8 @@
5152
locals()["ArgumentParser"] = object
5253
locals()["Namespace"] = object
5354

55+
ModuleType = TypeVar("ModuleType")
56+
5457

5558
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
5659
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
@@ -198,30 +201,6 @@ def add_lr_scheduler_args(
198201
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
199202
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
200203

201-
def class_instantiator(self, class_type, *args, **kwargs):
202-
for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items():
203-
if issubclass(class_type, base_type):
204-
with given_hyperparameters_context(hparams):
205-
return super().class_instantiator(class_type, *args, **kwargs)
206-
return super().class_instantiator(class_type, *args, **kwargs)
207-
208-
def instantiate_classes(
209-
self,
210-
cfg: Namespace,
211-
instantiate_groups: bool = True,
212-
hparam_context: Optional[Dict[str, type]] = None,
213-
) -> Namespace:
214-
if hparam_context:
215-
cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets!
216-
self._hparam_context = {}
217-
for key, base_type in hparam_context.items():
218-
hparams = cfg_dict.get(key, {})
219-
self._hparam_context[key] = (base_type, hparams)
220-
init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups)
221-
if hparam_context:
222-
delattr(self, "_hparam_context")
223-
return init
224-
225204

226205
class SaveConfigCallback(Callback):
227206
"""Saves a LightningCLI config to the log_dir when training starts.
@@ -405,6 +384,7 @@ def __init__(
405384

406385
self._set_seed()
407386

387+
self._add_instantiators()
408388
self.before_instantiate_classes()
409389
self.instantiate_classes()
410390

@@ -551,18 +531,28 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
551531
else:
552532
self.config = parser.parse_args(args)
553533

534+
def _add_instantiators(self) -> None:
535+
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
536+
if "subcommand" in self.config:
537+
self.config_dump = self.config_dump[self.config.subcommand]
538+
539+
self.parser.add_instantiator(
540+
_InstantiatorFn(cli=self, key="model"),
541+
_get_module_type(self._model_class),
542+
subclasses=self.subclass_mode_model,
543+
)
544+
self.parser.add_instantiator(
545+
_InstantiatorFn(cli=self, key="data"),
546+
_get_module_type(self._datamodule_class),
547+
subclasses=self.subclass_mode_data,
548+
)
549+
554550
def before_instantiate_classes(self) -> None:
555551
"""Implement to run some code before instantiating the classes."""
556552

557553
def instantiate_classes(self) -> None:
558554
"""Instantiates the classes and sets their attributes."""
559-
hparam_prefix = ""
560-
if "subcommand" in self.config:
561-
hparam_prefix = self.config["subcommand"] + "."
562-
hparam_context = {hparam_prefix + "model": self._model_class}
563-
if self.datamodule_class is not None:
564-
hparam_context[hparam_prefix + "data"] = self._datamodule_class
565-
self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context)
555+
self.config_init = self.parser.instantiate_classes(self.config)
566556
self.datamodule = self._get(self.config_init, "data")
567557
self.model = self._get(self.config_init, "model")
568558
self._add_configure_optimizers_method_to_model(self.subcommand)
@@ -788,7 +778,20 @@ def _get_short_description(component: object) -> Optional[str]:
788778
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
789779

790780

791-
ModuleType = TypeVar("ModuleType")
781+
def _get_module_type(value: Union[Callable, type]) -> type:
782+
if callable(value) and not isinstance(value, type):
783+
return inspect.signature(value).return_annotation
784+
return value
785+
786+
787+
class _InstantiatorFn:
788+
def __init__(self, cli: LightningCLI, key: str) -> None:
789+
self.cli = cli
790+
self.key = key
791+
792+
def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
793+
with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})):
794+
return class_type(*args, **kwargs)
792795

793796

794797
def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:

src/lightning/pytorch/core/mixins/hparams_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from argparse import Namespace
1818
from contextlib import contextmanager
1919
from contextvars import ContextVar
20-
from typing import Any, List, MutableMapping, Optional, Sequence, Union
20+
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union
2121

2222
from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters
2323

@@ -29,7 +29,7 @@
2929

3030

3131
@contextmanager
32-
def given_hyperparameters_context(value):
32+
def given_hyperparameters_context(value: dict) -> Iterator[None]:
3333
token = given_hyperparameters.set(value)
3434
try:
3535
yield

src/lightning/pytorch/core/saving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _load_state(
123123
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
124124
checkpoint: Dict[str, Any],
125125
strict: Optional[bool] = None,
126-
instantiator=None,
126+
instantiator: Optional[Callable] = None,
127127
**cls_kwargs_new: Any,
128128
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
129129
cls_spec = inspect.getfullargspec(cls.__init__)

0 commit comments

Comments
 (0)