Skip to content

Commit 84b9df5

Browse files
authored
[gradient checkpointing] lower tolerance for test (#652)
* lowe tolerance * put model in eval mode
1 parent 210be4f commit 84b9df5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_models_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def prepare_init_args_and_inputs_for_common(self):
199199

200200
def test_gradient_checkpointing(self):
201201
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
202-
model = self.model_class(**init_dict)
202+
model = self.model_class(**init_dict).eval()
203203
model.to(torch_device)
204204

205205
out = model(**inputs_dict).sample

0 commit comments

Comments
 (0)