Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 64 additions & 123 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import collections
import inspect
import os
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
Expand Down Expand Up @@ -30,12 +29,22 @@
else:
XLA_AVAILABLE = True

try:
import omegaconf
except ImportError:
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True

class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks):

class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

#: Current dtype
self.dtype = torch.FloatTensor

self.exp_save_path = None

#: The current epoch
Expand All @@ -53,6 +62,10 @@ def __init__(self, *args, **kwargs):
self.logger = None
self.example_input_array = None

#: True if your model is currently running on GPUs.
#: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
self.on_gpu = False

#: True if using dp
self.use_dp = False

Expand All @@ -67,20 +80,6 @@ def __init__(self, *args, **kwargs):

self.hparams = None

#: Current dtype
self._dtype = torch.float

#: device reference
self._device = torch.device('cpu')

@property
def on_gpu(self):
"""
True if your model is currently running on GPUs.
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
"""
return self.device.type == 'cuda'

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down Expand Up @@ -266,7 +265,6 @@ def training_epoch_end(
May contain the following optional keys:

- log (metrics to be added to the logger; only tensors)
- progress_bar (dict for progress bar display)
- any metric used in a callback (e.g. early stopping).

Note:
Expand All @@ -290,8 +288,7 @@ def training_epoch_end(self, outputs):

# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item()},
'progress_bar': {'train_acc': train_acc_mean},
'log': {'train_acc': train_acc_mean.item()}
}
return results

