From 4b524e030028854d4a960c99fb756a4b97ff504d Mon Sep 17 00:00:00 2001 From: James R T Date: Wed, 5 Oct 2022 20:20:36 +0800 Subject: [PATCH] Add final latent slice checks to SD pipeline intermediate state tests This is to ensure that the final latent slices stay somewhat consistent as more changes are introduced into the library. Signed-off-by: James R T --- tests/test_pipelines.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 78a22ec3138b..00419e548f7f 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1452,6 +1452,14 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 50: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1492,6 +1500,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530]) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1542,6 +1556,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1594,6 +1614,13 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: [-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 5: + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False