diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ffa9b0a1359ee..74862735aba61 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -202,7 +202,17 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl model.on_load_checkpoint(checkpoint) # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict'], strict=strict) + keys = model.load_state_dict(checkpoint['state_dict'], strict=strict) + + if not strict: + if keys.missing_keys: + rank_zero_warn( + f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}" + ) + if keys.unexpected_keys: + rank_zero_warn( + f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}" + ) return model