From 3042080947d2d673eb5c55f1dcbb9f7c497441ca Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Tue, 7 Feb 2023 13:40:51 -0600 Subject: [PATCH 1/4] Strategize slicing based on free [V]RAM --- ldm/invoke/generator/diffusers_pipeline.py | 24 ++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index f065a0ec2d9..af135c4ec5c 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -301,10 +301,8 @@ def __init__( textual_inversion_manager=self.textual_inversion_manager ) - self._enable_memory_efficient_attention() - - def _enable_memory_efficient_attention(self): + def _enable_memory_efficient_attention(self, latents: Torch.tensor): """ if xformers is available, use it, otherwise use sliced attention. """ @@ -317,7 +315,24 @@ def _enable_memory_efficient_attention(self): # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. pass else: - self.enable_attention_slicing(slice_size='max') + if self.device.type == 'cpu' or self.device.type == 'mps': + mem_free = psutil.virtual_memory().free + elif self.device.type == 'cuda': + mem_free, _ = torch.cuda.mem_get_info(self.device) + else: + raise ValueError(f"unrecognized device {device}") + # input tensor of [1, 4, h/8, w/8] + # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] + bytes_needed_for_baddbmm = latents.element_size() + 4 + estimated_max_size_required =\ + 16 * \ + latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \ + bytes_needed_for_baddbmm + if estimated_max_size_required > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code + self.enable_attention_slicing(slice_size='max') + else: + self.disable_attention_slicing() + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, @@ -377,6 +392,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None): + self._enable_memory_efficient_attention(latents) if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) if additional_guidance is None: From 78941bb723cb4eb0c842b5d3a61bf3c255f59c6b Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Tue, 7 Feb 2023 13:55:40 -0600 Subject: [PATCH 2/4] Cleaned up variable names --- ldm/invoke/generator/diffusers_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index af135c4ec5c..1504cdf9c37 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -323,12 +323,12 @@ def _enable_memory_efficient_attention(self, latents: Torch.tensor): raise ValueError(f"unrecognized device {device}") # input tensor of [1, 4, h/8, w/8] # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_needed_for_baddbmm = latents.element_size() + 4 - estimated_max_size_required =\ + bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 + max_size_required_for_baddbmm = \ 16 * \ latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \ - bytes_needed_for_baddbmm - if estimated_max_size_required > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code + bytes_per_element_needed_for_baddbmm_duplication + if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code self.enable_attention_slicing(slice_size='max') else: self.disable_attention_slicing() From e81ef2504d96ef4090a13ce8f8568fb71ba0ce70 Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Wed, 8 Feb 2023 06:36:24 -0600 Subject: [PATCH 3/4] Aggregate better CUDA stats. --- ldm/generate.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index fa4e6034999..e1eb40820b0 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -223,7 +223,7 @@ def __init__( self.model_name = model or fallback # for VRAM usage statistics - self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None + self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None transformers.logging.set_verbosity_error() # gets rid of annoying messages about random seed @@ -590,20 +590,24 @@ def process_image(image,seed): self.print_cuda_stats() return results - def clear_cuda_cache(self): + def gather_cuda_stats(self): if self._has_cuda(): self.max_memory_allocated = max( self.max_memory_allocated, - torch.cuda.max_memory_allocated() + torch.cuda.max_memory_allocated(self.device) ) self.memory_allocated = max( self.memory_allocated, - torch.cuda.memory_allocated() + torch.cuda.memory_allocated(self.device) ) self.session_peakmem = max( self.session_peakmem, - torch.cuda.max_memory_allocated() + torch.cuda.max_memory_allocated(self.device) ) + + def clear_cuda_cache(self): + if self._has_cuda(): + self.gather_cuda_stats() torch.cuda.empty_cache() def clear_cuda_stats(self): @@ -612,6 +616,7 @@ def clear_cuda_stats(self): def print_cuda_stats(self): if self._has_cuda(): + self.gather_cuda_stats() print( '>> Max VRAM used for this generation:', '%4.2fG.' % (self.max_memory_allocated / 1e9), From 35641d1f03dc8ade4c8ee36dd76c61504b5175b5 Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Sun, 12 Feb 2023 12:08:02 -0600 Subject: [PATCH 4/4] Renamed _enable_memory_efficient_attention to _adjust_memory_efficient_attention as this happens every generation. --- ldm/invoke/generator/diffusers_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 1504cdf9c37..24626247cf3 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -302,7 +302,7 @@ def __init__( ) - def _enable_memory_efficient_attention(self, latents: Torch.tensor): + def _adjust_memory_efficient_attention(self, latents: Torch.tensor): """ if xformers is available, use it, otherwise use sliced attention. """ @@ -392,7 +392,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None): - self._enable_memory_efficient_attention(latents) + self._adjust_memory_efficient_attention(latents) if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) if additional_guidance is None: