Skip to content

Commit 715cf01

Browse files
DarktexBorda
authored andcommitted
Adding deserialization of OmegaConf
1 parent d73d19c commit 715cf01

File tree

1 file changed

+64
-123
lines changed

1 file changed

+64
-123
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 64 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import collections
22
import inspect
33
import os
4-
import warnings
54
from abc import ABC, abstractmethod
65
from argparse import Namespace
76
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
@@ -30,12 +29,22 @@
3029
else:
3130
XLA_AVAILABLE = True
3231

32+
try:
33+
import omegaconf
34+
except ImportError:
35+
OMEGACONF_AVAILABLE = False
36+
else:
37+
OMEGACONF_AVAILABLE = True
3338

34-
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks):
39+
40+
class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
3541

3642
def __init__(self, *args, **kwargs):
3743
super().__init__(*args, **kwargs)
3844

45+
#: Current dtype
46+
self.dtype = torch.FloatTensor
47+
3948
self.exp_save_path = None
4049

4150
#: The current epoch
@@ -53,6 +62,10 @@ def __init__(self, *args, **kwargs):
5362
self.logger = None
5463
self.example_input_array = None
5564

65+
#: True if your model is currently running on GPUs.
66+
#: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
67+
self.on_gpu = False
68+
5669
#: True if using dp
5770
self.use_dp = False
5871

@@ -67,20 +80,6 @@ def __init__(self, *args, **kwargs):
6780

6881
self.hparams = None
6982

70-
#: Current dtype
71-
self._dtype = torch.float
72-
73-
#: device reference
74-
self._device = torch.device('cpu')
75-
76-
@property
77-
def on_gpu(self):
78-
"""
79-
True if your model is currently running on GPUs.
80-
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
81-
"""
82-
return self.device.type == 'cuda'
83-
8483
def print(self, *args, **kwargs) -> None:
8584
r"""
8685
Prints only from process 0. Use this in any distributed mode to log only once.
@@ -266,7 +265,6 @@ def training_epoch_end(
266265
May contain the following optional keys:
267266
268267
- log (metrics to be added to the logger; only tensors)
269-
- progress_bar (dict for progress bar display)
270268
- any metric used in a callback (e.g. early stopping).
271269
272270
Note:
@@ -290,8 +288,7 @@ def training_epoch_end(self, outputs):
290288
291289
# log training accuracy at the end of an epoch
292290
results = {
293-
'log': {'train_acc': train_acc_mean.item()},
294-
'progress_bar': {'train_acc': train_acc_mean},
291+
'log': {'train_acc': train_acc_mean.item()}
295292
}
296293
return results
297294
@@ -314,7 +311,6 @@ def training_epoch_end(self, outputs):
314311
# log training accuracy at the end of an epoch
315312
results = {
316313
'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
317-
'progress_bar': {'train_acc': train_acc_mean},
318314
}
319315
return results
320316
"""
@@ -887,7 +883,7 @@ def configure_ddp(self, model, device_ids):
887883

888884
def _init_slurm_connection(self) -> None:
889885
"""
890-
Sets up environment variables necessary for pytorch distributed communications
886+
Sets up environemnt variables necessary for pytorch distributed communications
891887
based on slurm environment.
892888
"""
893889
# use slurm job id for the port number
@@ -942,19 +938,16 @@ def init_ddp_connection(
942938
if 'MASTER_ADDR' not in os.environ:
943939
log.warning("MASTER_ADDR environment variable is not defined. Set as localhost")
944940
os.environ['MASTER_ADDR'] = '127.0.0.1'
945-
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
946941

947942
if 'MASTER_PORT' not in os.environ:
948943
log.warning("MASTER_PORT environment variable is not defined. Set as 12910")
949944
os.environ['MASTER_PORT'] = '12910'
950-
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
951945

952-
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
953-
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
954-
f"is not equal to the computed world size ({world_size}). Ignored.")
946+
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
947+
log.warning("WORLD_SIZE environment variable is not equal to the computed "
948+
"world size. Ignored.")
955949

956950
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
957-
log.info(f"initializing proc_rank {proc_rank} world {world_size}")
958951
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)
959952

