We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 21d570d commit b07538eCopy full SHA for b07538e
tests/test_models_unet.py
@@ -271,7 +271,7 @@ def prepare_init_args_and_inputs_for_common(self):
271
def test_gradient_checkpointing(self):
272
# enable deterministic behavior for gradient checkpointing
273
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
274
- model = self.model_class(**init_dict)
+ model = self.model_class(**init_dict).eval()
275
model.to(torch_device)
276
277
assert not model.is_gradient_checkpointing and model.training
0 commit comments