Skip to content

Commit da7e399

Browse files
authored
Fix vae tests for cpu and gpu (#480)
1 parent 55f7ca3 commit da7e399

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tests/test_models_vae.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,15 @@ def test_output_pretrained(self):
104104

105105
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
106106

107-
# fmt: off
108-
expected_output_slice = torch.tensor([-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026])
109-
# fmt: on
107+
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
108+
# the expected output slices are not the same for CPU and GPU.
109+
if torch_device in ("mps", "cpu"):
110+
expected_output_slice = torch.tensor(
111+
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
112+
)
113+
else:
114+
expected_output_slice = torch.tensor(
115+
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
116+
)
117+
110118
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

0 commit comments

Comments
 (0)