@@ -451,10 +451,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
451451
452452 def get_views (self , panorama_height , panorama_width , window_size = 64 , stride = 8 ):
453453 # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
454+ # if panorama's height/width < window_size, num_blocks of height/width should return 1
454455 panorama_height /= 8
455456 panorama_width /= 8
456- num_blocks_height = (panorama_height - window_size ) // stride + 1
457- num_blocks_width = (panorama_width - window_size ) // stride + 1
457+ num_blocks_height = (panorama_height - window_size ) // stride + 1 if panorama_height > window_size else 1
458+ num_blocks_width = (panorama_width - window_size ) // stride + 1 if panorama_height > window_size else 1
458459 total_num_blocks = int (num_blocks_height * num_blocks_width )
459460 views = []
460461 for i in range (total_num_blocks ):
@@ -474,6 +475,7 @@ def __call__(
474475 width : Optional [int ] = 2048 ,
475476 num_inference_steps : int = 50 ,
476477 guidance_scale : float = 7.5 ,
478+ view_batch_size : int = 1 ,
477479 negative_prompt : Optional [Union [str , List [str ]]] = None ,
478480 num_images_per_prompt : Optional [int ] = 1 ,
479481 eta : float = 0.0 ,
@@ -508,6 +510,9 @@ def __call__(
508510 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
509511 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
510512 usually at the expense of lower image quality.
513+ view_batch_size (`int`, *optional*, defaults to 1):
514+ The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
515+ can speedup the generation and increase the VRAM usage.
511516 negative_prompt (`str` or `List[str]`, *optional*):
512517 The prompt or prompts not to guide the image generation. If not defined, one has to pass
513518 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -609,8 +614,11 @@ def __call__(
609614 )
610615
611616 # 6. Define panorama grid and initialize views for synthesis.
617+ # prepare batch grid
612618 views = self .get_views (height , width )
613- views_scheduler_status = [copy .deepcopy (self .scheduler .__dict__ )] * len (views )
619+ views_batch = [views [i : i + view_batch_size ] for i in range (0 , len (views ), view_batch_size )]
620+ views_scheduler_status = [copy .deepcopy (self .scheduler .__dict__ )] * len (views_batch )
621+
614622 count = torch .zeros_like (latents )
615623 value = torch .zeros_like (latents )
616624
@@ -631,42 +639,55 @@ def __call__(
631639 # denoised (latent) crops are then averaged to produce the final latent
632640 # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
633641 # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
634- for j , (h_start , h_end , w_start , w_end ) in enumerate (views ):
642+ # Batch views denoise
643+ for j , batch_view in enumerate (views_batch ):
644+ vb_size = len (batch_view )
635645 # get the latents corresponding to the current view coordinates
636- latents_for_view = latents [:, :, h_start :h_end , w_start :w_end ]
646+ latents_for_view = torch .cat (
647+ [latents [:, :, h_start :h_end , w_start :w_end ] for h_start , h_end , w_start , w_end in batch_view ]
648+ )
637649
638650 # rematch block's scheduler status
639651 self .scheduler .__dict__ .update (views_scheduler_status [j ])
640652
641653 # expand the latents if we are doing classifier free guidance
642654 latent_model_input = (
643- torch .cat ([latents_for_view ] * 2 ) if do_classifier_free_guidance else latents_for_view
655+ latents_for_view .repeat_interleave (2 , dim = 0 )
656+ if do_classifier_free_guidance
657+ else latents_for_view
644658 )
645659 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
646660
661+ # repeat prompt_embeds for batch
662+ prompt_embeds_input = torch .cat ([prompt_embeds ] * vb_size )
663+
647664 # predict the noise residual
648665 noise_pred = self .unet (
649666 latent_model_input ,
650667 t ,
651- encoder_hidden_states = prompt_embeds ,
668+ encoder_hidden_states = prompt_embeds_input ,
652669 cross_attention_kwargs = cross_attention_kwargs ,
653670 ).sample
654671
655672 # perform guidance
656673 if do_classifier_free_guidance :
657- noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
674+ noise_pred_uncond , noise_pred_text = noise_pred [:: 2 ], noise_pred [ 1 :: 2 ]
658675 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
659676
660677 # compute the previous noisy sample x_t -> x_t-1
661- latents_view_denoised = self .scheduler .step (
678+ latents_denoised_batch = self .scheduler .step (
662679 noise_pred , t , latents_for_view , ** extra_step_kwargs
663680 ).prev_sample
664681
665682 # save views scheduler status after sample
666683 views_scheduler_status [j ] = copy .deepcopy (self .scheduler .__dict__ )
667684
668- value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
669- count [:, :, h_start :h_end , w_start :w_end ] += 1
685+ # extract value from batch
686+ for latents_view_denoised , (h_start , h_end , w_start , w_end ) in zip (
687+ latents_denoised_batch .chunk (vb_size ), batch_view
688+ ):
689+ value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
690+ count [:, :, h_start :h_end , w_start :w_end ] += 1
670691
671692 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
672693 latents = torch .where (count > 0 , value / count , value )
0 commit comments