@@ -303,6 +303,25 @@ def test_inpaint_compile(self):
303303 assert np .abs (expected_slice - image_slice ).max () < 1e-4
304304 assert np .abs (expected_slice - image_slice ).max () < 1e-3
305305
306+ def test_stable_diffusion_inpaint_pil_input_resolution_test (self ):
307+ pipe = StableDiffusionInpaintPipeline .from_pretrained (
308+ "runwayml/stable-diffusion-inpainting" , safety_checker = None
309+ )
310+ pipe .scheduler = LMSDiscreteScheduler .from_config (pipe .scheduler .config )
311+ pipe .to (torch_device )
312+ pipe .set_progress_bar_config (disable = None )
313+ pipe .enable_attention_slicing ()
314+
315+ inputs = self .get_inputs (torch_device )
316+ # change input image to a random size (one that would cause a tensor mismatch error)
317+ inputs ['image' ] = inputs ['image' ].resize ((127 ,127 ))
318+ inputs ['mask_image' ] = inputs ['mask_image' ].resize ((127 ,127 ))
319+ inputs ['height' ] = 128
320+ inputs ['width' ] = 128
321+ image = pipe (** inputs ).images
322+ # verify that the returned image has the same height and width as the input height and width
323+ assert image .shape == (1 , inputs ['height' ], inputs ['width' ], 3 )
324+
306325
307326@nightly
308327@require_torch_gpu
@@ -400,21 +419,22 @@ def test_inpaint_dpm(self):
400419
401420class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests (unittest .TestCase ):
402421 def test_pil_inputs (self ):
403- im = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
422+ height , width = 32 , 32
423+ im = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
404424 im = Image .fromarray (im )
405- mask = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
425+ mask = np .random .randint (0 , 255 , (height , width ), dtype = np .uint8 ) > 127.5
406426 mask = Image .fromarray ((mask * 255 ).astype (np .uint8 ))
407427
408- t_mask , t_masked = prepare_mask_and_masked_image (im , mask )
428+ t_mask , t_masked = prepare_mask_and_masked_image (im , mask , height , width )
409429
410430 self .assertTrue (isinstance (t_mask , torch .Tensor ))
411431 self .assertTrue (isinstance (t_masked , torch .Tensor ))
412432
413433 self .assertEqual (t_mask .ndim , 4 )
414434 self .assertEqual (t_masked .ndim , 4 )
415435
416- self .assertEqual (t_mask .shape , (1 , 1 , 32 , 32 ))
417- self .assertEqual (t_masked .shape , (1 , 3 , 32 , 32 ))
436+ self .assertEqual (t_mask .shape , (1 , 1 , height , width ))
437+ self .assertEqual (t_masked .shape , (1 , 3 , height , width ))
418438
419439 self .assertTrue (t_mask .dtype == torch .float32 )
420440 self .assertTrue (t_masked .dtype == torch .float32 )
@@ -427,141 +447,165 @@ def test_pil_inputs(self):
427447 self .assertTrue (t_mask .sum () > 0.0 )
428448
429449 def test_np_inputs (self ):
430- im_np = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
450+ height , width = 32 , 32
451+
452+ im_np = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
431453 im_pil = Image .fromarray (im_np )
432- mask_np = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
454+ mask_np = np .random .randint (0 , 255 , (height , width , ), dtype = np .uint8 ) > 127.5
433455 mask_pil = Image .fromarray ((mask_np * 255 ).astype (np .uint8 ))
434456
435- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
436- t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil )
457+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
458+ t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil , height , width )
437459
438460 self .assertTrue ((t_mask_np == t_mask_pil ).all ())
439461 self .assertTrue ((t_masked_np == t_masked_pil ).all ())
440462
441463 def test_torch_3D_2D_inputs (self ):
442- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
443- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
464+ height , width = 32 , 32
465+
466+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
467+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
444468 im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
445469 mask_np = mask_tensor .numpy ()
446470
447- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
448- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
471+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
472+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
449473
450474 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
451475 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
452476
453477 def test_torch_3D_3D_inputs (self ):
454- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
455- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
478+ height , width = 32 , 32
479+
480+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
481+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
456482 im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
457483 mask_np = mask_tensor .numpy ()[0 ]
458484
459- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
460- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
485+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
486+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
461487
462488 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
463489 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
464490
465491 def test_torch_4D_2D_inputs (self ):
466- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
467- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
492+ height , width = 32 , 32
493+
494+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
495+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
468496 im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
469497 mask_np = mask_tensor .numpy ()
470498
471- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
472- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
499+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
500+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
473501
474502 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
475503 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
476504
477505 def test_torch_4D_3D_inputs (self ):
478- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
479- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
506+ height , width = 32 , 32
507+
508+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
509+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
480510 im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
481511 mask_np = mask_tensor .numpy ()[0 ]
482512
483- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
484- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
513+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
514+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
485515
486516 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
487517 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
488518
489519 def test_torch_4D_4D_inputs (self ):
490- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
491- mask_tensor = torch .randint (0 , 255 , (1 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
520+ height , width = 32 , 32
521+
522+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
523+ mask_tensor = torch .randint (0 , 255 , (1 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
492524 im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
493525 mask_np = mask_tensor .numpy ()[0 ][0 ]
494526
495- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
496- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
527+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
528+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
497529
498530 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
499531 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
500532
501533 def test_torch_batch_4D_3D (self ):
502- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
503- mask_tensor = torch .randint (0 , 255 , (2 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
534+ height , width = 32 , 32
535+
536+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
537+ mask_tensor = torch .randint (0 , 255 , (2 , height , width ,), dtype = torch .uint8 ) > 127.5
504538
505539 im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
506540 mask_nps = [mask .numpy () for mask in mask_tensor ]
507541
508- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
509- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
542+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
543+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
510544 t_mask_np = torch .cat ([n [0 ] for n in nps ])
511545 t_masked_np = torch .cat ([n [1 ] for n in nps ])
512546
513547 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
514548 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
515549
516550 def test_torch_batch_4D_4D (self ):
517- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
518- mask_tensor = torch .randint (0 , 255 , (2 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
551+ height , width = 32 , 32
552+
553+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
554+ mask_tensor = torch .randint (0 , 255 , (2 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
519555
520556 im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
521557 mask_nps = [mask .numpy ()[0 ] for mask in mask_tensor ]
522558
523- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
524- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
559+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
560+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
525561 t_mask_np = torch .cat ([n [0 ] for n in nps ])
526562 t_masked_np = torch .cat ([n [1 ] for n in nps ])
527563
528564 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
529565 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
530566
531567 def test_shape_mismatch (self ):
568+ height , width = 32 , 32
569+
532570 # test height and width
533571 with self .assertRaises (AssertionError ):
534- prepare_mask_and_masked_image (torch .randn (3 , 32 , 32 ), torch .randn (64 , 64 ))
572+ prepare_mask_and_masked_image (torch .randn (3 , height , width , ), torch .randn (64 , 64 ), height , width )
535573 # test batch dim
536574 with self .assertRaises (AssertionError ):
537- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 64 , 64 ))
575+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 64 , 64 ), height , width )
538576 # test batch dim
539577 with self .assertRaises (AssertionError ):
540- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 1 , 64 , 64 ))
578+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 1 , 64 , 64 ), height , width )
541579
542580 def test_type_mismatch (self ):
581+ height , width = 32 , 32
582+
543583 # test tensors-only
544584 with self .assertRaises (TypeError ):
545- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .rand (3 , 32 , 32 ).numpy ())
585+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .rand (3 , height , width , ).numpy (), height , width )
546586 # test tensors-only
547587 with self .assertRaises (TypeError ):
548- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ).numpy (), torch .rand (3 , 32 , 32 ) )
588+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ).numpy (), torch .rand (3 , height , width ,), height , width )
549589
550590 def test_channels_first (self ):
591+ height , width = 32 , 32
592+
551593 # test channels first for 3D tensors
552594 with self .assertRaises (AssertionError ):
553- prepare_mask_and_masked_image (torch .rand (32 , 32 , 3 ), torch .rand (3 , 32 , 32 ) )
595+ prepare_mask_and_masked_image (torch .rand (height , width , 3 ), torch .rand (3 , height , width ,), height , width )
554596
555597 def test_tensor_range (self ):
598+ height , width = 32 , 32
599+
556600 # test im <= 1
557601 with self .assertRaises (ValueError ):
558- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * 2 , torch .rand (32 , 32 ) )
602+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * 2 , torch .rand (height , width ,), height , width )
559603 # test im >= -1
560604 with self .assertRaises (ValueError ):
561- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * (- 2 ), torch .rand (32 , 32 ) )
605+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * (- 2 ), torch .rand (height , width ,), height , width )
562606 # test mask <= 1
563607 with self .assertRaises (ValueError ):
564- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * 2 )
608+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * 2 , height , width )
565609 # test mask >= 0
566610 with self .assertRaises (ValueError ):
567- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * - 1 )
611+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * - 1 , height , width )
0 commit comments