Skip to content
Merged
12 changes: 11 additions & 1 deletion pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,17 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl
model.on_load_checkpoint(checkpoint)

# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'], strict=strict)
keys = model.load_state_dict(checkpoint['state_dict'], strict=strict)

if not strict:
if keys.missing_keys:
rank_zero_warn(
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
)
if keys.unexpected_keys:
rank_zero_warn(
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
)

return model

Expand Down