@@ -1452,6 +1452,14 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14521452 [1.8285 , 1.2857 , - 0.1024 , 1.2406 , - 2.3068 , 1.0747 , - 0.0818 , - 0.6520 , - 2.9506 ]
14531453 )
14541454 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1455+ elif step == 50 :
1456+ latents = latents .detach ().cpu ().numpy ()
1457+ assert latents .shape == (1 , 4 , 64 , 64 )
1458+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1459+ expected_slice = np .array (
1460+ [1.1078 , 1.5803 , 0.2773 , - 0.0589 , - 1.7928 , - 0.3665 , - 0.4695 , - 1.0727 , - 1.1601 ]
1461+ )
1462+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
14551463
14561464 test_callback_fn .has_been_called = False
14571465
@@ -1492,6 +1500,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14921500 latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
14931501 expected_slice = np .array ([0.9052 , - 0.0184 , 0.4810 , 0.2898 , 0.5851 , 1.4920 , 0.5362 , 1.9838 , 0.0530 ])
14941502 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1503+ elif step == 37 :
1504+ latents = latents .detach ().cpu ().numpy ()
1505+ assert latents .shape == (1 , 4 , 64 , 96 )
1506+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1507+ expected_slice = np .array ([0.7071 , 0.7831 , 0.8300 , 1.8140 , 1.7840 , 1.9402 , 1.3651 , 1.6590 , 1.2828 ])
1508+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
14951509
14961510 test_callback_fn .has_been_called = False
14971511
@@ -1542,6 +1556,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
15421556 [- 0.5472 , 1.1218 , - 0.5505 , - 0.9390 , - 1.0794 , 0.4063 , 0.5158 , 0.6429 , - 1.5246 ]
15431557 )
15441558 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1559+ elif step == 37 :
1560+ latents = latents .detach ().cpu ().numpy ()
1561+ assert latents .shape == (1 , 4 , 64 , 64 )
1562+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1563+ expected_slice = np .array ([0.4781 , 1.1572 , 0.6258 , 0.2291 , 0.2554 , - 0.1443 , 0.7085 , - 0.1598 , - 0.5659 ])
1564+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
15451565
15461566 test_callback_fn .has_been_called = False
15471567
@@ -1594,6 +1614,13 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
15941614 [- 0.5950 , - 0.3039 , - 1.1672 , 0.1594 , - 1.1572 , 0.6719 , - 1.9712 , - 0.0403 , 0.9592 ]
15951615 )
15961616 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1617+ elif step == 5 :
1618+ assert latents .shape == (1 , 4 , 64 , 64 )
1619+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1620+ expected_slice = np .array (
1621+ [- 0.4776 , - 0.0119 , - 0.8519 , - 0.0275 , - 0.9764 , 0.9820 , - 0.3843 , 0.3788 , 1.2264 ]
1622+ )
1623+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
15971624
15981625 test_callback_fn .has_been_called = False
15991626
0 commit comments