Skip to content

Commit fe05ea2

Browse files
committed
Fix intermediate state tests for Stable Diffusion pipelines
Signed-off-by: James R T <[email protected]>
1 parent 53a845a commit fe05ea2

File tree

1 file changed

+51
-39
lines changed

1 file changed

+51
-39
lines changed

tests/test_pipelines.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,31 +1446,35 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14461446
nonlocal number_of_steps
14471447
number_of_steps += 1
14481448
if step == 0:
1449-
latents = np.array(latents)
1449+
latents = latents.detach().cpu().numpy()
14501450
assert latents.shape == (1, 4, 64, 64)
14511451
latents_slice = latents[0, -3:, -3:, -1]
14521452
expected_slice = np.array(
1453-
[-1.2277, -0.3692, -0.2123, -1.3709, -1.4505, -0.6718, -0.3112, -1.2481, -1.0674]
1453+
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
14541454
)
14551455
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
14561456

14571457
test_callback_fn.has_been_called = False
14581458

1459-
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
1459+
pipe = StableDiffusionPipeline.from_pretrained(
1460+
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
1461+
)
14601462
pipe.to(torch_device)
14611463
pipe.set_progress_bar_config(disable=None)
1464+
pipe.enable_attention_slicing()
14621465

14631466
prompt = "Andromeda galaxy in a bottle"
14641467

14651468
generator = torch.Generator(device=torch_device).manual_seed(0)
1466-
pipe(
1467-
prompt=prompt,
1468-
num_inference_steps=50,
1469-
guidance_scale=7.5,
1470-
generator=generator,
1471-
callback=test_callback_fn,
1472-
callback_steps=1,
1473-
)
1469+
with torch.autocast(torch_device):
1470+
pipe(
1471+
prompt=prompt,
1472+
num_inference_steps=50,
1473+
guidance_scale=7.5,
1474+
generator=generator,
1475+
callback=test_callback_fn,
1476+
callback_steps=1,
1477+
)
14741478
assert test_callback_fn.has_been_called
14751479
assert number_of_steps == 51
14761480

@@ -1484,10 +1488,10 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14841488
nonlocal number_of_steps
14851489
number_of_steps += 1
14861490
if step == 0:
1487-
latents = np.array(latents)
1491+
latents = latents.detach().cpu().numpy()
14881492
assert latents.shape == (1, 4, 64, 96)
14891493
latents_slice = latents[0, -3:, -3:, -1]
1490-
expected_slice = np.array([0.5486, 0.8705, 1.4053, 1.6771, 2.0729, 0.7256, 1.5693, -0.1298, -1.3520])
1494+
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
14911495
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
14921496

14931497
test_callback_fn.has_been_called = False
@@ -1498,23 +1502,27 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
14981502
)
14991503
init_image = init_image.resize((768, 512))
15001504

1501-
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
1505+
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
1506+
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
1507+
)
15021508
pipe.to(torch_device)
15031509
pipe.set_progress_bar_config(disable=None)
1510+
pipe.enable_attention_slicing()
15041511

15051512
prompt = "A fantasy landscape, trending on artstation"
15061513

15071514
generator = torch.Generator(device=torch_device).manual_seed(0)
1508-
pipe(
1509-
prompt=prompt,
1510-
init_image=init_image,
1511-
strength=0.75,
1512-
num_inference_steps=50,
1513-
guidance_scale=7.5,
1514-
generator=generator,
1515-
callback=test_callback_fn,
1516-
callback_steps=1,
1517-
)
1515+
with torch.autocast(torch_device):
1516+
pipe(
1517+
prompt=prompt,
1518+
init_image=init_image,
1519+
strength=0.75,
1520+
num_inference_steps=50,
1521+
guidance_scale=7.5,
1522+
generator=generator,
1523+
callback=test_callback_fn,
1524+
callback_steps=1,
1525+
)
15181526
assert test_callback_fn.has_been_called
15191527
assert number_of_steps == 38
15201528

@@ -1528,11 +1536,11 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
15281536
nonlocal number_of_steps
15291537
number_of_steps += 1
15301538
if step == 0:
1531-
latents = np.array(latents)
1539+
latents = latents.detach().cpu().numpy()
15321540
assert latents.shape == (1, 4, 64, 64)
15331541
latents_slice = latents[0, -3:, -3:, -1]
15341542
expected_slice = np.array(
1535-
[-0.4155, -0.4140, 1.1430, -2.0722, 2.2523, -1.8766, -0.4917, 0.3338, 0.9667]
1543+
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
15361544
)
15371545
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
15381546

@@ -1547,24 +1555,28 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
15471555
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
15481556
)
15491557

1550-
pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
1558+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
1559+
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
1560+
)
15511561
pipe.to(torch_device)
15521562
pipe.set_progress_bar_config(disable=None)
1563+
pipe.enable_attention_slicing()
15531564

15541565
prompt = "A red cat sitting on a park bench"
15551566

15561567
generator = torch.Generator(device=torch_device).manual_seed(0)
1557-
pipe(
1558-
prompt=prompt,
1559-
init_image=init_image,
1560-
mask_image=mask_image,
1561-
strength=0.75,
1562-
num_inference_steps=50,
1563-
guidance_scale=7.5,
1564-
generator=generator,
1565-
callback=test_callback_fn,
1566-
callback_steps=1,
1567-
)
1568+
with torch.autocast(torch_device):
1569+
pipe(
1570+
prompt=prompt,
1571+
init_image=init_image,
1572+
mask_image=mask_image,
1573+
strength=0.75,
1574+
num_inference_steps=50,
1575+
guidance_scale=7.5,
1576+
generator=generator,
1577+
callback=test_callback_fn,
1578+
callback_steps=1,
1579+
)
15681580
assert test_callback_fn.has_been_called
15691581
assert number_of_steps == 38
15701582

@@ -1587,7 +1599,7 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
15871599
test_callback_fn.has_been_called = False
15881600

15891601
pipe = StableDiffusionOnnxPipeline.from_pretrained(
1590-
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CUDAExecutionProvider"
1602+
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CPUExecutionProvider"
15911603
)
15921604
pipe.set_progress_bar_config(disable=None)
15931605

0 commit comments

Comments
 (0)