From 0cfebdf69ea475525f01b9dfa76ecbd7db098863 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 8 Feb 2023 20:13:03 -0800 Subject: [PATCH 01/13] new OffloadingDevice loads one model at a time, on demand --- ldm/generate.py | 5 +- ldm/invoke/generator/diffusers_pipeline.py | 6 +++ ldm/invoke/model_manager.py | 25 +++++++-- ldm/invoke/offloading.py | 63 ++++++++++++++++++++++ 4 files changed, 92 insertions(+), 7 deletions(-) create mode 100644 ldm/invoke/offloading.py diff --git a/ldm/generate.py b/ldm/generate.py index ca054788234..6a1fb34a827 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -213,7 +213,9 @@ def __init__( print('>> xformers not installed') # model caching system for fast switching - self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) + self.model_manager = ModelManager(mconfig, self.device, self.precision, + max_loaded_models=max_loaded_models, + sequential_offload=self.free_gpu_mem) # don't accept invalid models fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME model = model or fallback @@ -479,7 +481,6 @@ def process_image(image,seed): self.model.cond_stage_model.device = self.model.device self.model.cond_stage_model.to(self.model.device) except AttributeError: - print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.") pass try: diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index f065a0ec2d9..6c5c9b7157c 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -4,6 +4,7 @@ import inspect import secrets import sys +from collections.abc import Sequence from dataclasses import dataclass, field from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any @@ -657,6 +658,11 @@ def channels(self) -> int: """Compatible with DiffusionWrapper""" return self.unet.in_channels + @property + def submodels(self) -> Sequence[torch.nn.Module]: + models = self.text_encoder, self.unet, self.vae, self.feature_extractor, self.safety_checker + return [m for m in models if m is not None] + def debug_latents(self, latents, msg): with torch.inference_mode(): from ldm.util import debug_image diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 3135931eea4..e4948221cfc 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -36,6 +36,7 @@ StableDiffusionGeneratorPipeline from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir, global_models_dir) +from ldm.invoke.offloading import OffloadingDevice from ldm.util import (ask_user, download_with_progress_bar, instantiate_from_config) @@ -49,9 +50,10 @@ class ModelManager(object): def __init__( self, config: OmegaConf, - device_type: str = "cpu", + device_type: str | torch.device = "cpu", precision: str = "float16", max_loaded_models=DEFAULT_MAX_MODELS, + sequential_offload = False ): """ Initialize with the path to the models.yaml config file, @@ -69,6 +71,10 @@ def __init__( self.models = {} self.stack = [] # this is an LRU FIFO self.current_model = None + if sequential_offload: + self.offloader = OffloadingDevice(self.device) + else: + self.offloader = None def valid_model(self, model_name: str) -> bool: """ @@ -92,7 +98,10 @@ def get_model(self, model_name: str): if self.current_model != model_name: if model_name not in self.models: # make room for a new one self._make_cache_room() - self.offload_model(self.current_model) + if self.offloader: + self.offloader.offload_current() + else: + self.offload_model(self.current_model) if model_name in self.models: requested_model = self.models[model_name]["model"] @@ -529,7 +538,10 @@ def _load_diffusers_model(self, mconfig): dlogging.set_verbosity(verbosity) assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded') - pipeline.to(self.device) + if self.offloader: + self.offloader.install(*pipeline.submodels) + else: + pipeline.to(self.device) model_hash = self._diffuser_sha256(name_or_path) @@ -748,7 +760,6 @@ def convert_and_import( into models.yaml. """ new_config = None - import transformers from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser @@ -1011,6 +1022,10 @@ def _model_from_cpu(self, model): if self.device == "cpu": return model + if self.offloader and isinstance(model, StableDiffusionGeneratorPipeline): + # Offloader handles it on demand. + return model + model.to(self.device) model.cond_stage_model.device = self.device @@ -1161,7 +1176,7 @@ def _delete_model_from_cache(repo_id): strategy.execute() @staticmethod - def _abs_path(path: Union(str, Path)) -> Path: + def _abs_path(path: str | Path) -> Path: if path is None or Path(path).is_absolute(): return path return Path(Globals.root, path).resolve() diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py new file mode 100644 index 00000000000..06fe37f88df --- /dev/null +++ b/ldm/invoke/offloading.py @@ -0,0 +1,63 @@ +import weakref +from collections.abc import MutableMapping +from typing import Optional, Callable + +import torch +from torch.utils.hooks import RemovableHandle + + +class OffloadingDevice: + _hooks: MutableMapping[torch.nn.Module, RemovableHandle] + _current_model_ref: Callable[[], Optional[torch.nn.Module]] + + def __init__(self, execution_device: torch.device): + self.execution_device = execution_device + self._hooks = weakref.WeakKeyDictionary() + self._current_model_ref = lambda: None + + def install(self, *models: torch.nn.Module): + for model in models: + self._hooks[model] = model.register_forward_pre_hook(self._pre_hook) + + def uninstall(self, *models: torch.nn.Module): + for model in models: + hook = self._hooks.pop(model) + hook.remove() + if self.is_current_model(model): + # no longer hooked by this object, so don't claim to manage it + self.clear_current_model() + + def _pre_hook(self, module: torch.nn.Module, forward_input): + self.load(module) + return forward_input + + def load(self, module): + if not self.is_current_model(module): + self.offload_current() + self._load(module) + + def offload_current(self) -> torch.nn.Module: + # noinspection PyNoneFunctionAssignment + module: Optional[torch.nn.Module] = self._current_model_ref() + if module is not None: + module.cpu() + self.clear_current_model() + return module + + def _load(self, module: torch.nn.Module) -> torch.nn.Module: + assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}" + module = module.to(self.execution_device) + self.set_current_model(module) + return module + + def is_current_model(self, model: torch.nn.Module) -> bool: + return self._current_model_ref() is model + + def is_empty(self): + return self._current_model_ref() is None + + def set_current_model(self, value): + self._current_model_ref = weakref.ref(value) + + def clear_current_model(self): + self._current_model_ref = lambda: None From 4a9b0fceb235bb22c7c8f9616e0ef971b1f326c4 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 8 Feb 2023 20:48:06 -0800 Subject: [PATCH 02/13] fixup! new OffloadingDevice loads one model at a time, on demand --- ldm/invoke/generator/diffusers_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 6c5c9b7157c..c8980235a44 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -660,7 +660,7 @@ def channels(self) -> int: @property def submodels(self) -> Sequence[torch.nn.Module]: - models = self.text_encoder, self.unet, self.vae, self.feature_extractor, self.safety_checker + models = self.text_encoder, self.unet, self.vae, self.safety_checker return [m for m in models if m is not None] def debug_latents(self, latents, msg): From 69873d98057ca5e78ae4154d70e56d3b8bbc0d62 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 8 Feb 2023 20:49:33 -0800 Subject: [PATCH 03/13] fix(prompt_to_embeddings): call the text encoder directly instead of its forward method allowing any associated hooks to run with it. --- ldm/modules/prompt_to_embeddings_converter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ldm/modules/prompt_to_embeddings_converter.py b/ldm/modules/prompt_to_embeddings_converter.py index dea15d61b43..45014af02f2 100644 --- a/ldm/modules/prompt_to_embeddings_converter.py +++ b/ldm/modules/prompt_to_embeddings_converter.py @@ -214,7 +214,7 @@ def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[f def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor: ''' - Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights + Build a tensor that embeds the passed-in token IDs and applies the given per_token weights :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints) :param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats) :return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings @@ -224,8 +224,7 @@ def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_wei if token_ids.shape != torch.Size([self.max_length]): raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]") - z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0), - return_dict=False)[0] + z = self.text_encoder(input_ids=token_ids.unsqueeze(0), return_dict=False)[0] empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + [self.tokenizer.pad_token_id] * (self.max_length-2) + [self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0) From 9d5ab9ea49b1ab3bfda8d6b78b17e5878da7213f Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 8 Feb 2023 22:07:45 -0800 Subject: [PATCH 04/13] more attempts to get things on the right device from the offloader --- ldm/invoke/generator/diffusers_pipeline.py | 10 +++++++--- ldm/invoke/offloading.py | 11 ++++++++++- ldm/modules/prompt_to_embeddings_converter.py | 8 ++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index c8980235a44..3640b1a30c4 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -468,9 +468,8 @@ def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Opt initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) ).add_mask_channels(latents) - return self.unet(sample=latents, - timestep=t, - encoder_hidden_states=text_embeddings, + # First three args should be positional, not keywords, so torch hooks can see them. + return self.unet(latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs).sample def img2img_from_embeddings(self, @@ -663,6 +662,11 @@ def submodels(self) -> Sequence[torch.nn.Module]: models = self.text_encoder, self.unet, self.vae, self.safety_checker return [m for m in models if m is not None] + def decode_latents(self, latents): + # Super ugly kludge to get the vae loaded! (since `decode` isn't the forward method.) + self.vae() + return super().decode_latents(latents) + def debug_latents(self, latents, msg): with torch.inference_mode(): from ldm.util import debug_image diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py index 06fe37f88df..191c58be531 100644 --- a/ldm/invoke/offloading.py +++ b/ldm/invoke/offloading.py @@ -1,8 +1,10 @@ +import warnings import weakref from collections.abc import MutableMapping from typing import Optional, Callable import torch +from accelerate.utils import send_to_device from torch.utils.hooks import RemovableHandle @@ -29,7 +31,10 @@ def uninstall(self, *models: torch.nn.Module): def _pre_hook(self, module: torch.nn.Module, forward_input): self.load(module) - return forward_input + if len(forward_input) == 0: + warnings.warn(f"Hook for {module.__class__.__name__} got no input. " + f"Inputs must be positional, not keywords.", stacklevel=3) + return send_to_device(forward_input, self.execution_device) def load(self, module): if not self.is_current_model(module): @@ -61,3 +66,7 @@ def set_current_model(self, value): def clear_current_model(self): self._current_model_ref = lambda: None + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} object at {id(self):x}: " \ + f"current_model={type(self._current_model_ref()).__name__} >" diff --git a/ldm/modules/prompt_to_embeddings_converter.py b/ldm/modules/prompt_to_embeddings_converter.py index 45014af02f2..84d927d48b0 100644 --- a/ldm/modules/prompt_to_embeddings_converter.py +++ b/ldm/modules/prompt_to_embeddings_converter.py @@ -224,12 +224,12 @@ def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_wei if token_ids.shape != torch.Size([self.max_length]): raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]") - z = self.text_encoder(input_ids=token_ids.unsqueeze(0), return_dict=False)[0] + z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0] empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + [self.tokenizer.pad_token_id] * (self.max_length-2) + - [self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0) - empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state - batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + [self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0) + empty_z = self.text_encoder(empty_token_ids).last_hidden_state + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z) z_delta_from_empty = z - empty_z weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) From f39c806aa6b33f637b7f7fecc2af4bd041e5fde7 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 8 Feb 2023 23:02:23 -0800 Subject: [PATCH 05/13] more attempts to get things on the right device from the offloader --- ldm/invoke/generator/diffusers_pipeline.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 3640b1a30c4..78ebefa73d0 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -444,7 +444,7 @@ def step(self, t: torch.Tensor, latents: torch.Tensor, ) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents, + step_output = self.scheduler.step(noise_pred, timestep, latents.to(noise_pred.device), **conditioning_data.scheduler_args) # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent. @@ -664,7 +664,10 @@ def submodels(self) -> Sequence[torch.nn.Module]: def decode_latents(self, latents): # Super ugly kludge to get the vae loaded! (since `decode` isn't the forward method.) - self.vae() + try: + self.vae(tuple()) + except TypeError: + pass # we didn't expect it to work, just needed its side-effects. return super().decode_latents(latents) def debug_latents(self, latents, msg): From 20df847cde536793d1dc813876ed57f8269bf2a1 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 9 Feb 2023 16:41:30 -0800 Subject: [PATCH 06/13] make offloading methods an explicit part of the pipeline interface --- ldm/invoke/generator/diffusers_pipeline.py | 92 ++++++++++++------- ldm/invoke/model_manager.py | 29 ++---- ldm/invoke/offloading.py | 102 ++++++++++++++++++--- 3 files changed, 160 insertions(+), 63 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 78ebefa73d0..54a9d851992 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -3,40 +3,33 @@ import dataclasses import inspect import secrets -import sys from collections.abc import Sequence from dataclasses import dataclass, field from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any -if sys.version_info < (3, 10): - from typing_extensions import ParamSpec -else: - from typing import ParamSpec - import PIL.Image import einops import torch import torchvision.transforms as T -from diffusers.utils.import_utils import is_xformers_available - -from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver -from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter - - from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from typing_extensions import ParamSpec from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings from ldm.modules.textual_inversion_manager import TextualInversionManager +from ..offloading import HotSeatModelGroup, SimpleModelGroup +from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver +from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @dataclass @@ -272,7 +265,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[StableDiffusionSafetyChecker], feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, @@ -303,7 +296,8 @@ def __init__( ) self._enable_memory_efficient_attention() - + self._model_group = SimpleModelGroup(self.unet.device) + self._model_group.install(*self._submodels) def _enable_memory_efficient_attention(self): """ @@ -320,6 +314,43 @@ def _enable_memory_efficient_attention(self): else: self.enable_attention_slicing(slice_size='max') + def enable_offload_submodels(self, device: torch.device): + models = self._submodels + if self._model_group is not None: + self._model_group.uninstall(*models) + group = HotSeatModelGroup(device) + group.install(*models) + self._model_group = group + + def disable_offload_submodels(self): + models = self._submodels + if self._model_group is not None: + self._model_group.uninstall(*models) + group = SimpleModelGroup(self._model_group.execution_device) + group.install(*models) + self._model_group = group + + def offload_all(self): + self._model_group.offload_current() + + def ready(self): + self._model_group.ready() + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + self._model_group.set_device(torch_device) + + @property + def device(self) -> torch.device: + return self._model_group.execution_device + + @property + def _submodels(self) -> Sequence[torch.nn.Module]: + module_names, _, _ = self.extract_init_dict(dict(self.config)) + values = [getattr(self, name) for name in module_names.keys()] + return [m for m in values if isinstance(m, torch.nn.Module)] + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, *, @@ -360,8 +391,9 @@ def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: in additional_guidance: List[Callable] = None, run_id=None, callback: Callable[[PipelineIntermediateState], None] = None ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: + device = self._model_group.device_for(self.unet) if timesteps is None: - self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) result: PipelineIntermediateState = infer_latents_from_embeddings( @@ -391,8 +423,9 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, latents=latents) batch_size = latents.shape[0] + device = self._model_group.device_for(self.unet) batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=self.unet.device) + dtype=timesteps.dtype, device=device) latents = self.scheduler.add_noise(latents, noise, batched_t) attention_map_saver: Optional[AttentionMapSaver] = None @@ -444,7 +477,7 @@ def step(self, t: torch.Tensor, latents: torch.Tensor, ) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents.to(noise_pred.device), + step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent. @@ -488,7 +521,7 @@ def img2img_from_embeddings(self, init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') # 6. Prepare latent variables - device = self.unet.device + device = self._model_group.device_for(self.unet) latents_dtype = self.unet.dtype initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) noise = noise_func(initial_latents) @@ -503,7 +536,8 @@ def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_ste strength, noise: torch.Tensor, run_id=None, callback=None ) -> InvokeAIStableDiffusionPipelineOutput: - timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device) + timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, + device=self._model_group.device_for(self.unet)) result_latents, result_attention_maps = self.latents_from_embeddings( initial_latents, num_inference_steps, conditioning_data, timesteps=timesteps, @@ -542,7 +576,7 @@ def inpaint_from_embeddings( run_id=None, noise_func=None, ) -> InvokeAIStableDiffusionPipelineOutput: - device = self.unet.device + device = self._model_group.device_for(self.unet) latents_dtype = self.unet.dtype if isinstance(init_image, PIL.Image.Image): @@ -606,6 +640,8 @@ def non_noised_latents_from_image(self, init_image, *, device: torch.device, dty # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline self.vae.to('cpu') init_image = init_image.to('cpu') + else: + self._model_group.load(self.vae) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! if device.type == 'mps': @@ -636,7 +672,7 @@ def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fr text=c, fragment_weights=fragment_weights, should_return_tokens=return_tokens, - device=self.device) + device=self._model_group.device_for(self.unet)) @property def cond_stage_model(self): @@ -657,17 +693,9 @@ def channels(self) -> int: """Compatible with DiffusionWrapper""" return self.unet.in_channels - @property - def submodels(self) -> Sequence[torch.nn.Module]: - models = self.text_encoder, self.unet, self.vae, self.safety_checker - return [m for m in models if m is not None] - def decode_latents(self, latents): - # Super ugly kludge to get the vae loaded! (since `decode` isn't the forward method.) - try: - self.vae(tuple()) - except TypeError: - pass # we didn't expect it to work, just needed its side-effects. + # Explicit call to get the vae loaded, since `decode` isn't the forward method. + self._model_group.load(self.vae) return super().decode_latents(latents) def debug_latents(self, latents, msg): diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index e4948221cfc..04c1d920892 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -25,8 +25,6 @@ import transformers from diffusers import AutoencoderKL from diffusers import logging as dlogging -from diffusers.utils.logging import (get_verbosity, set_verbosity, - set_verbosity_error) from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -36,7 +34,6 @@ StableDiffusionGeneratorPipeline from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir, global_models_dir) -from ldm.invoke.offloading import OffloadingDevice from ldm.util import (ask_user, download_with_progress_bar, instantiate_from_config) @@ -71,10 +68,7 @@ def __init__( self.models = {} self.stack = [] # this is an LRU FIFO self.current_model = None - if sequential_offload: - self.offloader = OffloadingDevice(self.device) - else: - self.offloader = None + self.sequential_offload = sequential_offload def valid_model(self, model_name: str) -> bool: """ @@ -98,10 +92,7 @@ def get_model(self, model_name: str): if self.current_model != model_name: if model_name not in self.models: # make room for a new one self._make_cache_room() - if self.offloader: - self.offloader.offload_current() - else: - self.offload_model(self.current_model) + self.offload_model(self.current_model) if model_name in self.models: requested_model = self.models[model_name]["model"] @@ -538,8 +529,8 @@ def _load_diffusers_model(self, mconfig): dlogging.set_verbosity(verbosity) assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded') - if self.offloader: - self.offloader.install(*pipeline.submodels) + if self.sequential_offload: + pipeline.enable_offload_submodels(self.device) else: pipeline.to(self.device) @@ -1004,12 +995,12 @@ def _model_to_cpu(self, model): if self.device == "cpu": return model - # diffusers really really doesn't like us moving a float16 model onto CPU - verbosity = get_verbosity() - set_verbosity_error() + if isinstance(model, StableDiffusionGeneratorPipeline): + model.offload_all() + return model + model.cond_stage_model.device = "cpu" model.to("cpu") - set_verbosity(verbosity) for submodel in ("first_stage_model", "cond_stage_model", "model"): try: @@ -1022,8 +1013,8 @@ def _model_from_cpu(self, model): if self.device == "cpu": return model - if self.offloader and isinstance(model, StableDiffusionGeneratorPipeline): - # Offloader handles it on demand. + if isinstance(model, StableDiffusionGeneratorPipeline): + model.ready() return model model.to(self.device) diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py index 191c58be531..af7d50c9454 100644 --- a/ldm/invoke/offloading.py +++ b/ldm/invoke/offloading.py @@ -1,21 +1,34 @@ +from __future__ import annotations + import warnings import weakref from collections.abc import MutableMapping -from typing import Optional, Callable +from typing import Callable import torch from accelerate.utils import send_to_device from torch.utils.hooks import RemovableHandle +OFFLOAD_DEVICE = torch.device("cpu") + +class _NoModel: + def __bool__(self): + return False + + def to(self, device: torch.device): + pass + +NO_MODEL = _NoModel() -class OffloadingDevice: + +class HotSeatModelGroup: _hooks: MutableMapping[torch.nn.Module, RemovableHandle] - _current_model_ref: Callable[[], Optional[torch.nn.Module]] + _current_model_ref: Callable[[], torch.nn.Module | _NoModel] def __init__(self, execution_device: torch.device): self.execution_device = execution_device self._hooks = weakref.WeakKeyDictionary() - self._current_model_ref = lambda: None + self._current_model_ref = weakref.ref(NO_MODEL) def install(self, *models: torch.nn.Module): for model in models: @@ -29,6 +42,9 @@ def uninstall(self, *models: torch.nn.Module): # no longer hooked by this object, so don't claim to manage it self.clear_current_model() + def uninstall_all(self): + self.uninstall(*self._hooks.keys()) + def _pre_hook(self, module: torch.nn.Module, forward_input): self.load(module) if len(forward_input) == 0: @@ -41,13 +57,11 @@ def load(self, module): self.offload_current() self._load(module) - def offload_current(self) -> torch.nn.Module: - # noinspection PyNoneFunctionAssignment - module: Optional[torch.nn.Module] = self._current_model_ref() - if module is not None: - module.cpu() + def offload_current(self): + module = self._current_model_ref() + if module is not NO_MODEL: + module.to(device=OFFLOAD_DEVICE) self.clear_current_model() - return module def _load(self, module: torch.nn.Module) -> torch.nn.Module: assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}" @@ -59,14 +73,78 @@ def is_current_model(self, model: torch.nn.Module) -> bool: return self._current_model_ref() is model def is_empty(self): - return self._current_model_ref() is None + return self._current_model_ref() is NO_MODEL def set_current_model(self, value): self._current_model_ref = weakref.ref(value) def clear_current_model(self): - self._current_model_ref = lambda: None + self._current_model_ref = weakref.ref(NO_MODEL) + + def set_device(self, device: torch.device): + if device == self.execution_device: + return + self.execution_device = device + current = self._current_model_ref() + if current is not NO_MODEL: + current.to(device) + + def device_for(self, model): + if model not in self: + raise KeyError("This does not manage this model f{type(model).__name__}", model) + return self.execution_device # this implementation only dispatches to one device + + def ready(self): + pass # always ready to load on-demand + + def __contains__(self, model): + return model in self._hooks def __repr__(self) -> str: return f"<{self.__class__.__name__} object at {id(self):x}: " \ f"current_model={type(self._current_model_ref()).__name__} >" + + +class SimpleModelGroup: + _models: weakref.WeakSet + + def __init__(self, execution_device: torch.device): + self.execution_device = execution_device + self._models = weakref.WeakSet() + + def install(self, *models: torch.nn.Module): + for model in models: + self._models.add(model) + model.to(device=self.execution_device) + + def uninstall(self, *models: torch.nn.Module): + for model in models: + self._models.remove(model) + + def uninstall_all(self): + self.uninstall(*self._models) + + def load(self, model): + model.to(device=self.execution_device) + + def offload_current(self): + for model in self._models: + model.to(device=OFFLOAD_DEVICE) + + def ready(self): + for model in self._models: + self.load(model) + + def set_device(self, device: torch.device): + self.execution_device = device + for model in self._models: + if model.device != OFFLOAD_DEVICE: + model.to(device=device) + + def device_for(self, model): + if model not in self: + raise KeyError("This does not manage this model f{type(model).__name__}", model) + return self.execution_device # this implementation only dispatches to one device + + def __contains__(self, model): + return model in self._models From ac0746f1b3c17cc304ee4da4446744061844462e Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 9 Feb 2023 17:04:19 -0800 Subject: [PATCH 07/13] inlining some calls where device is only used once --- ldm/invoke/generator/diffusers_pipeline.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 54a9d851992..58f08202b51 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -391,9 +391,8 @@ def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: in additional_guidance: List[Callable] = None, run_id=None, callback: Callable[[PipelineIntermediateState], None] = None ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: - device = self._model_group.device_for(self.unet) if timesteps is None: - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet)) timesteps = self.scheduler.timesteps infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) result: PipelineIntermediateState = infer_latents_from_embeddings( @@ -423,9 +422,8 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, latents=latents) batch_size = latents.shape[0] - device = self._model_group.device_for(self.unet) batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=device) + dtype=timesteps.dtype, device=self._model_group.device_for(self.unet)) latents = self.scheduler.add_noise(latents, noise, batched_t) attention_map_saver: Optional[AttentionMapSaver] = None @@ -521,9 +519,9 @@ def img2img_from_embeddings(self, init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') # 6. Prepare latent variables - device = self._model_group.device_for(self.unet) - latents_dtype = self.unet.dtype - initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) + initial_latents = self.non_noised_latents_from_image( + init_image, device=self._model_group.device_for(self.unet), + dtype=self.unet.dtype) noise = noise_func(initial_latents) return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, @@ -653,8 +651,7 @@ def non_noised_latents_from_image(self, init_image, *, device: torch.device, dty def check_for_safety(self, output, dtype): with torch.inference_mode(): - screened_images, has_nsfw_concept = self.run_safety_checker( - output.images, device=self._execution_device, dtype=dtype) + screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) screened_attention_map_saver = None if has_nsfw_concept is None or not has_nsfw_concept: screened_attention_map_saver = output.attention_map_saver @@ -663,6 +660,12 @@ def check_for_safety(self, output, dtype): # block the attention maps if NSFW content is detected attention_map_saver=screened_attention_map_saver) + def run_safety_checker(self, image, device=None, dtype=None): + # overriding to use the model group for device info instead of requiring the caller to know. + if self.safety_checker is not None: + device = self._model_group.device_for(self.safety_checker) + return super().run_safety_checker(image, device, dtype) + @torch.inference_mode() def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): """ From 127c1b8fe5a2de5242fcf9f453b2b8b29d8dcbea Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 9 Feb 2023 17:11:16 -0800 Subject: [PATCH 08/13] ensure model group is ready after pipeline.to is called --- ldm/invoke/generator/diffusers_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 58f08202b51..4968ad68ab9 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -340,6 +340,7 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): if torch_device is None: return self self._model_group.set_device(torch_device) + self._model_group.ready() @property def device(self) -> torch.device: From 26444afd251c0accc6659e2e12c94b719d70cadb Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 14 Feb 2023 19:17:57 -0800 Subject: [PATCH 09/13] fixup! Strategize slicing based on free [V]RAM (#2572) --- ldm/invoke/generator/diffusers_pipeline.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 26200793dc2..84521175bb7 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -9,6 +9,7 @@ import PIL.Image import einops +import psutil import torch import torchvision.transforms as T from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -301,7 +302,7 @@ def __init__( self._model_group.install(*self._submodels) - def _adjust_memory_efficient_attention(self, latents: Torch.tensor): + def _adjust_memory_efficient_attention(self, latents: torch.Tensor): """ if xformers is available, use it, otherwise use sliced attention. """ @@ -319,7 +320,7 @@ def _adjust_memory_efficient_attention(self, latents: Torch.tensor): elif self.device.type == 'cuda': mem_free, _ = torch.cuda.mem_get_info(self.device) else: - raise ValueError(f"unrecognized device {device}") + raise ValueError(f"unrecognized device {self.device}") # input tensor of [1, 4, h/8, w/8] # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 From 42ee1c66d194ae825fe422d59bdbee5deb63a2cb Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 14 Feb 2023 20:28:45 -0800 Subject: [PATCH 10/13] doc(offloading): docstrings for offloading.ModelGroup --- ldm/invoke/generator/diffusers_pipeline.py | 3 +- ldm/invoke/offloading.py | 107 ++++++++++++++++++++- 2 files changed, 104 insertions(+), 6 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 84521175bb7..326c413c1d5 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -28,7 +28,7 @@ from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.modules.textual_inversion_manager import TextualInversionManager -from ..offloading import HotSeatModelGroup, SimpleModelGroup +from ..offloading import HotSeatModelGroup, SimpleModelGroup, ModelGroup from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @@ -259,6 +259,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _model_group: ModelGroup ID_LENGTH = 8 diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py index af7d50c9454..da9e27801f7 100644 --- a/ldm/invoke/offloading.py +++ b/ldm/invoke/offloading.py @@ -2,6 +2,7 @@ import warnings import weakref +from abc import ABCMeta, abstractmethod from collections.abc import MutableMapping from typing import Callable @@ -12,21 +13,110 @@ OFFLOAD_DEVICE = torch.device("cpu") class _NoModel: + """Symbol that indicates no model is loaded. + + (We can't weakref.ref(None), so this was my best idea at the time to come up with something + type-checkable.) + """ + def __bool__(self): return False def to(self, device: torch.device): pass + def __repr__(self): + return "" + NO_MODEL = _NoModel() -class HotSeatModelGroup: +class ModelGroup(metaclass=ABCMeta): + """ + A group of models. + + The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline, + e.g. its text encoder, U-net, VAE, etc. + + Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with + :py:class:`torch.nn.Module` here. + """ + + def __init__(self, execution_device: torch.device): + self.execution_device = execution_device + + @abstractmethod + def install(self, *models: torch.nn.Module): + """Add models to this group.""" + pass + + @abstractmethod + def uninstall(self, models: torch.nn.Module): + """Remove models to this group.""" + pass + + @abstractmethod + def uninstall_all(self): + """Remove all models from this group.""" + + @abstractmethod + def load(self, model: torch.nn.Module): + """Load this model to the execution device.""" + pass + + @abstractmethod + def offload_current(self): + """Offload the current model(s) from the execution device.""" + pass + + @abstractmethod + def ready(self): + """Ready this group for use.""" + pass + + @abstractmethod + def set_device(self, device: torch.device): + """Change which device models from this group will execute on.""" + pass + + @abstractmethod + def device_for(self, model) -> torch.device: + """Get the device the given model will execute on. + + The model should already be a member of this group. + """ + pass + + @abstractmethod + def __contains__(self, model): + """Check if the model is a member of this group.""" + pass + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} object at {id(self):x}: " \ + f"device={self.execution_device} >" + + +class HotSeatModelGroup(ModelGroup): + """ + Only one model from this group is loaded on the GPU at a time. + + Running the forward method of a model will displace the previously-loaded model, + offloading it to CPU. + + If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``, + you will need to explicitly load it with :py:method:`.load(model)`. + + This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments + to the appropriate execution device, as long as they are positional arguments and not keyword + arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.) + """ + _hooks: MutableMapping[torch.nn.Module, RemovableHandle] _current_model_ref: Callable[[], torch.nn.Module | _NoModel] def __init__(self, execution_device: torch.device): - self.execution_device = execution_device + super().__init__(execution_device) self._hooks = weakref.WeakKeyDictionary() self._current_model_ref = weakref.ref(NO_MODEL) @@ -70,9 +160,11 @@ def _load(self, module: torch.nn.Module) -> torch.nn.Module: return module def is_current_model(self, model: torch.nn.Module) -> bool: + """Is the given model the one currently loaded on the execution device?""" return self._current_model_ref() is model def is_empty(self): + """Are none of this group's models loaded on the execution device?""" return self._current_model_ref() is NO_MODEL def set_current_model(self, value): @@ -91,7 +183,7 @@ def set_device(self, device: torch.device): def device_for(self, model): if model not in self: - raise KeyError("This does not manage this model f{type(model).__name__}", model) + raise KeyError(f"This does not manage this model {type(model).__name__}", model) return self.execution_device # this implementation only dispatches to one device def ready(self): @@ -105,11 +197,16 @@ def __repr__(self) -> str: f"current_model={type(self._current_model_ref()).__name__} >" -class SimpleModelGroup: +class SimpleModelGroup(ModelGroup): + """ + A group of models without any implicit loading or unloading. + + :py:meth:`.ready` loads _all_ the models to the execution device at once. + """ _models: weakref.WeakSet def __init__(self, execution_device: torch.device): - self.execution_device = execution_device + super().__init__(execution_device) self._models = weakref.WeakSet() def install(self, *models: torch.nn.Module): From 0dcfb6f469eb09a424abde520059ef9186983e70 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 14 Feb 2023 21:01:50 -0800 Subject: [PATCH 11/13] doc(offloading): docstrings for offloading-related pipeline methods --- ldm/invoke/generator/diffusers_pipeline.py | 22 ++++++++++++++++++++++ ldm/invoke/offloading.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 326c413c1d5..49216be1dc5 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -336,6 +336,15 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def enable_offload_submodels(self, device: torch.device): + """ + Offload each submodel when it's not in use. + + Useful for low-vRAM situations where the size of the model in memory is a big chunk of + the total available resource, and you want to free up as much for inference as possible. + + This requires more moving parts and may add some delay as the U-Net is swapped out for the + VAE and vice-versa. + """ models = self._submodels if self._model_group is not None: self._model_group.uninstall(*models) @@ -344,6 +353,13 @@ def enable_offload_submodels(self, device: torch.device): self._model_group = group def disable_offload_submodels(self): + """ + Leave all submodels loaded. + + Appropriate for cases where the size of the model in memory is small compared to the memory + required for inference. Avoids the delay and complexity of shuffling the submodels to and + from the GPU. + """ models = self._submodels if self._model_group is not None: self._model_group.uninstall(*models) @@ -352,9 +368,15 @@ def disable_offload_submodels(self): self._model_group = group def offload_all(self): + """Offload all this pipeline's models to CPU.""" self._model_group.offload_current() def ready(self): + """ + Ready this pipeline's models. + + i.e. pre-load them to the GPU if appropriate. + """ self._model_group.ready() def to(self, torch_device: Optional[Union[str, torch.device]] = None): diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py index da9e27801f7..10bd8376d2a 100644 --- a/ldm/invoke/offloading.py +++ b/ldm/invoke/offloading.py @@ -52,7 +52,7 @@ def install(self, *models: torch.nn.Module): @abstractmethod def uninstall(self, models: torch.nn.Module): - """Remove models to this group.""" + """Remove models from this group.""" pass @abstractmethod From 4a283269ba9c70271d5e07ca17b5a2340553837c Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 15 Feb 2023 18:08:19 -0800 Subject: [PATCH 12/13] refactor(offloading): s/SimpleModelGroup/FullyLoadedModelGroup --- ldm/invoke/generator/diffusers_pipeline.py | 6 +++--- ldm/invoke/offloading.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 49216be1dc5..4a5d2370b8c 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -28,7 +28,7 @@ from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.modules.textual_inversion_manager import TextualInversionManager -from ..offloading import HotSeatModelGroup, SimpleModelGroup, ModelGroup +from ..offloading import HotSeatModelGroup, FullyLoadedModelGroup, ModelGroup from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @@ -299,7 +299,7 @@ def __init__( textual_inversion_manager=self.textual_inversion_manager ) - self._model_group = SimpleModelGroup(self.unet.device) + self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group.install(*self._submodels) @@ -363,7 +363,7 @@ def disable_offload_submodels(self): models = self._submodels if self._model_group is not None: self._model_group.uninstall(*models) - group = SimpleModelGroup(self._model_group.execution_device) + group = FullyLoadedModelGroup(self._model_group.execution_device) group.install(*models) self._model_group = group diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py index 10bd8376d2a..7f1b0d0f11f 100644 --- a/ldm/invoke/offloading.py +++ b/ldm/invoke/offloading.py @@ -197,7 +197,7 @@ def __repr__(self) -> str: f"current_model={type(self._current_model_ref()).__name__} >" -class SimpleModelGroup(ModelGroup): +class FullyLoadedModelGroup(ModelGroup): """ A group of models without any implicit loading or unloading. From 10547e4a2e4c9834fe0689b1b5336abd970b21b6 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 15 Feb 2023 18:28:41 -0800 Subject: [PATCH 13/13] refactor(offloading): s/HotSeatModelGroup/LazilyLoadedModelGroup to frame it is the same terms as "FullyLoadedModelGroup" --- ldm/invoke/generator/diffusers_pipeline.py | 4 ++-- ldm/invoke/offloading.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 4a5d2370b8c..5990eb42a17 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -28,7 +28,7 @@ from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.modules.textual_inversion_manager import TextualInversionManager -from ..offloading import HotSeatModelGroup, FullyLoadedModelGroup, ModelGroup +from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @@ -348,7 +348,7 @@ def enable_offload_submodels(self, device: torch.device): models = self._submodels if self._model_group is not None: self._model_group.uninstall(*models) - group = HotSeatModelGroup(device) + group = LazilyLoadedModelGroup(device) group.install(*models) self._model_group = group diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py index 7f1b0d0f11f..e049f5fe099 100644 --- a/ldm/invoke/offloading.py +++ b/ldm/invoke/offloading.py @@ -97,7 +97,7 @@ def __repr__(self) -> str: f"device={self.execution_device} >" -class HotSeatModelGroup(ModelGroup): +class LazilyLoadedModelGroup(ModelGroup): """ Only one model from this group is loaded on the GPU at a time.