Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0cfebdf
new OffloadingDevice loads one model at a time, on demand
keturn Feb 9, 2023
79c454a
Merge remote-tracking branch 'origin/main' into spike/offloading-device
keturn Feb 9, 2023
4a9b0fc
fixup! new OffloadingDevice loads one model at a time, on demand
keturn Feb 9, 2023
69873d9
fix(prompt_to_embeddings): call the text encoder directly instead of …
keturn Feb 9, 2023
9d5ab9e
more attempts to get things on the right device from the offloader
keturn Feb 9, 2023
f39c806
more attempts to get things on the right device from the offloader
keturn Feb 9, 2023
337d179
Merge remote-tracking branch 'origin/main' into spike/offloading-device
keturn Feb 9, 2023
20df847
make offloading methods an explicit part of the pipeline interface
keturn Feb 10, 2023
f3e03e4
Merge remote-tracking branch 'origin/main' into spike/offloading-device
keturn Feb 10, 2023
ac0746f
inlining some calls where device is only used once
keturn Feb 10, 2023
127c1b8
ensure model group is ready after pipeline.to is called
keturn Feb 10, 2023
52563ae
Merge remote-tracking branch 'origin/main' into spike/offloading-device
keturn Feb 13, 2023
36bbb09
Merge remote-tracking branch 'origin/main' into spike/offloading-device
keturn Feb 15, 2023
26444af
fixup! Strategize slicing based on free [V]RAM (#2572)
keturn Feb 15, 2023
42ee1c6
doc(offloading): docstrings for offloading.ModelGroup
keturn Feb 15, 2023
0dcfb6f
doc(offloading): docstrings for offloading-related pipeline methods
keturn Feb 15, 2023
ae73997
Merge remote-tracking branch 'origin/main' into spike/offloading-device
keturn Feb 16, 2023
4a28326
refactor(offloading): s/SimpleModelGroup/FullyLoadedModelGroup
keturn Feb 16, 2023
10547e4
refactor(offloading): s/HotSeatModelGroup/LazilyLoadedModelGroup
keturn Feb 16, 2023
b4355a1
Merge branch 'main' into spike/offloading-device
damian0815 Feb 16, 2023
dca5561
Merge branch 'main' into spike/offloading-device
lstein Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def __init__(
print('>> xformers not installed')

# model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_manager = ModelManager(mconfig, self.device, self.precision,
max_loaded_models=max_loaded_models,
sequential_offload=self.free_gpu_mem)
# don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
Expand Down Expand Up @@ -480,7 +482,6 @@ def process_image(image,seed):
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
except AttributeError:
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
pass

try:
Expand Down
129 changes: 100 additions & 29 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,34 @@
import dataclasses
import inspect
import secrets
import sys
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

import PIL.Image
import einops
import psutil
import torch
import torchvision.transforms as T
from diffusers.utils.import_utils import is_xformers_available

from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter


from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec

from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter


@dataclass
Expand Down Expand Up @@ -264,6 +259,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup

ID_LENGTH = 8

Expand All @@ -273,7 +269,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
Expand Down Expand Up @@ -303,8 +299,11 @@ def __init__(
textual_inversion_manager=self.textual_inversion_manager
)

self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels)

def _adjust_memory_efficient_attention(self, latents: Torch.tensor):

def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
Expand All @@ -322,7 +321,7 @@ def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
Expand All @@ -336,6 +335,66 @@ def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
self.disable_attention_slicing()


def enable_offload_submodels(self, device: torch.device):
"""
Offload each submodel when it's not in use.

Useful for low-vRAM situations where the size of the model in memory is a big chunk of
the total available resource, and you want to free up as much for inference as possible.

This requires more moving parts and may add some delay as the U-Net is swapped out for the
VAE and vice-versa.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = LazilyLoadedModelGroup(device)
group.install(*models)
self._model_group = group

def disable_offload_submodels(self):
"""
Leave all submodels loaded.

Appropriate for cases where the size of the model in memory is small compared to the memory
required for inference. Avoids the delay and complexity of shuffling the submodels to and
from the GPU.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = FullyLoadedModelGroup(self._model_group.execution_device)
group.install(*models)
self._model_group = group

def offload_all(self):
"""Offload all this pipeline's models to CPU."""
self._model_group.offload_current()

def ready(self):
"""
Ready this pipeline's models.

i.e. pre-load them to the GPU if appropriate.
"""
self._model_group.ready()

