@@ -301,10 +301,8 @@ def __init__(
301301 textual_inversion_manager = self .textual_inversion_manager
302302 )
303303
304- self ._enable_memory_efficient_attention ()
305304
306-
307- def _enable_memory_efficient_attention (self ):
305+ def _adjust_memory_efficient_attention (self , latents : Torch .tensor ):
308306 """
309307 if xformers is available, use it, otherwise use sliced attention.
310308 """
@@ -317,7 +315,24 @@ def _enable_memory_efficient_attention(self):
317315 # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
318316 pass
319317 else :
320- self .enable_attention_slicing (slice_size = 'max' )
318+ if self .device .type == 'cpu' or self .device .type == 'mps' :
319+ mem_free = psutil .virtual_memory ().free
320+ elif self .device .type == 'cuda' :
321+ mem_free , _ = torch .cuda .mem_get_info (self .device )
322+ else :
323+ raise ValueError (f"unrecognized device { device } " )
324+ # input tensor of [1, 4, h/8, w/8]
325+ # output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
326+ bytes_per_element_needed_for_baddbmm_duplication = latents .element_size () + 4
327+ max_size_required_for_baddbmm = \
328+ 16 * \
329+ latents .size (dim = 2 ) * latents .size (dim = 3 ) * latents .size (dim = 2 ) * latents .size (dim = 3 ) * \
330+ bytes_per_element_needed_for_baddbmm_duplication
331+ if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0 ): # 3.3 / 4.0 is from old Invoke code
332+ self .enable_attention_slicing (slice_size = 'max' )
333+ else :
334+ self .disable_attention_slicing ()
335+
321336
322337 def image_from_embeddings (self , latents : torch .Tensor , num_inference_steps : int ,
323338 conditioning_data : ConditioningData ,
@@ -377,6 +392,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
377392 noise : torch .Tensor ,
378393 run_id : str = None ,
379394 additional_guidance : List [Callable ] = None ):
395+ self ._adjust_memory_efficient_attention (latents )
380396 if run_id is None :
381397 run_id = secrets .token_urlsafe (self .ID_LENGTH )
382398 if additional_guidance is None :
0 commit comments