@@ -397,21 +397,22 @@ def test_inpaint_dpm(self):
397397
398398class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests (unittest .TestCase ):
399399 def test_pil_inputs (self ):
400- im = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
400+ height , width = 32 , 32
401+ im = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
401402 im = Image .fromarray (im )
402- mask = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
403+ mask = np .random .randint (0 , 255 , (height , width ), dtype = np .uint8 ) > 127.5
403404 mask = Image .fromarray ((mask * 255 ).astype (np .uint8 ))
404405
405- t_mask , t_masked = prepare_mask_and_masked_image (im , mask )
406+ t_mask , t_masked = prepare_mask_and_masked_image (im , mask , height , width )
406407
407408 self .assertTrue (isinstance (t_mask , torch .Tensor ))
408409 self .assertTrue (isinstance (t_masked , torch .Tensor ))
409410
410411 self .assertEqual (t_mask .ndim , 4 )
411412 self .assertEqual (t_masked .ndim , 4 )
412413
413- self .assertEqual (t_mask .shape , (1 , 1 , 32 , 32 ))
414- self .assertEqual (t_masked .shape , (1 , 3 , 32 , 32 ))
414+ self .assertEqual (t_mask .shape , (1 , 1 , height , width ))
415+ self .assertEqual (t_masked .shape , (1 , 3 , height , width ))
415416
416417 self .assertTrue (t_mask .dtype == torch .float32 )
417418 self .assertTrue (t_masked .dtype == torch .float32 )
@@ -424,141 +425,165 @@ def test_pil_inputs(self):
424425 self .assertTrue (t_mask .sum () > 0.0 )
425426
426427 def test_np_inputs (self ):
427- im_np = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
428+ height , width = 32 , 32
429+
430+ im_np = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
428431 im_pil = Image .fromarray (im_np )
429- mask_np = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
432+ mask_np = np .random .randint (0 , 255 , (height , width , ), dtype = np .uint8 ) > 127.5
430433 mask_pil = Image .fromarray ((mask_np * 255 ).astype (np .uint8 ))
431434
432- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
433- t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil )
435+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
436+ t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil , height , width )
434437
435438 self .assertTrue ((t_mask_np == t_mask_pil ).all ())
436439 self .assertTrue ((t_masked_np == t_masked_pil ).all ())
437440
438441 def test_torch_3D_2D_inputs (self ):
439- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
440- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
442+ height , width = 32 , 32
443+
444+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
445+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
441446 im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
442447 mask_np = mask_tensor .numpy ()
443448
444- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
445- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
449+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
450+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
446451
447452 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
448453 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
449454
450455 def test_torch_3D_3D_inputs (self ):
451- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
452- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
456+ height , width = 32 , 32
457+
458+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
459+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
453460 im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
454461 mask_np = mask_tensor .numpy ()[0 ]
455462
456- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
457- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
463+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
464+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
458465
459466 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
460467 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
461468
462469 def test_torch_4D_2D_inputs (self ):
463- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
464- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
470+ height , width = 32 , 32
471+
472+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
473+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
465474 im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
466475 mask_np = mask_tensor .numpy ()
467476
468- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
469- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
477+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
478+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
470479
471480 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
472481 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
473482
474483 def test_torch_4D_3D_inputs (self ):
475- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
476- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
484+ height , width = 32 , 32
485+
486+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
487+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
477488 im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
478489 mask_np = mask_tensor .numpy ()[0 ]
479490
480- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
481- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
491+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
492+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
482493
483494 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
484495 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
485496
486497 def test_torch_4D_4D_inputs (self ):
487- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
488- mask_tensor = torch .randint (0 , 255 , (1 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
498+ height , width = 32 , 32
499+
500+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
501+ mask_tensor = torch .randint (0 , 255 , (1 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
489502 im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
490503 mask_np = mask_tensor .numpy ()[0 ][0 ]
491504
492- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
493- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
505+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
506+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
494507
495508 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
496509 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
497510
498511 def test_torch_batch_4D_3D (self ):
499- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
500- mask_tensor = torch .randint (0 , 255 , (2 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
512+ height , width = 32 , 32
513+
514+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
515+ mask_tensor = torch .randint (0 , 255 , (2 , height , width ,), dtype = torch .uint8 ) > 127.5
501516
502517 im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
503518 mask_nps = [mask .numpy () for mask in mask_tensor ]
504519
505- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
506- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
520+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
521+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
507522 t_mask_np = torch .cat ([n [0 ] for n in nps ])
508523 t_masked_np = torch .cat ([n [1 ] for n in nps ])
509524
510525 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
511526 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
512527
513528 def test_torch_batch_4D_4D (self ):
514- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
515- mask_tensor = torch .randint (0 , 255 , (2 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
529+ height , width = 32 , 32
530+
531+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
532+ mask_tensor = torch .randint (0 , 255 , (2 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
516533
517534 im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
518535 mask_nps = [mask .numpy ()[0 ] for mask in mask_tensor ]
519536
520- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
521- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
537+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
538+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
522539 t_mask_np = torch .cat ([n [0 ] for n in nps ])
523540 t_masked_np = torch .cat ([n [1 ] for n in nps ])
524541
525542 self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
526543 self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
527544
528545 def test_shape_mismatch (self ):
546+ height , width = 32 , 32
547+
529548 # test height and width
530549 with self .assertRaises (AssertionError ):
531- prepare_mask_and_masked_image (torch .randn (3 , 32 , 32 ), torch .randn (64 , 64 ))
550+ prepare_mask_and_masked_image (torch .randn (3 , height , width , ), torch .randn (64 , 64 ), height , width )
532551 # test batch dim
533552 with self .assertRaises (AssertionError ):
534- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 64 , 64 ))
553+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 64 , 64 ), height , width )
535554 # test batch dim
536555 with self .assertRaises (AssertionError ):
537- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 1 , 64 , 64 ))
556+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 1 , 64 , 64 ), height , width )
538557
539558 def test_type_mismatch (self ):
559+ height , width = 32 , 32
560+
540561 # test tensors-only
541562 with self .assertRaises (TypeError ):
542- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .rand (3 , 32 , 32 ).numpy ())
563+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .rand (3 , height , width , ).numpy (), height , width )
543564 # test tensors-only
544565 with self .assertRaises (TypeError ):
545- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ).numpy (), torch .rand (3 , 32 , 32 ) )
566+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ).numpy (), torch .rand (3 , height , width ,), height , width )
546567
547568 def test_channels_first (self ):
569+ height , width = 32 , 32
570+
548571 # test channels first for 3D tensors
549572 with self .assertRaises (AssertionError ):
550- prepare_mask_and_masked_image (torch .rand (32 , 32 , 3 ), torch .rand (3 , 32 , 32 ) )
573+ prepare_mask_and_masked_image (torch .rand (height , width , 3 ), torch .rand (3 , height , width ,), height , width )
551574
552575 def test_tensor_range (self ):
576+ height , width = 32 , 32
577+
553578 # test im <= 1
554579 with self .assertRaises (ValueError ):
555- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * 2 , torch .rand (32 , 32 ) )
580+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * 2 , torch .rand (height , width ,), height , width )
556581 # test im >= -1
557582 with self .assertRaises (ValueError ):
558- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * (- 2 ), torch .rand (32 , 32 ) )
583+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * (- 2 ), torch .rand (height , width ,), height , width )
559584 # test mask <= 1
560585 with self .assertRaises (ValueError ):
561- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * 2 )
586+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * 2 , height , width )
562587 # test mask >= 0
563588 with self .assertRaises (ValueError ):
564- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * - 1 )
589+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * - 1 , height , width )
0 commit comments