11import collections
22import inspect
33import os
4- import warnings
54from abc import ABC , abstractmethod
65from argparse import Namespace
76from typing import Any , Callable , Dict , List , Optional , Tuple , Union , Sequence
3029else :
3130 XLA_AVAILABLE = True
3231
32+ try :
33+ import omegaconf
34+ except ImportError :
35+ OMEGACONF_AVAILABLE = False
36+ else :
37+ OMEGACONF_AVAILABLE = True
3338
34- class LightningModule (ABC , DeviceDtypeModuleMixin , GradInformation , ModelIO , ModelHooks ):
39+
40+ class LightningModule (ABC , GradInformation , ModelIO , ModelHooks ):
3541
3642 def __init__ (self , * args , ** kwargs ):
3743 super ().__init__ (* args , ** kwargs )
3844
45+ #: Current dtype
46+ self .dtype = torch .FloatTensor
47+
3948 self .exp_save_path = None
4049
4150 #: The current epoch
@@ -53,6 +62,10 @@ def __init__(self, *args, **kwargs):
5362 self .logger = None
5463 self .example_input_array = None
5564
65+ #: True if your model is currently running on GPUs.
66+ #: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
67+ self .on_gpu = False
68+
5669 #: True if using dp
5770 self .use_dp = False
5871
@@ -67,20 +80,6 @@ def __init__(self, *args, **kwargs):
6780
6881 self .hparams = None
6982
70- #: Current dtype
71- self ._dtype = torch .float
72-
73- #: device reference
74- self ._device = torch .device ('cpu' )
75-
76- @property
77- def on_gpu (self ):
78- """
79- True if your model is currently running on GPUs.
80- Useful to set flags around the LightningModule for different CPU vs GPU behavior.
81- """
82- return self .device .type == 'cuda'
83-
8483 def print (self , * args , ** kwargs ) -> None :
8584 r"""
8685 Prints only from process 0. Use this in any distributed mode to log only once.
@@ -266,7 +265,6 @@ def training_epoch_end(
266265 May contain the following optional keys:
267266
268267 - log (metrics to be added to the logger; only tensors)
269- - progress_bar (dict for progress bar display)
270268 - any metric used in a callback (e.g. early stopping).
271269
272270 Note:
@@ -290,8 +288,7 @@ def training_epoch_end(self, outputs):
290288
291289 # log training accuracy at the end of an epoch
292290 results = {
293- 'log': {'train_acc': train_acc_mean.item()},
294- 'progress_bar': {'train_acc': train_acc_mean},
291+ 'log': {'train_acc': train_acc_mean.item()}
295292 }
296293 return results
297294
@@ -314,7 +311,6 @@ def training_epoch_end(self, outputs):
314311 # log training accuracy at the end of an epoch
315312 results = {
316313 'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
317- 'progress_bar': {'train_acc': train_acc_mean},
318314 }
319315 return results
320316 """
@@ -887,7 +883,7 @@ def configure_ddp(self, model, device_ids):
887883
888884 def _init_slurm_connection (self ) -> None :
889885 """
890- Sets up environment variables necessary for pytorch distributed communications
886+ Sets up environemnt variables necessary for pytorch distributed communications
891887 based on slurm environment.
892888 """
893889 # use slurm job id for the port number
@@ -942,19 +938,16 @@ def init_ddp_connection(
942938 if 'MASTER_ADDR' not in os .environ :
943939 log .warning ("MASTER_ADDR environment variable is not defined. Set as localhost" )
944940 os .environ ['MASTER_ADDR' ] = '127.0.0.1'
945- log .debug (f"MASTER_ADDR: { os .environ ['MASTER_ADDR' ]} " )
946941
947942 if 'MASTER_PORT' not in os .environ :
948943 log .warning ("MASTER_PORT environment variable is not defined. Set as 12910" )
949944 os .environ ['MASTER_PORT' ] = '12910'
950- log .debug (f"MASTER_PORT: { os .environ ['MASTER_PORT' ]} " )
951945
952- if 'WORLD_SIZE' in os .environ and int ( os .environ ['WORLD_SIZE' ]) != world_size :
953- log .warning (f "WORLD_SIZE environment variable ( { os . environ [ 'WORLD_SIZE' ] } ) "
954- f"is not equal to the computed world size ( { world_size } ) . Ignored." )
946+ if 'WORLD_SIZE' in os .environ and os .environ ['WORLD_SIZE' ] != world_size :
947+ log .warning ("WORLD_SIZE environment variable is not equal to the computed "
948+ " world size. Ignored." )
955949
956950 torch_backend = "nccl" if self .trainer .on_gpu else "gloo"
957- log .info (f"initializing proc_rank { proc_rank } world { world_size } " )
958951 torch_distrib .init_process_group (torch_backend , rank = proc_rank , world_size = world_size )
959952
960953 def configure_apex (
@@ -1175,9 +1168,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
11751168
11761169 # native amp + lbfgs is a no go right now
11771170 if self .trainer .use_amp and self .trainer .use_native_amp :
1178- raise MisconfigurationException (
1179- 'native PyTorch amp and lbfgs are not compatible. '
1180- ' To request, please file a Github issue in PyTorch and tag @mcarilli' )
1171+ m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \
1172+ 'a Github issue in PyTorch and tag @mcarilli '
1173+ raise MisconfigurationException ( m )
11811174 optimizer .step (second_order_closure )
11821175 else :
11831176 if self .trainer .use_amp and self .trainer .use_native_amp :
@@ -1444,93 +1437,50 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
14441437 def load_from_checkpoint (
14451438 cls ,
14461439 checkpoint_path : str ,
1447- * args ,
14481440 map_location : Optional [Union [Dict [str , str ], str , torch .device , int , Callable ]] = None ,
1449- hparams_file : Optional [str ] = None ,
1450- tags_csv : Optional [str ] = None , # backward compatible, todo: remove in v0.9.0
1451- hparam_overrides : Optional [Dict ] = None ,
1452- ** kwargs
1441+ tags_csv : Optional [str ] = None ,
1442+ * args , ** kwargs
14531443 ) -> 'LightningModule' :
14541444 r"""
14551445 Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
14561446 it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
1457- with an argument called ``hparams`` which is an object of :class:`~dict` or
1458- :class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
1459- when parsing command line arguments).
1460- If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
1447+ with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
1448+ (output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
14611449 Any other arguments specified through \*args and \*\*kwargs will be passed to the model.
14621450
14631451 Example:
14641452 .. code-block:: python
14651453
1466- # define hparams as Namespace
14671454 from argparse import Namespace
14681455 hparams = Namespace(**{'learning_rate': 0.1})
14691456
14701457 model = MyModel(hparams)
14711458
14721459 class MyModel(LightningModule):
1473- def __init__(self, hparams: Namespace ):
1460+ def __init__(self, hparams):
14741461 self.learning_rate = hparams.learning_rate
14751462
1476- # ----------
1477-
1478- # define hparams as dict
1479- hparams = {
1480- drop_prob: 0.2,
1481- dataloader: {
1482- batch_size: 32
1483- }
1484- }
1485-
1486- model = MyModel(hparams)
1487-
1488- class MyModel(LightningModule):
1489- def __init__(self, hparams: dict):
1490- self.learning_rate = hparams['learning_rate']
1491-
14921463 Args:
14931464 checkpoint_path: Path to checkpoint.
1494- args : Any positional args needed to init the model.
1465+ model_args : Any keyword args needed to init the model.
14951466 map_location:
14961467 If your checkpoint saved a GPU model and you now load on CPUs
14971468 or a different number of GPUs, use this to map to the new setup.
14981469 The behaviour is the same as in :func:`torch.load`.
1499- hparams_file : Optional path to a .yaml file with hierarchical structure
1470+ tags_csv : Optional path to a .csv file with two columns (key, value)
15001471 as in this example::
15011472
1502- drop_prob: 0.2
1503- dataloader:
1504- batch_size: 32
1473+ key,value
1474+ drop_prob,0.2
1475+ batch_size, 32
15051476
15061477 You most likely won't need this since Lightning will always save the hyperparameters
15071478 to the checkpoint.
15081479 However, if your checkpoint weights don't have the hyperparameters saved,
1509- use this method to pass in a .yaml file with the hparams you'd like to use.
1510- These will be converted into a :class:`~dict ` and passed into your
1480+ use this method to pass in a .csv file with the hparams you'd like to use.
1481+ These will be converted into a :class:`~argparse.Namespace ` and passed into your
15111482 :class:`LightningModule` for use.
15121483
1513- If your model's `hparams` argument is :class:`~argparse.Namespace`
1514- and .yaml file has hierarchical structure, you need to refactor your model to treat
1515- `hparams` as :class:`~dict`.
1516-
1517- .csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
1518- tags_csv:
1519- .. warning:: .. deprecated:: 0.7.6
1520-
1521- `tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
1522-
1523- Optional path to a .csv file with two columns (key, value)
1524- as in this example::
1525-
1526- key,value
1527- drop_prob,0.2
1528- batch_size,32
1529-
1530- Use this method to pass in a .csv file with the hparams you'd like to use.
1531- hparam_overrides: A dictionary with keys to override in the hparams
1532- kwargs: Any keyword args needed to init the model.
1533-
15341484 Return:
15351485 :class:`LightningModule` with loaded weights and hyperparameters (if available).
15361486
@@ -1550,13 +1500,7 @@ def __init__(self, hparams: dict):
15501500 # or load weights and hyperparameters from separate files.
15511501 MyLightningModule.load_from_checkpoint(
15521502 'path/to/checkpoint.ckpt',
1553- hparams_file='/path/to/hparams_file.yaml'
1554- )
1555-
1556- # override some of the params with new values
1557- MyLightningModule.load_from_checkpoint(
1558- PATH,
1559- hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
1503+ tags_csv='/path/to/hparams_file.csv'
15601504 )
15611505
15621506 # or load passing whatever args the model takes to load
@@ -1577,28 +1521,11 @@ def __init__(self, hparams: dict):
15771521 else :
15781522 checkpoint = torch .load (checkpoint_path , map_location = lambda storage , loc : storage )
15791523
1580- # add the hparams from csv file to checkpoint
15811524 if tags_csv is not None :
1582- hparams_file = tags_csv
1583- rank_zero_warn ('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0' , DeprecationWarning )
1584-
1585- if hparams_file is not None :
1586- extension = hparams_file .split ('.' )[- 1 ]
1587- if extension .lower () in ('csv' ):
1588- hparams = load_hparams_from_tags_csv (hparams_file )
1589- elif extension .lower () in ('yml' , 'yaml' ):
1590- hparams = load_hparams_from_yaml (hparams_file )
1591- else :
1592- raise ValueError ('.csv, .yml or .yaml is required for `hparams_file`' )
1593-
1594- hparams ['on_gpu' ] = False
1595-
1596- # overwrite hparams by the given file
1597- checkpoint ['hparams' ] = hparams
1598-
1599- # override the hparam keys that were passed in
1600- if hparam_overrides is not None :
1601- update_hparams (hparams , hparam_overrides )
1525+ # add the hparams from csv file to checkpoint
1526+ hparams = load_hparams_from_tags_csv (tags_csv )
1527+ hparams .__setattr__ ('on_gpu' , False )
1528+ checkpoint ['hparams' ] = vars (hparams )
16021529
16031530 model = cls ._load_model_state (checkpoint , * args , ** kwargs )
16041531 return model
@@ -1607,21 +1534,36 @@ def __init__(self, hparams: dict):
16071534 def _load_model_state (cls , checkpoint : Dict [str , Any ], * args , ** kwargs ) -> 'LightningModule' :
16081535 cls_takes_hparams = 'hparams' in inspect .signature (cls .__init__ ).parameters
16091536 ckpt_hparams = checkpoint .get ('hparams' )
1537+ print (f"CKPT HPARAMS: { ckpt_hparams } " )
1538+ print (f"CKPT Type: { type (ckpt_hparams )} " )
16101539
16111540 if cls_takes_hparams :
16121541 if ckpt_hparams is not None :
1613- hparams_type = checkpoint .get ('hparams_type' , 'Namespace' )
1614- if hparams_type .lower () == 'dict' :
1615- hparams = ckpt_hparams
1616- elif hparams_type .lower () == 'namespace' :
1542+ hparams_type = checkpoint .get ('hparams_type' , None )
1543+ if hparams_type == 'Namespace' :
16171544 hparams = Namespace (** ckpt_hparams )
1545+ elif hparams_type == 'DictConfig' :
1546+ if not OMEGACONF_AVAILABLE :
1547+ raise ImportError (
1548+ "This checkpoint's hparams were saved with OmegaConf "
1549+ "but you don't have it installed here, so we can't load it."
1550+ )
1551+ hparams = ckpt_hparams
1552+ elif hparams_type == 'dict' :
1553+ hparams = ckpt_hparams
1554+ else :
1555+ raise ValueError (
1556+ f"The hparams in the checkpoint were saved as { hparams_type } "
1557+ "but we only support dict, ArgParse Namespace and "
1558+ f"OmegaConf DictConfig. Please add support for { hparams_type } !"
1559+ )
16181560 else :
16191561 rank_zero_warn (
1620- f"Checkpoint does not contain hyperparameters but { cls .__name__ } 's __init__"
1621- " contains argument 'hparams'. Will pass in an empty Namespace instead."
1562+ f"Checkpoint does not contain hyperparameters but { cls .__name__ } 's __init__ "
1563+ f" contains argument 'hparams'. Will pass in an empty Namespace instead."
16221564 " Did you forget to store your model hyperparameters in self.hparams?"
16231565 )
1624- hparams = {}
1566+ hparams = Namespace ()
16251567 else : # The user's LightningModule does not define a hparams argument
16261568 if ckpt_hparams is None :
16271569 hparams = None
@@ -1632,9 +1574,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
16321574 )
16331575
16341576 # load the state_dict on the model automatically
1635- if cls_takes_hparams :
1636- kwargs .update (hparams = hparams )
1637- model = cls (* args , ** kwargs )
1577+ model_args = [hparams ] if hparams else []
1578+ model = cls (* model_args , * args , ** kwargs )
16381579 model .load_state_dict (checkpoint ['state_dict' ])
16391580
16401581 # give model a chance to load something
0 commit comments