Skip to content

Commit 798813b

Browse files
committed
Add final latent slice checks to SB pipeline intermediate state tests
This is to ensure that the final latent slices stay somewhat consistent as more changes are introduced into the library. Signed-off-by: James R T <[email protected]>
1 parent 7265dd8 commit 798813b

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/test_pipelines.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,14 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14521452
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
14531453
)
14541454
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1455+
elif step == 50:
1456+
latents = latents.detach().cpu().numpy()
1457+
assert latents.shape == (1, 4, 64, 64)
1458+
latents_slice = latents[0, -3:, -3:, -1]
1459+
expected_slice = np.array(
1460+
[1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601]
1461+
)
1462+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
14551463

14561464
test_callback_fn.has_been_called = False
14571465

@@ -1492,6 +1500,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14921500
latents_slice = latents[0, -3:, -3:, -1]
14931501
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
14941502
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1503+
elif step == 37:
1504+
latents = latents.detach().cpu().numpy()
1505+
assert latents.shape == (1, 4, 64, 96)
1506+
latents_slice = latents[0, -3:, -3:, -1]
1507+
expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828])
1508+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
14951509

14961510
test_callback_fn.has_been_called = False
14971511

@@ -1542,6 +1556,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
15421556
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
15431557
)
15441558
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1559+
elif step == 37:
1560+
latents = latents.detach().cpu().numpy()
1561+
assert latents.shape == (1, 4, 64, 64)
1562+
latents_slice = latents[0, -3:, -3:, -1]
1563+
expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659])
1564+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
15451565

15461566
test_callback_fn.has_been_called = False
15471567

@@ -1594,6 +1614,13 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
15941614
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
15951615
)
15961616
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
1617+
elif step == 5:
1618+
assert latents.shape == (1, 4, 64, 64)
1619+
latents_slice = latents[0, -3:, -3:, -1]
1620+
expected_slice = np.array(
1621+
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
1622+
)
1623+
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
15971624

15981625
test_callback_fn.has_been_called = False
15991626

0 commit comments

Comments
 (0)