Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(
self.model_name = model or fallback

# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
transformers.logging.set_verbosity_error()

# gets rid of annoying messages about random seed
Expand Down Expand Up @@ -592,20 +592,24 @@ def process_image(image,seed):
self.print_cuda_stats()
return results

def clear_cuda_cache(self):
def gather_cuda_stats(self):
if self._has_cuda():
self.max_memory_allocated = max(
self.max_memory_allocated,
torch.cuda.max_memory_allocated()
torch.cuda.max_memory_allocated(self.device)
)
self.memory_allocated = max(
self.memory_allocated,
torch.cuda.memory_allocated()
torch.cuda.memory_allocated(self.device)
)
self.session_peakmem = max(
self.session_peakmem,
torch.cuda.max_memory_allocated()
torch.cuda.max_memory_allocated(self.device)
)

def clear_cuda_cache(self):
if self._has_cuda():
self.gather_cuda_stats()
torch.cuda.empty_cache()

def clear_cuda_stats(self):
Expand All @@ -614,6 +618,7 @@ def clear_cuda_stats(self):

def print_cuda_stats(self):
if self._has_cuda():
self.gather_cuda_stats()
print(
'>> Max VRAM used for this generation:',
'%4.2fG.' % (self.max_memory_allocated / 1e9),
Expand Down
24 changes: 20 additions & 4 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,8 @@ def __init__(
textual_inversion_manager=self.textual_inversion_manager
)

self._enable_memory_efficient_attention()


def _enable_memory_efficient_attention(self):
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
Expand All @@ -317,7 +315,24 @@ def _enable_memory_efficient_attention(self):
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='max')
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = \
16 * \
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
bytes_per_element_needed_for_baddbmm_duplication
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size='max')
else:
self.disable_attention_slicing()


def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
Expand Down Expand Up @@ -377,6 +392,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:
Expand Down