diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 99ae0c256ae..118ca150d8b 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -1,21 +1,20 @@ name: Test invoke.py -on: - push: - branches: - - 'main' - - 'development' - - 'fix-gh-actions-fork' - pull_request: - branches: - - 'main' - - 'development' +on: [push, pull_request] jobs: matrix: + # Run on: + # - pull requests + # - pushes to forks (will run in the forked project with that fork's secrets) + # - pushes to branches that are *not* pull requests + if: | + github.event_name == 'pull_request' + || github.repository != 'invoke-ai/InvokeAI' + || github.ref_protected strategy: matrix: stable-diffusion-model: - - 'stable-diffusion-1.5' + - diffusers-1.4 environment-yaml: - environment-lin-amd.yml - environment-lin-cuda.yml @@ -30,15 +29,13 @@ jobs: - environment-yaml: environment-mac.yml os: macos-12 default-shell: bash -l {0} - - stable-diffusion-model: stable-diffusion-1.5 - stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt - stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1 - stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt name: ${{ matrix.environment-yaml }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} env: CONDA_ENV_NAME: invokeai INVOKEAI_ROOT: '${{ github.workspace }}/invokeai' + PYTHONUNBUFFERED: 1 + HAVE_SECRETS: ${{ secrets.HUGGINGFACE_TOKEN != '' }} defaults: run: shell: ${{ matrix.default-shell }} @@ -55,6 +52,15 @@ jobs: - name: create environment.yml run: cp "environments-and-requirements/${{ matrix.environment-yaml }}" environment.yml + - name: Use Cached Stable Diffusion Model + id: cache-sd-model + uses: actions/cache@v3 + env: + cache-name: huggingface-${{ matrix.stable-diffusion-model }} + with: + path: ~/.cache/huggingface + key: ${{ env.cache-name }} + - name: Use cached conda packages id: use-cached-conda-packages uses: actions/cache@v3 @@ -82,32 +88,24 @@ jobs: if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }} run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV - - name: Use Cached Stable Diffusion Model - id: cache-sd-model - uses: actions/cache@v3 - env: - cache-name: cache-${{ matrix.stable-diffusion-model }} - with: - path: ${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }} - key: ${{ env.cache-name }} - - - name: Download ${{ matrix.stable-diffusion-model }} - id: download-stable-diffusion-model - if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }} - run: | - mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}" - curl \ - -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" \ - -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" \ - -L ${{ matrix.stable-diffusion-model-url }} - - name: run configure_invokeai.py id: run-preload-models run: | - python scripts/configure_invokeai.py --no-interactive --yes + if [ "${HAVE_SECRETS}" == true ] ; then + mkdir -p ~/.huggingface + echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token + fi + python scripts/configure_invokeai.py \ + --no-interactive --yes \ + --full-precision # can't use fp16 weights without a GPU - name: Run the tests id: run-tests + env: + # Set offline mode to make sure configure preloaded successfully. + HF_HUB_OFFLINE: 1 + HF_DATASETS_OFFLINE: 1 + TRANSFORMERS_OFFLINE: 1 run: | time python scripts/invoke.py \ --model ${{ matrix.stable-diffusion-model }} \ diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index ce1d1ad6d77..15c0e377b2b 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -1,20 +1,20 @@ name: Test invoke.py pip -on: - push: - branches: - - 'main' - - 'development' - pull_request: - branches: - - 'main' - - 'development' +on: [push, pull_request] jobs: matrix: + # Run on: + # - pull requests + # - pushes to forks (will run in the forked project with that fork's secrets) + # - pushes to branches that are *not* pull requests + if: | + github.event_name == 'pull_request' + || github.repository != 'invoke-ai/InvokeAI' + || github.ref_protected strategy: matrix: stable-diffusion-model: - - stable-diffusion-1.5 + - diffusers-1.4 requirements-file: - requirements-lin-cuda.txt - requirements-lin-amd.txt @@ -32,10 +32,6 @@ jobs: - requirements-file: requirements-mac-mps-cpu.txt os: macOS-12 default-shell: bash -l {0} - - stable-diffusion-model: stable-diffusion-1.5 - stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt - stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1 - stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt name: ${{ matrix.requirements-file }} on ${{ matrix.python-version }} runs-on: ${{ matrix.os }} defaults: @@ -43,6 +39,8 @@ jobs: shell: ${{ matrix.default-shell }} env: INVOKEAI_ROOT: '${{ github.workspace }}/invokeai' + PYTHONUNBUFFERED: 1 + HAVE_SECRETS: ${{ secrets.HUGGINGFACE_TOKEN != '' }} steps: - name: Checkout sources id: checkout-sources @@ -53,6 +51,15 @@ jobs: mkdir -p ${{ env.INVOKEAI_ROOT }}/configs cp configs/models.yaml.example ${{ env.INVOKEAI_ROOT }}/configs/models.yaml + - name: Use Cached Stable Diffusion Model + id: cache-sd-model + uses: actions/cache@v3 + env: + cache-name: huggingface-${{ matrix.stable-diffusion-model }} + with: + path: ~/.cache/huggingface + key: ${{ env.cache-name }} + - name: set test prompt to main branch validation if: ${{ github.ref == 'refs/heads/main' }} run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> $GITHUB_ENV @@ -81,32 +88,24 @@ jobs: - name: install requirements run: ${{ env.pythonLocation }}/bin/pip install -r '${{ matrix.requirements-file }}' - - name: Use Cached Stable Diffusion Model - id: cache-sd-model - uses: actions/cache@v3 - env: - cache-name: cache-${{ matrix.stable-diffusion-model }} - with: - path: ${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }} - key: ${{ env.cache-name }} - - - name: Download ${{ matrix.stable-diffusion-model }} - id: download-stable-diffusion-model - if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }} - run: | - mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}" - curl \ - -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" \ - -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" \ - -L ${{ matrix.stable-diffusion-model-url }} - - name: run configure_invokeai.py id: run-preload-models run: | - ${{ env.pythonLocation }}/bin/python scripts/configure_invokeai.py --no-interactive --yes + if [ "${HAVE_SECRETS}" == true ] ; then + mkdir -p ~/.huggingface + echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token + fi + ${{ env.pythonLocation }}/bin/python scripts/configure_invokeai.py \ + --no-interactive --yes \ + --full-precision # can't use fp16 weights without a GPU - name: Run the tests id: run-tests + env: + # Set offline mode to make sure configure preloaded successfully. + HF_HUB_OFFLINE: 1 + HF_DATASETS_OFFLINE: 1 + TRANSFORMERS_OFFLINE: 1 run: | time ${{ env.pythonLocation }}/bin/python scripts/invoke.py \ --model ${{ matrix.stable-diffusion-model }} \ diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index ac8edc6a324..d525cf87f89 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -19,6 +19,7 @@ from threading import Event from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash +from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.prompt_parser import split_weighted_subprompts from ldm.invoke.generator.inpaint import infill_methods @@ -847,7 +848,9 @@ def generate_images( init_img_path = self.get_image_path_from_url(init_img_url) generation_parameters["init_img"] = Image.open(init_img_path).convert('RGB') - def image_progress(sample, step): + def image_progress(progress_state: PipelineIntermediateState): + step = progress_state.step + sample = progress_state.latents if self.canceled.is_set(): raise CanceledException diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index f3079e04973..4cc0831c764 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -10,6 +10,8 @@ "k_heun", "k_lms", "plms", + # diffusers: + "pndm", ] diff --git a/configs/models.yaml.example b/configs/models.yaml.example index 31401cd02d8..87bc13645d1 100644 --- a/configs/models.yaml.example +++ b/configs/models.yaml.example @@ -5,6 +5,15 @@ # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. +diffusers-1.4: + description: Diffusers version of Stable Diffusion version 1.4 + format: diffusers + repo_name: CompVis/stable-diffusion-v1-4 + default: true +diffusers-1.5: + description: Diffusers version of Stable Diffusion version 1.5 + format: diffusers + repo_name: runwayml/stable-diffusion-v1-5 stable-diffusion-1.5: description: The newest Stable Diffusion version 1.5 weight file (4.27 GB) weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt diff --git a/environments-and-requirements/environment-lin-amd.yml b/environments-and-requirements/environment-lin-amd.yml index 69de31aa19d..8128dbc6736 100644 --- a/environments-and-requirements/environment-lin-amd.yml +++ b/environments-and-requirements/environment-lin-amd.yml @@ -11,7 +11,7 @@ dependencies: - --extra-index-url https://download.pytorch.org/whl/rocm5.2/ - albumentations==0.4.3 - dependency_injector==4.40.0 - - diffusers==0.6.0 + - diffusers~=0.9 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/environment-lin-cuda.yml b/environments-and-requirements/environment-lin-cuda.yml index d214ea519e3..e633a7101b8 100644 --- a/environments-and-requirements/environment-lin-cuda.yml +++ b/environments-and-requirements/environment-lin-cuda.yml @@ -12,9 +12,10 @@ dependencies: - pytorch=1.12.1 - cudatoolkit=11.6 - pip: + - accelerate~=0.13 - albumentations==0.4.3 - dependency_injector==4.40.0 - - diffusers==0.6.0 + - diffusers~=0.9 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/environment-mac.yml b/environments-and-requirements/environment-mac.yml index 67489cbc09f..fea4fa8bf27 100644 --- a/environments-and-requirements/environment-mac.yml +++ b/environments-and-requirements/environment-mac.yml @@ -22,7 +22,7 @@ dependencies: - albumentations=1.2 - coloredlogs=15.0 - - diffusers=0.6 + - diffusers~=0.9 - einops=0.3 - eventlet - grpcio=1.46 diff --git a/environments-and-requirements/environment-win-cuda.yml b/environments-and-requirements/environment-win-cuda.yml index 9b43a30540e..baf10f5a13b 100644 --- a/environments-and-requirements/environment-win-cuda.yml +++ b/environments-and-requirements/environment-win-cuda.yml @@ -15,7 +15,7 @@ dependencies: - albumentations==0.4.3 - basicsr==1.4.1 - dependency_injector==4.40.0 - - diffusers==0.6.0 + - diffusers~=0.9 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/requirements-base.txt b/environments-and-requirements/requirements-base.txt index e1b605db97b..ac17601d103 100644 --- a/environments-and-requirements/requirements-base.txt +++ b/environments-and-requirements/requirements-base.txt @@ -1,7 +1,7 @@ # pip will resolve the version which matches torch albumentations dependency_injector==4.40.0 -diffusers +diffusers[torch]~=0.9 einops eventlet facexlib @@ -30,9 +30,10 @@ taming-transformers-rom1504 test-tube>=0.7.5 torch-fidelity torchmetrics -transformers==4.21.* +transformers~=4.24 picklescan git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg git+https://github.com/invoke-ai/PyPatchMatch@0.1.1#egg=pypatchmatch + diff --git a/installer/requirements.in b/installer/requirements.in index ab6b2a1ff5c..de97f06f1b9 100644 --- a/installer/requirements.in +++ b/installer/requirements.in @@ -3,7 +3,7 @@ --trusted-host https://download.pytorch.org accelerate~=0.14 albumentations -diffusers +diffusers[torch]~=0.9 einops eventlet facexlib diff --git a/ldm/generate.py b/ldm/generate.py index bbc2cc50782..01cb97cf41a 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -18,6 +18,8 @@ import hashlib import cv2 import skimage +from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ + EulerAncestralDiscreteScheduler, PNDMScheduler, IPNDMScheduler from omegaconf import OmegaConf from ldm.invoke.generator.base import downsampling @@ -401,7 +403,10 @@ def process_image(image,seed): width = width or self.width height = height or self.height - configure_model_padding(model, seamless, seamless_axes) + if isinstance(model, DiffusionPipeline): + configure_model_padding(model.unet, seamless, seamless_axes) + else: + configure_model_padding(model, seamless, seamless_axes) assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' assert threshold >= 0.0, '--threshold must be >=0.0' @@ -854,8 +859,8 @@ def set_model(self,model_name): self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' ) - self._set_sampler() self.model_name = model_name + self._set_sampler() # requires self.model_name to be set first return self.model def correct_colors(self, @@ -951,9 +956,15 @@ def sample_to_image(self, samples): def sample_to_lowres_estimated_image(self, samples): return self._make_base().sample_to_lowres_estimated_image(samples) + def _set_sampler(self): + if isinstance(self.model, DiffusionPipeline): + return self._set_scheduler() + else: + return self._set_sampler_legacy() + # very repetitive code - can this be simplified? The KSampler names are # consistent, at least - def _set_sampler(self): + def _set_sampler_legacy(self): msg = f'>> Setting Sampler to {self.sampler_name}' if self.sampler_name == 'plms': self.sampler = PLMSSampler(self.model, device=self.device) @@ -977,6 +988,47 @@ def _set_sampler(self): print(msg) + def _set_scheduler(self): + default = self.model.scheduler + + higher_order_samplers = [ + 'k_dpm_2', + 'k_dpm_2_a', + 'k_heun', + 'plms', # Its first step is like Heun + ] + scheduler_map = dict( + ddim=DDIMScheduler, + ipndm=IPNDMScheduler, + k_euler=EulerDiscreteScheduler, + k_euler_a=EulerAncestralDiscreteScheduler, + k_lms=LMSDiscreteScheduler, + pndm=PNDMScheduler, + ) + + if self.sampler_name in scheduler_map: + sampler_class = scheduler_map[self.sampler_name] + msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})' + self.sampler = sampler_class.from_config( + self.model_cache.model_name_or_path(self.model_name), + subfolder="scheduler" + ) + elif self.sampler_name in higher_order_samplers: + msg = (f'>> Unsupported Sampler: {self.sampler_name} ' + f'— diffusers does not yet support higher-order samplers, ' + f'Defaulting to {default}') + self.sampler = default + else: + msg = (f'>> Unsupported Sampler: {self.sampler_name} ' + f'Defaulting to {default}') + self.sampler = default + + print(msg) + + if not hasattr(self.sampler, 'uses_inpainting_model'): + # FIXME: terrible kludge! + self.sampler.uses_inpainting_model = lambda: False + def _load_img(self, img)->Image: if isinstance(img, Image.Image): image = img diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 5d60153a600..71ab31c133d 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -105,6 +105,8 @@ 'k_heun', 'k_lms', 'plms', + # diffusers: + "pndm", ] PRECISION_CHOICES = [ diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index a8ef62822d0..2bebc5c4083 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -2,24 +2,35 @@ Base class for ldm.invoke.generator.* including img2img, txt2img, and inpaint ''' -import torch -import numpy as np -import random +from __future__ import annotations + import os +import random import traceback -from tqdm import tqdm, trange + +import numpy as np +import torch from PIL import Image, ImageFilter, ImageChops import cv2 as cv -from einops import rearrange, repeat +from diffusers import DiffusionPipeline +from einops import rearrange from pytorch_lightning import seed_everything +from tqdm import trange + from ldm.invoke.devices import choose_autocast +from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.util import rand_perlin_2d downsampling = 8 CAUTION_IMG = 'assets/caution.png' -class Generator(): - def __init__(self, model, precision): +class Generator: + downsampling_factor: int + latent_channels: int + precision: str + model: DiffusionWrapper | DiffusionPipeline + + def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str): self.model = model self.precision = precision self.seed = None @@ -161,12 +172,12 @@ def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image blurred_init_mask = pil_init_mask multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1]) - + # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask) return matched_result - + def sample_to_lowres_estimated_image(self,samples): # origingally adapted from code by @erucipe and @keturn here: diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py new file mode 100644 index 00000000000..2d3f694687a --- /dev/null +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +import secrets +import warnings +from dataclasses import dataclass +from typing import List, Optional, Union, Callable + +import PIL.Image +import torch +from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +@dataclass +class PipelineIntermediateState: + run_id: str + step: int + timestep: int + latents: torch.Tensor + predicted_original: Optional[torch.Tensor] = None + + +class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline. + Hopefully future versions of diffusers provide access to more of these functions so that we don't + need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384 + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + ID_LENGTH = 8 + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: Optional[StableDiffusionSafetyChecker], + feature_extractor: Optional[CLIPFeatureExtractor], + ): + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + # InvokeAI's interface for text embeddings and whatnot + self.clip_embedder = WeightedFrozenCLIPEmbedder( + tokenizer=self.tokenizer, + transformer=self.text_encoder + ) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) + + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None]=None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, + run_id=None, + **extra_step_kwargs) -> StableDiffusionPipelineOutput: + r""" + Function invoked when calling the pipeline for generation. + + :param latents: Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for + image generation. Can be used to tweak the same generation with different prompts. + :param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + :param text_embeddings: + :param guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + :param callback: + :param extra_conditioning_info: + :param run_id: + :param extra_step_kwargs: + """ + self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) + result = None + for result in self.generate_from_embeddings( + latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + run_id=run_id, **extra_step_kwargs): + if callback is not None and isinstance(result, PipelineIntermediateState): + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + + def generate( + self, + prompt: Union[str, List[str]], + *, + opposing_prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + run_id: str = None, + **extra_step_kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=opposing_prompt) + text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2) + self.scheduler.set_timesteps(num_inference_steps) + latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels, + height=height, width=width, + dtype=self.unet.dtype, device=self._execution_device, + generator=generator, + latents=latents) + + yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings, + guidance_scale, run_id=run_id, **extra_step_kwargs) + + def generate_from_embeddings( + self, + latents: torch.Tensor, + text_embeddings: torch.Tensor, + unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, + run_id: str = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + timesteps = None, + **extra_step_kwargs): + if run_id is None: + run_id = secrets.token_urlsafe(self.ID_LENGTH) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, + step_count=len(self.scheduler.timesteps)) + else: + self.invokeai_diffuser.remove_cross_attention_control() + + if timesteps is None: + timesteps = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents *= self.scheduler.init_noise_sigma + yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, + latents=latents) + + batch_size = latents.shape[0] + batched_t = torch.full((batch_size,), timesteps[0], + dtype=timesteps.dtype, device=self.unet.device) + # NOTE: Depends on scheduler being already initialized! + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t.fill_(t) + step_output = self.step(batched_t, latents, guidance_scale, + text_embeddings, unconditioned_embeddings, + i, **extra_step_kwargs) + latents = step_output.prev_sample + predicted_original = getattr(step_output, 'pred_original_sample', None) + yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, + predicted_original=predicted_original) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + with torch.inference_mode(): + image = self.decode_latents(latents) + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + yield self.check_for_safety(output, dtype=text_embeddings.dtype) + + @torch.inference_mode() + def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + step_index:int | None = None, + **extra_step_kwargs): + # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value + timestep = t[0] + + # TODO: should this scaling happen here or inside self._unet_forward? + # i.e. before or after passing it to InvokeAIDiffuserComponent + latent_model_input = self.scheduler.scale_model_input(latents, timestep) + + # predict the noise residual + noise_pred = self.invokeai_diffuser.do_diffusion_step( + latent_model_input, t, + unconditioned_embeddings, text_embeddings, + guidance_scale, + step_index=step_index) + + # compute the previous noisy sample x_t -> x_t-1 + return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs) + + def _unet_forward(self, latents, t, text_embeddings): + # predict the noise residual + return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + + def img2img_from_embeddings(self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float, + num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None] = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + run_id=None, + noise_func=None, + **extra_step_kwargs) -> StableDiffusionPipelineOutput: + device = self.unet.device + latents_dtype = text_embeddings.dtype + batch_size = 1 + num_images_per_prompt = 1 + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image.convert('RGB')) + + img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) + img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + + result = None + for result in self.generate_from_embeddings( + latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, + run_id=run_id, **extra_step_kwargs): + if callback is not None and isinstance(result, PipelineIntermediateState): + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + + def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor: + # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents + # because we have our own noise function + init_image = init_image.to(device=device, dtype=dtype) + with torch.inference_mode(): + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible! + init_latents = 0.18215 * init_latents + + noise = noise_func(init_latents) + + return self.scheduler.add_noise(init_latents, noise, timestep) + + def check_for_safety(self, output, dtype): + with torch.inference_mode(): + screened_images, has_nsfw_concept = self.run_safety_checker( + output.images, device=self._execution_device, dtype=dtype) + return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept) + + @torch.inference_mode() + def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): + """ + Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. + """ + return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights) + + @property + def cond_stage_model(self): + warnings.warn("legacy compatibility layer", DeprecationWarning) + return self.clip_embedder + + @torch.inference_mode() + def _tokenize(self, prompt: Union[str, List[str]]): + return self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + @property + def channels(self) -> int: + """Compatible with DiffusionWrapper""" + return self.unet.in_channels diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py index dc6af35a6c9..0b9fda7ac29 100644 --- a/ldm/invoke/generator/embiggen.py +++ b/ldm/invoke/generator/embiggen.py @@ -3,14 +3,16 @@ and generates with ldm.invoke.generator.img2img ''' +import numpy as np import torch -import numpy as np +from PIL import Image from tqdm import trange -from PIL import Image -from ldm.invoke.generator.base import Generator -from ldm.invoke.generator.img2img import Img2Img + from ldm.invoke.devices import choose_autocast -from ldm.models.diffusion.ddim import DDIMSampler +from ldm.invoke.generator.base import Generator +from ldm.invoke.generator.img2img import Img2Img +from ldm.models.diffusion.ddim import DDIMSampler + class Embiggen(Generator): def __init__(self, model, precision): @@ -493,7 +495,7 @@ def make_image(): # Layer tile onto final image outputsuperimage.alpha_composite(intileimage, (left, top)) else: - print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') + print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') # after internal loops and patching up return Embiggen image return outputsuperimage diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 1981b4eacb6..6ea41fda33c 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -3,14 +3,10 @@ ''' import torch -import numpy as np -import PIL -from torch import Tensor -from PIL import Image -from ldm.invoke.devices import choose_autocast + from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline + class Img2Img(Generator): def __init__(self, model, precision): @@ -25,66 +21,51 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, """ self.perlin = perlin - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image.convert('RGB')) - - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - t_enc = int(strength * steps) uc, c, extra_conditioning_info = conditioning + # noinspection PyTypeChecker + pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.scheduler = sampler + def make_image(x_T): - # encode (scaled latent) - z_enc = sampler.stochastic_encode( - self.init_latent, - torch.tensor([t_enc]).to(self.model.device), - noise=x_T - ) - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler - extra_conditioning_info = extra_conditioning_info, - all_timesteps_count = steps + # FIXME: use x_T for initial seeded noise + pipeline_output = pipeline.img2img_from_embeddings( + init_image, strength, steps, c, uc, cfg_scale, + extra_conditioning_info=extra_conditioning_info, + noise_func=self.get_noise_like, + callback=step_callback ) - return self.sample_to_image(samples) + return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image - def get_noise(self,width,height): - device = self.model.device - init_latent = self.init_latent - assert init_latent is not None,'call to get_noise() when init_latent not set' + def get_noise_like(self, like: torch.Tensor): + device = like.device if device.type == 'mps': - x = torch.randn_like(init_latent, device='cpu').to(device) + x = torch.randn_like(like, device='cpu').to(device) else: - x = torch.randn_like(init_latent, device=device) + x = torch.randn_like(like, device=device) if self.perlin > 0.0: - shape = init_latent.shape + shape = like.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) return x - def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: - image = np.array(image).astype(np.float32) / 255.0 - if len(image.shape) == 2: # 'L' image, as in a mask - image = image[None,None] - else: # 'RGB' image - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - if normalize: - image = 2.0 * image - 1.0 - return image.to(self.model.device) + def get_noise(self,width,height): + # copy of the Txt2Img.get_noise + device = self.model.device + if self.use_mps_noise or device.type == 'mps': + x = torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device='cpu').to(device) + else: + x = torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=device) + if self.perlin > 0.0: + x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) + return x diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index ae1e4a946bb..16b31e5cf02 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -3,17 +3,18 @@ ''' import math -import torch -import torchvision.transforms as T -import numpy as np -import cv2 as cv + import PIL +import cv2 as cv +import numpy as np +import torch from PIL import Image, ImageFilter, ImageOps, ImageChops -from skimage.exposure.histogram_matching import match_histograms -from einops import rearrange, repeat -from ldm.invoke.devices import choose_autocast -from ldm.invoke.generator.img2img import Img2Img -from ldm.models.diffusion.ddim import DDIMSampler +from einops import repeat + +from ldm.invoke.devices import choose_autocast +from ldm.invoke.generator.base import downsampling +from ldm.invoke.generator.img2img import Img2Img +from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ksampler import KSampler from ldm.invoke.generator.base import downsampling from ldm.util import debug_image @@ -53,7 +54,7 @@ def get_tile_images(self, image: np.ndarray, width=8, height=8): writeable=False ) - def infill_patchmatch(self, im: Image.Image) -> Image: + def infill_patchmatch(self, im: Image.Image) -> Image: if im.mode != 'RGBA': return im @@ -215,7 +216,7 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, init_filled = init_filled.resize((inpaint_width, inpaint_height)) debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) - + # Create init tensor init_image = self._image_to_tensor(init_filled.convert('RGB')) @@ -245,7 +246,7 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, # klms samplers not supported yet, so ignore previous sampler if isinstance(sampler,KSampler): print( - f">> Using recommended DDIM sampler for inpainting." + ">> Using recommended DDIM sampler for inpainting." ) sampler = DDIMSampler(self.model, device=self.model.device) diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py index 35c4a62d664..d2d3ee1ed66 100644 --- a/ldm/invoke/generator/omnibus.py +++ b/ldm/invoke/generator/omnibus.py @@ -5,7 +5,6 @@ from einops import repeat from PIL import Image, ImageOps, ImageChops from ldm.invoke.devices import choose_autocast -from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.txt2img import Txt2Img @@ -56,8 +55,6 @@ def get_make_image( self.mask_blur_radius = mask_blur_radius - t_enc = steps - if init_image is not None and mask_image is not None: # inpainting masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero @@ -162,10 +159,10 @@ def get_noise(self, width:int, height:int): def sample_to_image(self, samples)->Image.Image: gen_result = super().sample_to_image(samples).convert('RGB') - + if self.pil_image is None or self.pil_mask is None: return gen_result corrected_result = super(Img2Img, self).repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) - + return corrected_result diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index ba49d2ef558..f9af1ac3ed7 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -1,11 +1,11 @@ ''' ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' - +import PIL.Image import torch -import numpy as np -from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent + +from .base import Generator +from .diffusers_pipeline import StableDiffusionGeneratorPipeline class Txt2Img(Generator): @@ -14,7 +14,8 @@ def __init__(self, model, precision): @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): + conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0, + **kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it @@ -23,38 +24,32 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, self.perlin = perlin uc, c, extra_conditioning_info = conditioning - @torch.no_grad() - def make_image(x_T): - shape = [ - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor, - ] + # noinspection PyTypeChecker + pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.scheduler = sampler - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) + def make_image(x_T) -> PIL.Image.Image: + # FIXME: restore free_gpu_mem functionality + # if self.free_gpu_mem and self.model.model.device != self.model.device: + # self.model.model.to(self.model.device) - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x_T, - conditioning = c, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - extra_conditioning_info = extra_conditioning_info, - eta = ddim_eta, - img_callback = step_callback, - threshold = threshold, + pipeline_output = pipeline.image_from_embeddings( + latents=x_T, + num_inference_steps=steps, + text_embeddings=c, + unconditioned_embeddings=uc, + guidance_scale=cfg_scale, + callback=step_callback, + extra_conditioning_info=extra_conditioning_info, + # TODO: eta = ddim_eta, + # TODO: threshold = threshold, ) - if self.free_gpu_mem: - self.model.model.to("cpu") + # FIXME: restore free_gpu_mem functionality + # if self.free_gpu_mem: + # self.model.model.to("cpu") - return self.sample_to_image(samples) + return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 759ba2dba4e..3da42ebb8af 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -2,14 +2,15 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' -import torch -import numpy as np import math + +import torch +from PIL import Image + from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.ddim import DDIMSampler from ldm.invoke.generator.omnibus import Omnibus -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from PIL import Image +from ldm.models.diffusion.ddim import DDIMSampler + class Txt2Img2Img(Generator): def __init__(self, model, precision): diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 645a6fd4da1..60b9b06d38b 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -4,26 +4,31 @@ below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. ''' +from __future__ import annotations -import torch -import os -import io -import time +import contextlib import gc import hashlib -import psutil +import io +import os import sys -import transformers -import traceback import textwrap -import contextlib +import time +import traceback +import warnings +from pathlib import Path from typing import Union + +import torch +import transformers from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError -from ldm.util import instantiate_from_config, ask_user -from ldm.invoke.globals import Globals from picklescan.scanner import scan_file_path +from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline +from ldm.invoke.globals import Globals +from ldm.util import instantiate_from_config, ask_user + DEFAULT_MAX_MODELS=2 class ModelCache(object): @@ -101,7 +106,7 @@ def get_model(self, model_name:str): 'hash': hash } - def default_model(self) -> str: + def default_model(self) -> str | None: ''' Returns the name of the default model, or None if none is defined. @@ -197,12 +202,46 @@ def _load_model(self, model_name:str): print(f'"{model_name}" is not a known model name. Please check your models.yaml file') mconfig = self.config[model_name] + + # for usage statistics + if self._has_cuda(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + tic = time.time() + + # this does the work + model_format = mconfig.get('format', 'ckpt') + if model_format == 'ckpt': + weights = mconfig.weights + print(f'>> Loading {model_name} from {weights}') + model, width, height, model_hash = self._load_ckpt_model(model_name, mconfig) + elif model_format == 'diffusers': + model, width, height, model_hash = self._load_diffusers_model(mconfig) + else: + raise NotImplementedError(f"Unknown model format {model_name}: {model_format}") + + # usage statistics + toc = time.time() + print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) + if self._has_cuda(): + print( + '>> Max VRAM used to load the model:', + '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), + '\n>> Current VRAM usage:' + '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), + ) + return model, width, height, model_hash + + def _load_ckpt_model(self, model_name, mconfig): config = mconfig.config weights = mconfig.weights vae = mconfig.get('vae') width = mconfig.width height = mconfig.height + if not os.path.isabs(config): + config = os.path.join(Globals.root,config) if not os.path.isabs(weights): weights = os.path.normpath(os.path.join(Globals.root,weights)) # scan model @@ -223,7 +262,7 @@ def _load_model(self, model_name:str): omega_config = OmegaConf.load(config) with open(weights,'rb') as f: weight_bytes = f.read() - model_hash = self._cached_sha256(weights,weight_bytes) + model_hash = self._cached_sha256(weights, weight_bytes) sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu') del weight_bytes sd = sd['state_dict'] @@ -271,7 +310,60 @@ def _load_model(self, model_name:str): ) return model, width, height, model_hash - + + def _load_diffusers_model(self, mconfig): + pipeline_args = {} + + if 'repo_name' in mconfig: + name_or_path = mconfig['repo_name'] + model_hash = "FIXME" + # model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash + elif 'path' in mconfig: + name_or_path = Path(mconfig['path']) + # FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all + # the submodels hashed together? The commit ID from the repo? + model_hash = "FIXME TOO" + else: + raise ValueError("Model config must specify either repo_name or path.") + + print(f'>> Loading diffusers model from {name_or_path}') + + # TODO: scan weights maybe? + + if self.precision == 'float16': + print(' | Using faster float16 precision') + pipeline_args.update(revision="fp16", torch_dtype=torch.float16) + else: + # TODO: more accurately, "using the model's default precision." + # How do we find out what that is? + print(' | Using more accurate float32 precision') + + pipeline = StableDiffusionGeneratorPipeline.from_pretrained( + name_or_path, + safety_checker=None, # TODO + # TODO: alternate VAE + # TODO: local_files_only=True + **pipeline_args + ) + pipeline.to(self.device) + + width = pipeline.vae.block_out_channels[-1] + height = pipeline.vae.block_out_channels[-1] + + return pipeline, width, height, model_hash + + def model_name_or_path(self, model_name:str) -> str | Path: + if model_name not in self.config: + raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file') + + mconfig = self.config[model_name] + if 'repo_name' in mconfig: + return mconfig['repo_name'] + elif 'path' in mconfig: + return Path(mconfig['path']) + else: + raise ValueError("Model config must specify either repo_name or path.") + def offload_model(self, model_name:str) -> None: ''' Offload the indicated model to CPU. Will call @@ -355,10 +447,13 @@ def _invalidate_cached_model(self,model_name:str) -> None: def _model_to_cpu(self,model): if self.device != 'cpu': - model.cond_stage_model.device = 'cpu' - model.first_stage_model.to('cpu') - model.cond_stage_model.to('cpu') - model.model.to('cpu') + try: + model.cond_stage_model.device = 'cpu' + model.first_stage_model.to('cpu') + model.cond_stage_model.to('cpu') + model.model.to('cpu') + except AttributeError as e: + warnings.warn(f"TODO: clean up legacy model-management: {e}") return model.to('cpu') else: return model diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index a4362e07704..ec7c3c215cc 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,8 +1,10 @@ import enum +import warnings from typing import Optional import torch + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -244,19 +246,39 @@ def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim return attention_slice - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.identifier = name + cross_attention_modules = [(name, module) for (name, module) in unet.named_modules() + if type(module).__name__ == "CrossAttention"] + for identifier, module in cross_attention_modules: + module.identifier = identifier + try: module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ - context.get_slicing_strategy(module_identifier)) + module.set_slicing_strategy_getter( + lambda module: context.get_slicing_strategy(identifier) + ) + except AttributeError as e: + if is_attribute_error_about(e, 'set_attention_slice_wrangler'): + warnings.warn(f"TODO: implement for {type(module)}") # TODO + else: + raise def remove_attention_function(unet): - # clear wrangler callback - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": + cross_attention_modules = [module for (_, module) in unet.named_modules() + if type(module).__name__ == "CrossAttention"] + for module in cross_attention_modules: + try: + # clear wrangler callback module.set_attention_slice_wrangler(None) module.set_slicing_strategy_getter(None) + except AttributeError as e: + if is_attribute_error_about(e, 'set_attention_slice_wrangler'): + warnings.warn(f"TODO: implement for {type(module)}") # TODO + else: + raise + + +def is_attribute_error_about(error: AttributeError, attribute: str): + if hasattr(error, 'name'): # Python 3.10 + return error.name == attribute + else: # Python 3.9 + return attribute in str(error) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index d748c9a6735..d6ec1ea44bf 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -35,6 +35,7 @@ def __init__(self, model, model_forward_callback: :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ + self.conditioning = None self.model = model self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index be9f88cdd2e..c76edce3c59 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,5 +1,7 @@ import math import os.path +from typing import Optional + import torch import torch.nn as nn from functools import partial @@ -235,26 +237,28 @@ def encode(self, x): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" + tokenizer: CLIPTokenizer + transformer: CLIPTextModel def __init__( self, - version='openai/clip-vit-large-patch14', - device=choose_torch_device(), - max_length=77, + version:str='openai/clip-vit-large-patch14', + max_length:int=77, + tokenizer:Optional[CLIPTokenizer]=None, + transformer:Optional[CLIPTextModel]=None, ): super().__init__() cache = os.path.join(Globals.root,'models',version) - self.tokenizer = CLIPTokenizer.from_pretrained( + self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained( version, cache_dir=cache, local_files_only=True ) - self.transformer = CLIPTextModel.from_pretrained( + self.transformer = transformer or CLIPTextModel.from_pretrained( version, cache_dir=cache, local_files_only=True ) - self.device = device self.max_length = max_length self.freeze() @@ -460,6 +464,14 @@ def forward(self, text, **kwargs): def encode(self, text, **kwargs): return self(text, **kwargs) + @property + def device(self): + return self.transformer.device + + @device.setter + def device(self, device): + self.transformer.to(device=device) + class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): fragment_weights_key = "fragment_weights" diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index 0cc8c9d9928..7089795db02 100755 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -8,30 +8,38 @@ # print('Loading Python libraries...\n') import argparse -import sys import os import re -import warnings import shutil -from urllib import request -from tqdm import tqdm -from omegaconf import OmegaConf -from huggingface_hub import HfFolder, hf_hub_url +import sys +import traceback +import warnings from pathlib import Path +from typing import Dict +from urllib import request + +import requests +import transformers +from diffusers import StableDiffusionPipeline, AutoencoderKL from getpass_asterisk import getpass_asterisk +from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami +from omegaconf import OmegaConf +from tqdm import tqdm from transformers import CLIPTokenizer, CLIPTextModel + from ldm.invoke.globals import Globals from ldm.invoke.readline import generic_completer -import traceback -import requests -import clip -import transformers -import warnings warnings.filterwarnings('ignore') import torch transformers.logging.set_verbosity_error() +try: + from ldm.invoke.model_cache import ModelCache +except ImportError: + sys.path.append('.') + from ldm.invoke.model_cache import ModelCache + #--------------------------globals----------------------- Model_dir = 'models' Weights_dir = 'ldm/stable-diffusion-v1/' @@ -263,6 +271,19 @@ def download_weight_datasets(models:dict, access_token:str): print(f'Successfully installed {keys}') return successful +#--------------------------------------------- +def is_huggingface_authenticated(): + # huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError, + # maybe other things, not all end-user-friendly. + # noinspection PyBroadException + try: + response = hf_whoami() + if response.get('id') is not None: + return True + except Exception: + pass + return False + #--------------------------------------------- def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool: model_dest = os.path.join(model_dir, model_name) @@ -330,6 +351,58 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'): print(traceback.format_exc()) +#--------------------------------------------- +def download_diffusers(models: Dict, full_precision: bool): + # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, + # which moves a bunch of stuff. + # We can be more complete after we know it won't be all merge conflicts. + diffusers_repos = { + 'CompVis/stable-diffusion-v1-4-original': 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5': 'runwayml/stable-diffusion-v1-5', + 'runwayml/stable-diffusion-inpainting': 'runwayml/stable-diffusion-inpainting', + 'hakurei/waifu-diffusion-v1-3': 'hakurei/waifu-diffusion' + } + vae_repos = { + 'stabilityai/sd-vae-ft-mse-original': 'stabilityai/sd-vae-ft-mse', + } + precision_args = {} + if not full_precision: + precision_args.update(revision='fp16') + + for model_name, model in models.items(): + repo_id = model['repo_id'] + if repo_id in vae_repos: + print(f" * Downloading diffusers VAE {model_name}...") + # TODO: can we autodetect when a repo has no fp16 revision? + AutoencoderKL.from_pretrained(repo_id) + elif repo_id not in diffusers_repos: + print(f" * Downloading diffusers {model_name}...") + StableDiffusionPipeline.from_pretrained(repo_id, **precision_args) + else: + warnings.warn(f" ⚠ FIXME: add diffusers repo for {repo_id}") + continue + + +def download_diffusers_in_config(config_path: Path, full_precision: bool): + # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, + # which moves a bunch of stuff. + # We can be more complete after we know it won't be all merge conflicts. + if not is_huggingface_authenticated(): + print("*⚠ No Hugging Face access token; some downloads may be blocked.") + + precision = 'full' if full_precision else 'float16' + cache = ModelCache(OmegaConf.load(config_path), precision=precision, + device_type='cpu', max_loaded_models=1) + for model_name in cache.list_models(): + # TODO: download model without loading it. + # https://github.com/huggingface/diffusers/issues/1301 + model_config = cache.config[model_name] + if model_config.get('format') == 'diffusers': + print(f" * Downloading diffusers {model_name}...") + cache.get_model(model_name) + cache.offload_model(model_name) + + #--------------------------------------------- def update_config_file(successfully_downloaded:dict,opt:dict): config_file = opt.config_file or Default_config_file @@ -402,7 +475,7 @@ def download_bert(): print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) - from transformers import BertTokenizerFast, AutoFeatureExtractor + from transformers import BertTokenizerFast download_from_hf(BertTokenizerFast,'bert-base-uncased') print('...success',file=sys.stderr) @@ -649,6 +722,12 @@ def main(): action=argparse.BooleanOptionalAction, default=True, help='run in interactive mode (default)') + parser.add_argument('--full-precision', + dest='full_precision', + action=argparse.BooleanOptionalAction, + type=bool, + default=False, + help='use 32-bit weights instead of faster 16-bit weights') parser.add_argument('--yes','-y', dest='yes_to_all', action='store_true', @@ -686,6 +765,12 @@ def main(): if opt.interactive: print('** DOWNLOADING DIFFUSION WEIGHTS **') download_weights(opt) + else: + config_path = Path(Globals.root, opt.config_file or Default_config_file) + if config_path.exists(): + download_diffusers_in_config(config_path, full_precision=opt.full_precision) + else: + print(f"*⚠ No config file found; downloading no weights. Looked in {config_path}") print('\n** DOWNLOADING SUPPORT MODELS **') download_bert() download_clip()