Skip to content

Commit e0fece2

Browse files
authored
Add final latent slice checks to SD pipeline intermediate state tests (#731)
This is to ensure that the final latent slices stay somewhat consistent as more changes are introduced into the library. Signed-off-by: James R T <[email protected]> Signed-off-by: James R T <[email protected]>
1 parent 75bb6d2 commit e0fece2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/test_pipelines.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,14 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
17331733
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
17341734
)
17351735
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1736+
elif step == 50:
1737+
latents = latents.detach().cpu().numpy()
1738+
assert latents.shape == (1, 4, 64, 64)
1739+
latents_slice = latents[0, -3:, -3:, -1]
1740+
expected_slice = np.array(
1741+
[1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601]
1742+
)
1743+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
17361744

17371745
test_callback_fn.has_been_called = False
17381746

@@ -1773,6 +1781,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
17731781
latents_slice = latents[0, -3:, -3:, -1]
17741782
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
17751783
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1784+
elif step == 37:
1785+
latents = latents.detach().cpu().numpy()
1786+
assert latents.shape == (1, 4, 64, 96)
1787+
latents_slice = latents[0, -3:, -3:, -1]
1788+
expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828])
1789+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
17761790

17771791
test_callback_fn.has_been_called = False
17781792

@@ -1823,6 +1837,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
18231837
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
18241838
)
18251839
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1840+
elif step == 37:
1841+
latents = latents.detach().cpu().numpy()
1842+
assert latents.shape == (1, 4, 64, 64)
1843+
latents_slice = latents[0, -3:, -3:, -1]
1844+
expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659])
1845+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
18261846

18271847
test_callback_fn.has_been_called = False
18281848

@@ -1875,6 +1895,13 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
18751895
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
18761896
)
18771897
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1898+
elif step == 5:
1899+
assert latents.shape == (1, 4, 64, 64)
1900+
latents_slice = latents[0, -3:, -3:, -1]
1901+
expected_slice = np.array(
1902+
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
1903+
)
1904+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
18781905

18791906
test_callback_fn.has_been_called = False
18801907

0 commit comments

Comments
 (0)