960953
def configure_apex(
@@ -1175,9 +1168,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
11751168

11761169
# native amp + lbfgs is a no go right now
11771170
if self.trainer.use_amp and self.trainer.use_native_amp:
1178-
raise MisconfigurationException(
1179-
'native PyTorch amp and lbfgs are not compatible.'
1180-
' To request, please file a Github issue in PyTorch and tag @mcarilli')
1171+
m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \
1172+
'a Github issue in PyTorch and tag @mcarilli'
1173+
raise MisconfigurationException(m)
11811174
optimizer.step(second_order_closure)
11821175
else:
11831176
if self.trainer.use_amp and self.trainer.use_native_amp:
@@ -1444,93 +1437,50 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
14441437
def load_from_checkpoint(
14451438
cls,
14461439
checkpoint_path: str,
1447-
*args,
14481440
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
1449-
hparams_file: Optional[str] = None,
1450-
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
1451-
hparam_overrides: Optional[Dict] = None,
1452-
**kwargs
1441+
tags_csv: Optional[str] = None,
1442+
*args, **kwargs
14531443
) -> 'LightningModule':
14541444
r"""
14551445
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
14561446
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
1457-
with an argument called ``hparams`` which is an object of :class:`~dict` or
1458-
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
1459-
when parsing command line arguments).
1460-
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
1447+
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
1448+
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
14611449
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.
14621450
14631451
Example:
14641452
.. code-block:: python
14651453
1466-
# define hparams as Namespace
14671454
from argparse import Namespace
14681455
hparams = Namespace(**{'learning_rate': 0.1})
14691456
14701457
model = MyModel(hparams)
14711458
14721459
class MyModel(LightningModule):
1473-
def __init__(self, hparams: Namespace):
1460+
def __init__(self, hparams):
14741461
self.learning_rate = hparams.learning_rate
14751462
1476-
# ----------
1477-
1478-
# define hparams as dict
1479-
hparams = {
1480-
drop_prob: 0.2,
1481-
dataloader: {
1482-
batch_size: 32
1483-
}
1484-
}
1485-
1486-
model = MyModel(hparams)
1487-
1488-
class MyModel(LightningModule):
1489-
def __init__(self, hparams: dict):
1490-
self.learning_rate = hparams['learning_rate']
1491-
14921463
Args:
14931464
checkpoint_path: Path to checkpoint.
1494-
args: Any positional args needed to init the model.
1465+
model_args: Any keyword args needed to init the model.
14951466
map_location:
14961467
If your checkpoint saved a GPU model and you now load on CPUs
14971468
or a different number of GPUs, use this to map to the new setup.
14981469
The behaviour is the same as in :func:`torch.load`.
1499-
hparams_file: Optional path to a .yaml file with hierarchical structure
1470+
tags_csv: Optional path to a .csv file with two columns (key, value)
15001471
as in this example::
15011472
1502-
drop_prob: 0.2
1503-
dataloader:
1504-
batch_size: 32
1473+
key,value
1474+
drop_prob,0.2
1475+
batch_size,32
15051476
15061477
You most likely won't need this since Lightning will always save the hyperparameters
15071478
to the checkpoint.
15081479
However, if your checkpoint weights don't have the hyperparameters saved,
1509-
use this method to pass in a .yaml file with the hparams you'd like to use.
1510-
These will be converted into a :class:`~dict` and passed into your
1480+
use this method to pass in a .csv file with the hparams you'd like to use.
1481+
These will be converted into a :class:`~argparse.Namespace` and passed into your
15111482
:class:`LightningModule` for use.
15121483
1513-
If your model's `hparams` argument is :class:`~argparse.Namespace`
1514-
and .yaml file has hierarchical structure, you need to refactor your model to treat
1515-
`hparams` as :class:`~dict`.
1516-
1517-
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
1518-
tags_csv:
1519-
.. warning:: .. deprecated:: 0.7.6
1520-
1521-
`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
1522-
1523-
Optional path to a .csv file with two columns (key, value)
1524-
as in this example::
1525-
1526-
key,value
1527-
drop_prob,0.2
1528-
batch_size,32
1529-
1530-
Use this method to pass in a .csv file with the hparams you'd like to use.
1531-
hparam_overrides: A dictionary with keys to override in the hparams
1532-
kwargs: Any keyword args needed to init the model.
1533-
15341484
Return:
15351485
:class:`LightningModule` with loaded weights and hyperparameters (if available).
15361486
@@ -1550,13 +1500,7 @@ def __init__(self, hparams: dict):
15501500
# or load weights and hyperparameters from separate files.
15511501
MyLightningModule.load_from_checkpoint(
15521502
'path/to/checkpoint.ckpt',
1553-
hparams_file='/path/to/hparams_file.yaml'
1554-
)
1555-
1556-
# override some of the params with new values
1557-
MyLightningModule.load_from_checkpoint(
1558-
PATH,
1559-
hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
1503+
tags_csv='/path/to/hparams_file.csv'
15601504
)
15611505
15621506
# or load passing whatever args the model takes to load
@@ -1577,28 +1521,11 @@ def __init__(self, hparams: dict):
15771521
else:
15781522
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
15791523

1580-
# add the hparams from csv file to checkpoint
15811524
if tags_csv is not None:
1582-
hparams_file = tags_csv
1583-
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)
1584-
1585-
if hparams_file is not None:
1586-
extension = hparams_file.split('.')[-1]
1587-
if extension.lower() in ('csv'):
1588-
hparams = load_hparams_from_tags_csv(hparams_file)
1589-
elif extension.lower() in ('yml', 'yaml'):
1590-
hparams = load_hparams_from_yaml(hparams_file)
1591-
else:
1592-
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')
1593-
1594-
hparams['on_gpu'] = False
1595-
1596-
# overwrite hparams by the given file
1597-
checkpoint['hparams'] = hparams
1598-
1599-
# override the hparam keys that were passed in
1600-
if hparam_overrides is not None:
1601-
update_hparams(hparams, hparam_overrides)
1525+
# add the hparams from csv file to checkpoint
1526+
hparams = load_hparams_from_tags_csv(tags_csv)
1527+
hparams.__setattr__('on_gpu', False)
1528+
checkpoint['hparams'] = vars(hparams)
16021529

16031530
model = cls._load_model_state(checkpoint, *args, **kwargs)
16041531
return model
@@ -1607,21 +1534,36 @@ def __init__(self, hparams: dict):
16071534
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'LightningModule':
16081535
cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters
16091536
ckpt_hparams = checkpoint.get('hparams')
1537+
print(f"CKPT HPARAMS: {ckpt_hparams}")
1538+
print(f"CKPT Type: {type(ckpt_hparams)}")
16101539

16111540
if cls_takes_hparams:
16121541
if ckpt_hparams is not None:
1613-
hparams_type = checkpoint.get('hparams_type', 'Namespace')
1614-
if hparams_type.lower() == 'dict':
1615-
hparams = ckpt_hparams
1616-
elif hparams_type.lower() == 'namespace':
1542+
hparams_type = checkpoint.get('hparams_type', None)
1543+
if hparams_type == 'Namespace':
16171544
hparams = Namespace(**ckpt_hparams)
1545+
elif hparams_type == 'DictConfig':
1546+
if not OMEGACONF_AVAILABLE:
1547+
raise ImportError(
1548+
"This checkpoint's hparams were saved with OmegaConf "
1549+
"but you don't have it installed here, so we can't load it."
1550+
)
1551+
hparams = ckpt_hparams
1552+
elif hparams_type == 'dict':
1553+
hparams = ckpt_hparams
1554+
else:
1555+
raise ValueError(
1556+
f"The hparams in the checkpoint were saved as {hparams_type} "
1557+
"but we only support dict, ArgParse Namespace and "
1558+
f"OmegaConf DictConfig. Please add support for {hparams_type}!"
1559+
)
16181560
else:
16191561
rank_zero_warn(
1620-
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
1621-
" contains argument 'hparams'. Will pass in an empty Namespace instead."
1562+
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
1563+
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
16221564
" Did you forget to store your model hyperparameters in self.hparams?"
16231565
)
1624-
hparams = {}
1566+
hparams = Namespace()
16251567
else: # The user's LightningModule does not define a hparams argument
16261568
if ckpt_hparams is None:
16271569
hparams = None
@@ -1632,9 +1574,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
16321574
)
16331575

16341576
# load the state_dict on the model automatically
1635-
if cls_takes_hparams:
1636-
kwargs.update(hparams=hparams)
1637-
model = cls(*args, **kwargs)
1577+
model_args = [hparams] if hparams else []
1578+
model = cls(*model_args, *args, **kwargs)
16381579
model.load_state_dict(checkpoint['state_dict'])
16391580

16401581
# give model a chance to load something

0 commit comments

Comments
 (0)