Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ def __init__(
print('>> xformers memory-efficient attention is available but disabled')
else:
print('>> xformers not installed')

# model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_manager = ModelManager(mconfig, self.device, self.precision,
max_loaded_models=max_loaded_models,
free_gpu_mem=self.free_gpu_mem)
# don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
Expand Down Expand Up @@ -478,8 +480,7 @@ def process_image(image,seed):
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
except AttributeError:
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
pass
pass # free_gpu_mem is handled by model_manager for diffusers

try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
Expand Down
35 changes: 28 additions & 7 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
Expand Down Expand Up @@ -360,7 +360,7 @@ def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: in
callback: Callable[[PipelineIntermediateState], None] = None
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
result: PipelineIntermediateState = infer_latents_from_embeddings(
Expand All @@ -379,6 +379,8 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
additional_guidance: List[Callable] = None):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
assert not latents.is_meta
assert not noise.is_meta
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
Expand All @@ -391,7 +393,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,

batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
dtype=timesteps.dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)

attention_map_saver: Optional[AttentionMapSaver] = None
Expand Down Expand Up @@ -488,7 +490,7 @@ def img2img_from_embeddings(self,
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')

# 6. Prepare latent variables
device = self.unet.device
device = self.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
noise = noise_func(initial_latents)
Expand All @@ -503,7 +505,7 @@ def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_ste
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.device)
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
timesteps=timesteps,
Expand Down Expand Up @@ -542,7 +544,7 @@ def inpaint_from_embeddings(
run_id=None,
noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
device = self.device
latents_dtype = self.unet.dtype

if isinstance(init_image, PIL.Image.Image):
Expand Down Expand Up @@ -657,6 +659,25 @@ def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels

@property
def device(self) -> torch.device:
maybe_device = super().device
if maybe_device.type != 'meta':
return maybe_device
# copied from StableDiffusionPipeline._execution_device:
# Returns the device on which the pipeline's models will be executed. After calling
# `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred
# from Accelerate's module hooks.
# FIXME: poking around in implementation details of accelerate hooks is Bad News
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return maybe_device

def debug_latents(self, latents, msg):
with torch.inference_mode():
from ldm.util import debug_image
Expand Down
35 changes: 23 additions & 12 deletions ldm/invoke/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@
import textwrap
import time
import warnings
import safetensors.torch
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union
from huggingface_hub import scan_cache_dir
from ldm.util import download_with_progress_bar

import torch
import safetensors
import safetensors.torch
import torch
import transformers
from diffusers import AutoencoderKL, logging as dlogging
from diffusers import AutoencoderKL, logging as dlogging, DiffusionPipeline
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path

from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.invoke.globals import Globals, global_models_dir, global_autoscan_dir, global_cache_dir
from ldm.util import download_with_progress_bar
from ldm.util import instantiate_from_config, ask_user

DEFAULT_MAX_MODELS=2
Expand All @@ -43,9 +43,10 @@
class ModelManager(object):
def __init__(self,
config:OmegaConf,
device_type:str='cpu',
device_type:str | torch.device='cpu',
precision:str='float16',
max_loaded_models=DEFAULT_MAX_MODELS):
max_loaded_models=DEFAULT_MAX_MODELS,
free_gpu_mem=False):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
Expand All @@ -58,6 +59,7 @@ def __init__(self,
self.config = config
self.precision = precision
self.device = torch.device(device_type)
self.free_gpu_mem = free_gpu_mem
self.max_loaded_models = max_loaded_models
self.models = {}
self.stack = [] # this is an LRU FIFO
Expand Down Expand Up @@ -497,6 +499,9 @@ def _load_diffusers_model(self, mconfig):

pipeline.to(self.device)

if self.free_gpu_mem:
pipeline.enable_sequential_cpu_offload()

model_hash = self._diffuser_sha256(name_or_path)

# square images???
Expand Down Expand Up @@ -697,7 +702,6 @@ def convert_and_import(self,
'''
new_config = None
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
import transformers
if diffusers_path.exists():
print(f'ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again.')
return
Expand All @@ -706,7 +710,7 @@ def convert_and_import(self,
model_description = model_description or 'Optimized version of {model_name}'
print(f'>> Optimizing {model_name} (30-60s)')
try:
# By passing the specified VAE too the conversion function, the autoencoder
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None
convert_ckpt_to_diffuser(
Expand Down Expand Up @@ -753,7 +757,6 @@ def search_models(self, search_folder):
return search_folder, found_models

def _choose_diffusers_vae(self, model_name:str, vae:str=None)->Union[dict,str]:

# In the event that the original entry is using a custom ckpt VAE, we try to
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
# I would prefer to do this differently: We load the ckpt model into memory, swap the
Expand Down Expand Up @@ -901,10 +904,14 @@ def _invalidate_cached_model(self,model_name:str) -> None:
self.stack.remove(model_name)
self.models.pop(model_name,None)

def _model_to_cpu(self,model):
def _model_to_cpu(self, model):
if self.device == 'cpu':
return model

if isinstance(model, DiffusionPipeline) and self.free_gpu_mem:
# diffusers CPU offloading is being handled by accelerate
return model

# diffusers really really doesn't like us moving a float16 model onto CPU
verbosity = get_verbosity()
set_verbosity_error()
Expand All @@ -923,6 +930,10 @@ def _model_from_cpu(self,model):
if self.device == 'cpu':
return model

if isinstance(model, DiffusionPipeline) and self.free_gpu_mem:
# diffusers CPU offloading is being handled by accelerate
return model

model.to(self.device)
model.cond_stage_model.device = self.device

Expand Down Expand Up @@ -1066,7 +1077,7 @@ def _delete_model_from_cache(repo_id):
strategy.execute()

@staticmethod
def _abs_path(path:Union(str,Path))->Path:
def _abs_path(path: str | Path)->Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root,path).resolve()
Expand Down
8 changes: 5 additions & 3 deletions ldm/modules/prompt_to_embeddings_converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math

import torch
Expand All @@ -7,7 +9,7 @@
from ldm.modules.textual_inversion_manager import TextualInversionManager


class WeightedPromptFragmentsToEmbeddingsConverter():
class WeightedPromptFragmentsToEmbeddingsConverter:

def __init__(self,
tokenizer: CLIPTokenizer, # converts strings to lists of int token ids
Expand Down Expand Up @@ -159,7 +161,7 @@ def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weight
# lerped embeddings has shape (77, 768)


def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor):
def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str | torch.device) -> (torch.Tensor, torch.Tensor):
'''
Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences
together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded
Expand Down Expand Up @@ -208,7 +210,7 @@ def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[f
per_token_weights += [1.0] * pad_length

all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(device), device=device)
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_token_ids_tensor, per_token_weights_tensor

Expand Down