Skip to content

Commit 9eed191

Browse files
authored
Strategize slicing based on free [V]RAM (#2572)
Strategize slicing based on free [V]RAM when not using xformers. Free [V]RAM is evaluated at every generation. When there's enough memory, the entire generation occurs without slicing. If there is not enough free memory, we use diffusers' sliced attention.
1 parent 7c86130 commit 9eed191

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

ldm/generate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __init__(
223223
self.model_name = model or fallback
224224

225225
# for VRAM usage statistics
226-
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
226+
self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
227227
transformers.logging.set_verbosity_error()
228228

229229
# gets rid of annoying messages about random seed
@@ -592,20 +592,24 @@ def process_image(image,seed):
592592
self.print_cuda_stats()
593593
return results
594594

595-
def clear_cuda_cache(self):
595+
def gather_cuda_stats(self):
596596
if self._has_cuda():
597597
self.max_memory_allocated = max(
598598
self.max_memory_allocated,
599-
torch.cuda.max_memory_allocated()
599+
torch.cuda.max_memory_allocated(self.device)
600600
)
601601
self.memory_allocated = max(
602602
self.memory_allocated,
603-
torch.cuda.memory_allocated()
603+
torch.cuda.memory_allocated(self.device)
604604
)
605605
self.session_peakmem = max(
606606
self.session_peakmem,
607-
torch.cuda.max_memory_allocated()
607+
torch.cuda.max_memory_allocated(self.device)
608608
)
609+
610+
def clear_cuda_cache(self):
611+
if self._has_cuda():
612+
self.gather_cuda_stats()
609613
torch.cuda.empty_cache()
610614

611615
def clear_cuda_stats(self):
@@ -614,6 +618,7 @@ def clear_cuda_stats(self):
614618

615619
def print_cuda_stats(self):
616620
if self._has_cuda():
621+
self.gather_cuda_stats()
617622
print(
618623
'>> Max VRAM used for this generation:',
619624
'%4.2fG.' % (self.max_memory_allocated / 1e9),

ldm/invoke/generator/diffusers_pipeline.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,8 @@ def __init__(
301301
textual_inversion_manager=self.textual_inversion_manager
302302
)
303303

304-
self._enable_memory_efficient_attention()
305304

306-
307-
def _enable_memory_efficient_attention(self):
305+
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
308306
"""
309307
if xformers is available, use it, otherwise use sliced attention.
310308
"""
@@ -317,7 +315,24 @@ def _enable_memory_efficient_attention(self):
317315
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
318316
pass
319317
else:
320-
self.enable_attention_slicing(slice_size='max')
318+
if self.device.type == 'cpu' or self.device.type == 'mps':
319+
mem_free = psutil.virtual_memory().free
320+
elif self.device.type == 'cuda':
321+
mem_free, _ = torch.cuda.mem_get_info(self.device)
322+
else:
323+
raise ValueError(f"unrecognized device {device}")
324+
# input tensor of [1, 4, h/8, w/8]
325+
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
326+
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
327+
max_size_required_for_baddbmm = \
328+
16 * \
329+
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
330+
bytes_per_element_needed_for_baddbmm_duplication
331+
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
332+
self.enable_attention_slicing(slice_size='max')
333+
else:
334+
self.disable_attention_slicing()
335+
321336

322337
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
323338
conditioning_data: ConditioningData,
@@ -377,6 +392,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
377392
noise: torch.Tensor,
378393
run_id: str = None,
379394
additional_guidance: List[Callable] = None):
395+
self._adjust_memory_efficient_attention(latents)
380396
if run_id is None:
381397
run_id = secrets.token_urlsafe(self.ID_LENGTH)
382398
if additional_guidance is None:

0 commit comments

Comments
 (0)