diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 2ee9a169a300..b2f16aef5825 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -273,37 +273,39 @@ def test_gradient_checkpointing(self): model = self.model_class(**init_dict) model.to(torch_device) + assert not model.is_gradient_checkpointing and model.training + out = model(**inputs_dict).sample # run the backwards pass on the model. For backwards pass, for simplicity purpose, # we won't calculate the loss and rather backprop on out.sum() model.zero_grad() - out.sum().backward() - # now we save the output and parameter gradients that we will use for comparison purposes with - # the non-checkpointed run. - output_not_checkpointed = out.data.clone() - grad_not_checkpointed = {} - for name, param in model.named_parameters(): - grad_not_checkpointed[name] = param.grad.data.clone() + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() - model.enable_gradient_checkpointing() - out = model(**inputs_dict).sample + # re-instantiate the model now enabling gradient checkpointing + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict).sample # run the backwards pass on the model. For backwards pass, for simplicity purpose, # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - out.sum().backward() - - # now we save the output and parameter gradients that we will use for comparison purposes with - # the non-checkpointed run. - output_checkpointed = out.data.clone() - grad_checkpointed = {} - for name, param in model.named_parameters(): - grad_checkpointed[name] = param.grad.data.clone() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() # compare the output and parameters gradients - self.assertTrue((output_checkpointed == output_not_checkpointed).all()) - for name in grad_checkpointed: - self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) + self.assertTrue((loss - loss_2).abs() < 1e-5) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + for name, param in named_params.items(): + self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) # TODO(Patrick) - Re-add this test after having cleaned up LDM