Skip to content

Commit 1afc1ca

Browse files
karthikrangasaijustusschockananthsubcarmocca
authored
Logging Non-matching keys when loading from checkpoint in non-strict … (#8152)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent acb6f26 commit 1afc1ca

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

pytorch_lightning/core/saving.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,17 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl
202202
model.on_load_checkpoint(checkpoint)
203203

204204
# load the state_dict on the model automatically
205-
model.load_state_dict(checkpoint['state_dict'], strict=strict)
205+
keys = model.load_state_dict(checkpoint['state_dict'], strict=strict)
206+
207+
if not strict:
208+
if keys.missing_keys:
209+
rank_zero_warn(
210+
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
211+
)
212+
if keys.unexpected_keys:
213+
rank_zero_warn(
214+
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
215+
)
206216

207217
return model
208218

0 commit comments

Comments
 (0)