Expand All @@ -314,7 +311,6 @@ def training_epoch_end(self, outputs):
# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
'progress_bar': {'train_acc': train_acc_mean},
}
return results
"""
Expand Down Expand Up @@ -887,7 +883,7 @@ def configure_ddp(self, model, device_ids):

def _init_slurm_connection(self) -> None:
"""
Sets up environment variables necessary for pytorch distributed communications
Sets up environemnt variables necessary for pytorch distributed communications
based on slurm environment.
"""
# use slurm job id for the port number
Expand Down Expand Up @@ -942,19 +938,16 @@ def init_ddp_connection(
if 'MASTER_ADDR' not in os.environ:
log.warning("MASTER_ADDR environment variable is not defined. Set as localhost")
os.environ['MASTER_ADDR'] = '127.0.0.1'
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

if 'MASTER_PORT' not in os.environ:
log.warning("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ['MASTER_PORT'] = '12910'
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored.")
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
log.warning("WORLD_SIZE environment variable is not equal to the computed "
"world size. Ignored.")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing proc_rank {proc_rank} world {world_size}")
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)

def configure_apex(
Expand Down Expand Up @@ -1175,9 +1168,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,

# native amp + lbfgs is a no go right now
if self.trainer.use_amp and self.trainer.use_native_amp:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \
'a Github issue in PyTorch and tag @mcarilli'
raise MisconfigurationException(m)
optimizer.step(second_order_closure)
else:
if self.trainer.use_amp and self.trainer.use_native_amp:
Expand Down Expand Up @@ -1444,93 +1437,50 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
def load_from_checkpoint(
cls,
checkpoint_path: str,
*args,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
hparams_file: Optional[str] = None,
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
hparam_overrides: Optional[Dict] = None,
**kwargs
tags_csv: Optional[str] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
with an argument called ``hparams`` which is an object of :class:`~dict` or
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
when parsing command line arguments).
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.

Example:
.. code-block:: python

# define hparams as Namespace
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})

model = MyModel(hparams)

class MyModel(LightningModule):
def __init__(self, hparams: Namespace):
def __init__(self, hparams):
self.learning_rate = hparams.learning_rate

# ----------

# define hparams as dict
hparams = {
drop_prob: 0.2,
dataloader: {
batch_size: 32
}
}

model = MyModel(hparams)

class MyModel(LightningModule):
def __init__(self, hparams: dict):
self.learning_rate = hparams['learning_rate']

Args:
checkpoint_path: Path to checkpoint.
args: Any positional args needed to init the model.
model_args: Any keyword args needed to init the model.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in :func:`torch.load`.
hparams_file: Optional path to a .yaml file with hierarchical structure
tags_csv: Optional path to a .csv file with two columns (key, value)
as in this example::

drop_prob: 0.2
dataloader:
batch_size: 32
key,value
drop_prob,0.2
batch_size,32

You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .yaml file with the hparams you'd like to use.
These will be converted into a :class:`~dict` and passed into your
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
:class:`LightningModule` for use.

If your model's `hparams` argument is :class:`~argparse.Namespace`
and .yaml file has hierarchical structure, you need to refactor your model to treat
`hparams` as :class:`~dict`.

.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
tags_csv:
.. warning:: .. deprecated:: 0.7.6

`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.

Optional path to a .csv file with two columns (key, value)
as in this example::

key,value
drop_prob,0.2
batch_size,32

Use this method to pass in a .csv file with the hparams you'd like to use.
hparam_overrides: A dictionary with keys to override in the hparams
kwargs: Any keyword args needed to init the model.

Return:
:class:`LightningModule` with loaded weights and hyperparameters (if available).

Expand All @@ -1550,13 +1500,7 @@ def __init__(self, hparams: dict):
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
MyLightningModule.load_from_checkpoint(
PATH,
hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
tags_csv='/path/to/hparams_file.csv'
)

# or load passing whatever args the model takes to load
Expand All @@ -1577,28 +1521,11 @@ def __init__(self, hparams: dict):
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
if tags_csv is not None:
hparams_file = tags_csv
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)

if hparams_file is not None:
extension = hparams_file.split('.')[-1]
if extension.lower() in ('csv'):
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ('yml', 'yaml'):
hparams = load_hparams_from_yaml(hparams_file)
else:
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')

hparams['on_gpu'] = False

# overwrite hparams by the given file
checkpoint['hparams'] = hparams

# override the hparam keys that were passed in
if hparam_overrides is not None:
update_hparams(hparams, hparam_overrides)
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint, *args, **kwargs)
return model
Expand All @@ -1607,21 +1534,36 @@ def __init__(self, hparams: dict):
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'LightningModule':
cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters
ckpt_hparams = checkpoint.get('hparams')
print(f"CKPT HPARAMS: {ckpt_hparams}")
print(f"CKPT Type: {type(ckpt_hparams)}")

if cls_takes_hparams:
if ckpt_hparams is not None:
hparams_type = checkpoint.get('hparams_type', 'Namespace')
if hparams_type.lower() == 'dict':
hparams = ckpt_hparams
elif hparams_type.lower() == 'namespace':
hparams_type = checkpoint.get('hparams_type', None)
if hparams_type == 'Namespace':
hparams = Namespace(**ckpt_hparams)
elif hparams_type == 'DictConfig':
if not OMEGACONF_AVAILABLE:
raise ImportError(
"This checkpoint's hparams were saved with OmegaConf "
"but you don't have it installed here, so we can't load it."
)
hparams = ckpt_hparams
elif hparams_type == 'dict':
hparams = ckpt_hparams
else:
raise ValueError(
f"The hparams in the checkpoint were saved as {hparams_type} "
"but we only support dict, ArgParse Namespace and "
f"OmegaConf DictConfig. Please add support for {hparams_type}!"
)
else:
rank_zero_warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
" contains argument 'hparams'. Will pass in an empty Namespace instead."
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
" Did you forget to store your model hyperparameters in self.hparams?"
)
hparams = {}
hparams = Namespace()
else: # The user's LightningModule does not define a hparams argument
if ckpt_hparams is None:
hparams = None
Expand All @@ -1632,9 +1574,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
)

# load the state_dict on the model automatically
if cls_takes_hparams:
kwargs.update(hparams=hparams)
model = cls(*args, **kwargs)
model_args = [hparams] if hparams else []
model = cls(*model_args, *args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])

# give model a chance to load something
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@
else:
HOROVOD_AVAILABLE = True

try:
import omegaconf
except ImportError:
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True


class TrainerIOMixin(ABC):

Expand Down Expand Up @@ -351,10 +358,13 @@ def dump_checkpoint(self, weights_only: bool = False):
elif isinstance(model.hparams, Namespace):
checkpoint['hparams_type'] = 'Namespace'
checkpoint['hparams'] = vars(model.hparams)
elif OMEGACONF_AVAILABLE and isinstance(model.hparams, omegaconf.DictConfig):
checkpoint['hparams_type'] = 'DictConfig'
checkpoint['hparams'] = model.hparams
else:
raise ValueError(
'The acceptable hparams type is dict or argparse.Namespace,',
f' not {checkpoint["hparams_type"]}'
'The acceptable hparams type is dict, argparse.Namespace,',
f' or omegaconf.DictConfig, not {checkpoint["hparams_type"]}'
)
else:
rank_zero_warn(
Expand Down