Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,37 +273,39 @@ def test_gradient_checkpointing(self):
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()

# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_not_checkpointed = out.data.clone()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

model.enable_gradient_checkpointing()
out = model(**inputs_dict).sample
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()

# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
for name in grad_checkpointed:
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))


# TODO(Patrick) - Re-add this test after having cleaned up LDM
Expand Down