def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self
self._model_group.set_device(torch_device)
self._model_group.ready()

@property
def device(self) -> torch.device:
return self._model_group.execution_device

@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)]

def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
*,
Expand Down Expand Up @@ -377,7 +436,7 @@ def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: in
callback: Callable[[PipelineIntermediateState], None] = None
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
result: PipelineIntermediateState = infer_latents_from_embeddings(
Expand Down Expand Up @@ -409,7 +468,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,

batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
latents = self.scheduler.add_noise(latents, noise, batched_t)

attention_map_saver: Optional[AttentionMapSaver] = None
Expand Down Expand Up @@ -493,9 +552,8 @@ def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Opt
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)

return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(latents, t, text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample

def img2img_from_embeddings(self,
Expand All @@ -514,9 +572,9 @@ def img2img_from_embeddings(self,
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')

# 6. Prepare latent variables
device = self.unet.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
initial_latents = self.non_noised_latents_from_image(
init_image, device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype)
noise = noise_func(initial_latents)

return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
Expand All @@ -529,7 +587,8 @@ def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_ste
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
device=self._model_group.device_for(self.unet))
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
timesteps=timesteps,
Expand Down Expand Up @@ -568,7 +627,7 @@ def inpaint_from_embeddings(
run_id=None,
noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype

if isinstance(init_image, PIL.Image.Image):
Expand Down Expand Up @@ -632,6 +691,8 @@ def non_noised_latents_from_image(self, init_image, *, device: torch.device, dty
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu')
init_image = init_image.to('cpu')
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
if device.type == 'mps':
Expand All @@ -643,8 +704,7 @@ def non_noised_latents_from_image(self, init_image, *, device: torch.device, dty

def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, device=self._execution_device, dtype=dtype)
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
Expand All @@ -653,6 +713,12 @@ def check_for_safety(self, output, dtype):
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver)

def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)

@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
"""
Expand All @@ -662,7 +728,7 @@ def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fr
text=c,
fragment_weights=fragment_weights,
should_return_tokens=return_tokens,
device=self.device)
device=self._model_group.device_for(self.unet))

@property
def cond_stage_model(self):
Expand All @@ -683,6 +749,11 @@ def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels

def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)

def debug_latents(self, latents, msg):
with torch.inference_mode():
from ldm.util import debug_image
Expand Down
26 changes: 16 additions & 10 deletions ldm/invoke/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import transformers
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from diffusers.utils.logging import (get_verbosity, set_verbosity,
set_verbosity_error)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
Expand All @@ -49,9 +47,10 @@ class ModelManager(object):
def __init__(
self,
config: OmegaConf,
device_type: str = "cpu",
device_type: str | torch.device = "cpu",
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload = False
):
"""
Initialize with the path to the models.yaml config file,
Expand All @@ -69,6 +68,7 @@ def __init__(
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
self.sequential_offload = sequential_offload

def valid_model(self, model_name: str) -> bool:
"""
Expand Down Expand Up @@ -529,7 +529,10 @@ def _load_diffusers_model(self, mconfig):
dlogging.set_verbosity(verbosity)
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')

pipeline.to(self.device)
if self.sequential_offload:
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)

model_hash = self._diffuser_sha256(name_or_path)

Expand Down Expand Up @@ -748,7 +751,6 @@ def convert_and_import(
into models.yaml.
"""
new_config = None
import transformers

from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser

Expand Down Expand Up @@ -995,12 +997,12 @@ def _model_to_cpu(self, model):
if self.device == "cpu":
return model

# diffusers really really doesn't like us moving a float16 model onto CPU
verbosity = get_verbosity()
set_verbosity_error()
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model

model.cond_stage_model.device = "cpu"
model.to("cpu")
set_verbosity(verbosity)

for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
Expand All @@ -1013,6 +1015,10 @@ def _model_from_cpu(self, model):
if self.device == "cpu":
return model

if isinstance(model, StableDiffusionGeneratorPipeline):
model.ready()
return model

model.to(self.device)
model.cond_stage_model.device = self.device

Expand Down Expand Up @@ -1163,7 +1169,7 @@ def _delete_model_from_cache(repo_id):
strategy.execute()

@staticmethod
def _abs_path(path: Union(str, Path)) -> Path:
def _abs_path(path: str | Path) -> Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root, path).resolve()
Expand Down
Loading