diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 78a22ec3138b..00419e548f7f 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1452,6 +1452,14 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 50: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1492,6 +1500,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530]) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1542,6 +1556,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1594,6 +1614,13 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: [-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 5: + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False