@@ -1733,6 +1733,14 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
17331733 [1.8285 , 1.2857 , - 0.1024 , 1.2406 , - 2.3068 , 1.0747 , - 0.0818 , - 0.6520 , - 2.9506 ]
17341734 )
17351735 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1736+ elif step == 50 :
1737+ latents = latents .detach ().cpu ().numpy ()
1738+ assert latents .shape == (1 , 4 , 64 , 64 )
1739+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1740+ expected_slice = np .array (
1741+ [1.1078 , 1.5803 , 0.2773 , - 0.0589 , - 1.7928 , - 0.3665 , - 0.4695 , - 1.0727 , - 1.1601 ]
1742+ )
1743+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
17361744
17371745 test_callback_fn .has_been_called = False
17381746
@@ -1773,6 +1781,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
17731781 latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
17741782 expected_slice = np .array ([0.9052 , - 0.0184 , 0.4810 , 0.2898 , 0.5851 , 1.4920 , 0.5362 , 1.9838 , 0.0530 ])
17751783 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1784+ elif step == 37 :
1785+ latents = latents .detach ().cpu ().numpy ()
1786+ assert latents .shape == (1 , 4 , 64 , 96 )
1787+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1788+ expected_slice = np .array ([0.7071 , 0.7831 , 0.8300 , 1.8140 , 1.7840 , 1.9402 , 1.3651 , 1.6590 , 1.2828 ])
1789+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
17761790
17771791 test_callback_fn .has_been_called = False
17781792
@@ -1823,6 +1837,12 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
18231837 [- 0.5472 , 1.1218 , - 0.5505 , - 0.9390 , - 1.0794 , 0.4063 , 0.5158 , 0.6429 , - 1.5246 ]
18241838 )
18251839 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1840+ elif step == 37 :
1841+ latents = latents .detach ().cpu ().numpy ()
1842+ assert latents .shape == (1 , 4 , 64 , 64 )
1843+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1844+ expected_slice = np .array ([0.4781 , 1.1572 , 0.6258 , 0.2291 , 0.2554 , - 0.1443 , 0.7085 , - 0.1598 , - 0.5659 ])
1845+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
18261846
18271847 test_callback_fn .has_been_called = False
18281848
@@ -1875,6 +1895,13 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
18751895 [- 0.5950 , - 0.3039 , - 1.1672 , 0.1594 , - 1.1572 , 0.6719 , - 1.9712 , - 0.0403 , 0.9592 ]
18761896 )
18771897 assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
1898+ elif step == 5 :
1899+ assert latents .shape == (1 , 4 , 64 , 64 )
1900+ latents_slice = latents [0 , - 3 :, - 3 :, - 1 ]
1901+ expected_slice = np .array (
1902+ [- 0.4776 , - 0.0119 , - 0.8519 , - 0.0275 , - 0.9764 , 0.9820 , - 0.3843 , 0.3788 , 1.2264 ]
1903+ )
1904+ assert np .abs (latents_slice .flatten () - expected_slice ).max () < 1e-3
18781905
18791906 test_callback_fn .has_been_called = False
18801907
0 commit comments