From 6012e9396a577d8498283b207d34f8bc079f1bf4 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:15:57 -0300 Subject: [PATCH 01/41] add load textual inversion embeddings draft --- .../pipeline_stable_diffusion.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 24447c6a6729..d710f1132139 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -178,6 +178,34 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) + def load_textual_inversion_embeddings(self, embeddings): + r""" + Loads textual inversion embeddings. Receives a dictionary with the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the + text encoder's embedding matrix. + """ + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + # load embedding from embedding path then convert it to self.text_encoder's device and dtype + embedding = torch.load(embedding_path) + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids("token") + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. From d4642c796d1749fb5d206ab89df33b58f903f2be Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:18:44 -0300 Subject: [PATCH 02/41] fix quality --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d710f1132139..9eba41e668f9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -183,9 +183,9 @@ def load_textual_inversion_embeddings(self, embeddings): Loads textual inversion embeddings. Receives a dictionary with the following keys: - `token`: name of the token to be added to the tokenizers' vocabulary - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the - text encoder's embedding matrix. + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text + encoder's embedding matrix. """ for token, embedding_path in embeddings.items(): # check if token in tokenizer vocab From c5ffdc3d169de43f0dc56808eca2ee74da99acee Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:38:23 -0300 Subject: [PATCH 03/41] fix typo --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9eba41e668f9..9149afd82f47 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -202,7 +202,7 @@ def load_textual_inversion_embeddings(self, embeddings): self.tokenizer.add_tokens([token]) - token_id = self.tokenizer.convert_tokens_to_ids("token") + token_id = self.tokenizer.convert_tokens_to_ids(token) self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding From 525428de62f54f49ac63ea2ffd9a71248ed23549 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:42:22 -0300 Subject: [PATCH 04/41] make fix copies --- .../alt_diffusion/pipeline_alt_diffusion.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 06919596df8a..2482cf997ec5 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -181,6 +181,34 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) + def load_textual_inversion_embeddings(self, embeddings): + r""" + Loads textual inversion embeddings. Receives a dictionary with the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text + encoder's embedding matrix. + """ + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + # load embedding from embedding path then convert it to self.text_encoder's device and dtype + embedding = torch.load(embedding_path) + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. From fdec2d05bfb27ba05e1cca8d4945b62028730e3c Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 09:35:54 -0300 Subject: [PATCH 05/41] move to textual inversion mixin --- src/diffusers/__init__.py | 1 + .../alt_diffusion/pipeline_alt_diffusion.py | 28 ----------------- .../pipeline_stable_diffusion.py | 31 ++----------------- src/diffusers/textual_inversion_utils.py | 31 +++++++++++++++++++ 4 files changed, 34 insertions(+), 57 deletions(-) create mode 100644 src/diffusers/textual_inversion_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 681c598eb6e9..1f421f287e5e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ UnCLIPScheduler, VQDiffusionScheduler, ) + from .textual_inversion_utils import TextualInversionMixin from .training_utils import EMAModel try: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 49492456dcb9..5166cbb294c6 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -181,34 +181,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def load_textual_inversion_embeddings(self, embeddings): - r""" - Loads textual inversion embeddings. Receives a dictionary with the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text - encoder's embedding matrix. - """ - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - - # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding = torch.load(embedding_path) - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) - - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b619f2b8b26c..e0854673e0e8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -34,6 +34,7 @@ from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +from ... textual_inversion_utils import TextualInversionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -53,7 +54,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -178,34 +179,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def load_textual_inversion_embeddings(self, embeddings): - r""" - Loads textual inversion embeddings. Receives a dictionary with the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text - encoder's embedding matrix. - """ - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - - # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding = torch.load(embedding_path) - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) - - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py new file mode 100644 index 000000000000..c22e93fd2ee3 --- /dev/null +++ b/src/diffusers/textual_inversion_utils.py @@ -0,0 +1,31 @@ +import torch + + +class TextualInversionMixin: + def load_textual_inversion_embeddings(self, embeddings): + r""" + Loads textual inversion embeddings. Receives a dictionary with the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text + encoder's embedding matrix. + """ + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + # load embedding from embedding path then convert it to self.text_encoder's device and dtype + embedding = torch.load(embedding_path) + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding From 5ec8feaab7a4d58daf28bff30137c7addd6e6d39 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 09:45:52 -0300 Subject: [PATCH 06/41] make it accept from sd-concept library --- src/diffusers/textual_inversion_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index c22e93fd2ee3..624fa0986dd3 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -20,7 +20,11 @@ def load_textual_inversion_embeddings(self, embeddings): ) # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding = torch.load(embedding_path) + embedding_dict = torch.load(embedding_path) + + # get the first key from embedding dict, gets its value and assign it to embedding + embedding = list(embedding_dict.values())[0] + embedding = embedding.to(self.text_encoder.device) embedding = embedding.to(self.text_encoder.dtype) From 5d58240ced97bf72ddd5cac4e713f641bf4e0948 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 09:57:07 -0300 Subject: [PATCH 07/41] accept list of paths to embeddings --- src/diffusers/textual_inversion_utils.py | 70 ++++++++++++++++++------ 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index 624fa0986dd3..b25be0c22309 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -2,34 +2,68 @@ class TextualInversionMixin: + textual_inversion_tokens = [] + def load_textual_inversion_embeddings(self, embeddings): r""" - Loads textual inversion embeddings. Receives a dictionary with the following keys: + Loads textual inversion embeddings. + + Receives a dictionary with the following keys: - `token`: name of the token to be added to the tokenizers' vocabulary - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + Alternatively, it can receive a list of pathes to embedding dictionaries, where the keys are the tokens and the + values are the embeddings. In that case, it will iterate through the list and add the tokens and embeddings to + the tokenizer's vocabulary and the text encoder's embedding matrix. + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text encoder's embedding matrix. """ - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding_dict = torch.load(embedding_path) + if isinstance(embeddings, dict): + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + embedding_dict = torch.load(embedding_path) + embedding = list(embedding_dict.values())[0] + + self.add_textual_inversion_embedding(token, embedding) + + elif isinstance(embeddings, list): + for embedding_path in embeddings: + embedding_dict = torch.load(embedding_path) + token = list(embedding_dict.keys())[0] + embedding = embedding_dict[token] + + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + self.add_textual_inversion_embedding(token, embedding) + + def add_textual_inversion_embedding(self, token, embedding): + r""" + Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. + """ + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError(f"Token {token} already in tokenizer vocabulary. Please choose a different token name.") - # get the first key from embedding dict, gets its value and assign it to embedding - embedding = list(embedding_dict.values())[0] + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) + self.tokenizer.add_tokens([token]) - self.tokenizer.add_tokens([token]) + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + self.textual_inversion_tokens.append(token) From 530a208f471e6dfed4a650ad0e69c97ca6fe1d75 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 10:01:19 -0300 Subject: [PATCH 08/41] fix styling of stable diffusion pipeline --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e0854673e0e8..462945971da4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -30,11 +30,11 @@ LMSDiscreteScheduler, PNDMScheduler, ) +from ...textual_inversion_utils import TextualInversionMixin from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker -from ... textual_inversion_utils import TextualInversionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 8e50514187c7d507b275bc5c1c51d06ca47a7323 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 10:02:22 -0300 Subject: [PATCH 09/41] add dummy TextualInversionMixin --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 62c2bbc2732d..21d0398690da 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -602,6 +602,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TextualInversionMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EMAModel(metaclass=DummyObject): _backends = ["torch"] From b73098746dec5bdf699b2098f06dfe890aca8491 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 10:06:28 -0300 Subject: [PATCH 10/41] add docstring to textualinversionmixin --- src/diffusers/textual_inversion_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index b25be0c22309..bfbd58e4a5f4 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -2,6 +2,15 @@ class TextualInversionMixin: + r""" + Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: + - [`~TextualInversionMixin.load_textual_inversion_embeddings`] + - [`~TextualInversionMixin.add_textual_inversion_embedding`] + + Class attributes: + - **textual_inversion_tokens** (`List[str]`): list of tokens added to the tokenizer's vocabulary and the text + encoder's embedding matrix + """ textual_inversion_tokens = [] def load_textual_inversion_embeddings(self, embeddings): From 65b76f829a88316bd788cd80c0e59783da5eb354 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:15:57 -0300 Subject: [PATCH 11/41] add load textual inversion embeddings draft --- .../pipeline_stable_diffusion.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b38ca866d58d..c8fafa1cf837 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -164,6 +164,34 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) + def load_textual_inversion_embeddings(self, embeddings): + r""" + Loads textual inversion embeddings. Receives a dictionary with the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the + text encoder's embedding matrix. + """ + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + # load embedding from embedding path then convert it to self.text_encoder's device and dtype + embedding = torch.load(embedding_path) + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids("token") + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. From 66a74896a1bbf65b0836e326f219898133c55308 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:18:44 -0300 Subject: [PATCH 12/41] fix quality --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c8fafa1cf837..430b5343a791 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -169,9 +169,9 @@ def load_textual_inversion_embeddings(self, embeddings): Loads textual inversion embeddings. Receives a dictionary with the following keys: - `token`: name of the token to be added to the tokenizers' vocabulary - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the - text encoder's embedding matrix. + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text + encoder's embedding matrix. """ for token, embedding_path in embeddings.items(): # check if token in tokenizer vocab From 82dff2118c168430809419d7eb46c9127f3bb3b6 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:38:23 -0300 Subject: [PATCH 13/41] fix typo --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 430b5343a791..941e9346235d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -188,7 +188,7 @@ def load_textual_inversion_embeddings(self, embeddings): self.tokenizer.add_tokens([token]) - token_id = self.tokenizer.convert_tokens_to_ids("token") + token_id = self.tokenizer.convert_tokens_to_ids(token) self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding From bf0424b8365756d5f9d4fe2226f86db218d054bc Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 16 Jan 2023 11:42:22 -0300 Subject: [PATCH 14/41] make fix copies --- .../alt_diffusion/pipeline_alt_diffusion.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 6978ab8e28b2..90e8a861ab29 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -167,6 +167,34 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) + def load_textual_inversion_embeddings(self, embeddings): + r""" + Loads textual inversion embeddings. Receives a dictionary with the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text + encoder's embedding matrix. + """ + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + # load embedding from embedding path then convert it to self.text_encoder's device and dtype + embedding = torch.load(embedding_path) + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. From 22e47515b26e55845477ee4325629a4a3b452369 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 09:35:54 -0300 Subject: [PATCH 15/41] move to textual inversion mixin --- src/diffusers/__init__.py | 1 + .../alt_diffusion/pipeline_alt_diffusion.py | 28 ----------------- .../pipeline_stable_diffusion.py | 31 ++----------------- src/diffusers/textual_inversion_utils.py | 31 +++++++++++++++++++ 4 files changed, 34 insertions(+), 57 deletions(-) create mode 100644 src/diffusers/textual_inversion_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4ee671c5aa03..f2e91b973254 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -86,6 +86,7 @@ UnCLIPScheduler, VQDiffusionScheduler, ) + from .textual_inversion_utils import TextualInversionMixin from .training_utils import EMAModel try: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 90e8a861ab29..6978ab8e28b2 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -167,34 +167,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def load_textual_inversion_embeddings(self, embeddings): - r""" - Loads textual inversion embeddings. Receives a dictionary with the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text - encoder's embedding matrix. - """ - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - - # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding = torch.load(embedding_path) - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) - - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 941e9346235d..7341d3c77514 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -27,6 +27,7 @@ from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +from ... textual_inversion_utils import TextualInversionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -46,7 +47,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -164,34 +165,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def load_textual_inversion_embeddings(self, embeddings): - r""" - Loads textual inversion embeddings. Receives a dictionary with the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text - encoder's embedding matrix. - """ - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - - # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding = torch.load(embedding_path) - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) - - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py new file mode 100644 index 000000000000..c22e93fd2ee3 --- /dev/null +++ b/src/diffusers/textual_inversion_utils.py @@ -0,0 +1,31 @@ +import torch + + +class TextualInversionMixin: + def load_textual_inversion_embeddings(self, embeddings): + r""" + Loads textual inversion embeddings. Receives a dictionary with the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text + encoder's embedding matrix. + """ + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + # load embedding from embedding path then convert it to self.text_encoder's device and dtype + embedding = torch.load(embedding_path) + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding From f25292c88eceefac8431853dbbaaa3cda4771063 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 09:45:52 -0300 Subject: [PATCH 16/41] make it accept from sd-concept library --- src/diffusers/textual_inversion_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index c22e93fd2ee3..624fa0986dd3 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -20,7 +20,11 @@ def load_textual_inversion_embeddings(self, embeddings): ) # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding = torch.load(embedding_path) + embedding_dict = torch.load(embedding_path) + + # get the first key from embedding dict, gets its value and assign it to embedding + embedding = list(embedding_dict.values())[0] + embedding = embedding.to(self.text_encoder.device) embedding = embedding.to(self.text_encoder.dtype) From f23185462f5d9b5b5d97a34fcb45aaac21f31e44 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 09:57:07 -0300 Subject: [PATCH 17/41] accept list of paths to embeddings --- src/diffusers/textual_inversion_utils.py | 70 ++++++++++++++++++------ 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index 624fa0986dd3..b25be0c22309 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -2,34 +2,68 @@ class TextualInversionMixin: + textual_inversion_tokens = [] + def load_textual_inversion_embeddings(self, embeddings): r""" - Loads textual inversion embeddings. Receives a dictionary with the following keys: + Loads textual inversion embeddings. + + Receives a dictionary with the following keys: - `token`: name of the token to be added to the tokenizers' vocabulary - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + Alternatively, it can receive a list of pathes to embedding dictionaries, where the keys are the tokens and the + values are the embeddings. In that case, it will iterate through the list and add the tokens and embeddings to + the tokenizer's vocabulary and the text encoder's embedding matrix. + Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text encoder's embedding matrix. """ - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - # load embedding from embedding path then convert it to self.text_encoder's device and dtype - embedding_dict = torch.load(embedding_path) + if isinstance(embeddings, dict): + for token, embedding_path in embeddings.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + embedding_dict = torch.load(embedding_path) + embedding = list(embedding_dict.values())[0] + + self.add_textual_inversion_embedding(token, embedding) + + elif isinstance(embeddings, list): + for embedding_path in embeddings: + embedding_dict = torch.load(embedding_path) + token = list(embedding_dict.keys())[0] + embedding = embedding_dict[token] + + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + self.add_textual_inversion_embedding(token, embedding) + + def add_textual_inversion_embedding(self, token, embedding): + r""" + Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. + """ + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError(f"Token {token} already in tokenizer vocabulary. Please choose a different token name.") - # get the first key from embedding dict, gets its value and assign it to embedding - embedding = list(embedding_dict.values())[0] + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) + self.tokenizer.add_tokens([token]) - self.tokenizer.add_tokens([token]) + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + self.textual_inversion_tokens.append(token) From ced8e14e06dacb60069ceabdd6af6d71c54f6418 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 10:01:19 -0300 Subject: [PATCH 18/41] fix styling of stable diffusion pipeline --- .../stable_diffusion/pipeline_stable_diffusion.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7341d3c77514..7de59e817f07 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,12 +22,20 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + KarrasDiffusionSchedulers +) +from ...textual_inversion_utils import TextualInversionMixin from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker -from ... textual_inversion_utils import TextualInversionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 5d2ef24d9f0c2ccccc54faa918989e25207c777b Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 10:02:22 -0300 Subject: [PATCH 19/41] add dummy TextualInversionMixin --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1e7c0a46a2b2..91d75243d43a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -617,6 +617,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TextualInversionMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EMAModel(metaclass=DummyObject): _backends = ["torch"] From e9284a4dd07f78bef1a6d60d1f10b43c1be5f8e7 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 17 Jan 2023 10:06:28 -0300 Subject: [PATCH 20/41] add docstring to textualinversionmixin --- src/diffusers/textual_inversion_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index b25be0c22309..bfbd58e4a5f4 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -2,6 +2,15 @@ class TextualInversionMixin: + r""" + Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: + - [`~TextualInversionMixin.load_textual_inversion_embeddings`] + - [`~TextualInversionMixin.add_textual_inversion_embedding`] + + Class attributes: + - **textual_inversion_tokens** (`List[str]`): list of tokens added to the tokenizer's vocabulary and the text + encoder's embedding matrix + """ textual_inversion_tokens = [] def load_textual_inversion_embeddings(self, embeddings): From e6f6d1c07a4f4e3fa176e273391729d021cfc1d0 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 18 Jan 2023 10:18:11 -0300 Subject: [PATCH 21/41] add case for parsing embedding from auto1111 UI format Co-authored-by: Evan Jones Co-authored-by: Ana Tamais --- src/diffusers/textual_inversion_utils.py | 29 +++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py index bfbd58e4a5f4..8490fce415b3 100644 --- a/src/diffusers/textual_inversion_utils.py +++ b/src/diffusers/textual_inversion_utils.py @@ -39,15 +39,15 @@ def load_textual_inversion_embeddings(self, embeddings): ) embedding_dict = torch.load(embedding_path) - embedding = list(embedding_dict.values())[0] + embedding = self.extract_embedding_from_dict(embedding_dict) self.add_textual_inversion_embedding(token, embedding) elif isinstance(embeddings, list): for embedding_path in embeddings: embedding_dict = torch.load(embedding_path) - token = list(embedding_dict.keys())[0] - embedding = embedding_dict[token] + token = self.extract_token_from_dict(embedding_dict) + embedding = self.extract_embedding_from_dict(embedding_dict) # check if token in tokenizer vocab # if yes, raise exception @@ -57,6 +57,29 @@ def load_textual_inversion_embeddings(self, embeddings): ) self.add_textual_inversion_embedding(token, embedding) + def extract_embedding_from_dict(self, embedding_dict): + r""" + Extracts the embedding from the embedding dictionary. + """ + # auto1111 embedding case + if "string_to_param" in embedding_dict: + embedding_dict = embedding_dict["string_to_param"] + embedding = embedding_dict["*"] + return embedding + + return list(embedding_dict.values())[0] + + def extract_token_from_dict(self, embedding_dict): + r""" + Extracts the token from the embedding dictionary. + """ + # auto1111 embedding case + if "string_to_param" in embedding_dict: + token = embedding_dict["name"] + return token + + return list(embedding_dict.keys())[0] + def add_textual_inversion_embedding(self, token, embedding): r""" Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. From bd3b59552327578b06fa06e3763f87b4f17a4d51 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 18 Jan 2023 10:27:56 -0300 Subject: [PATCH 22/41] fix style after rebase --- .../stable_diffusion/pipeline_stable_diffusion.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7de59e817f07..c0e4c351367f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,15 +22,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - KarrasDiffusionSchedulers -) +from ...schedulers import KarrasDiffusionSchedulers from ...textual_inversion_utils import TextualInversionMixin from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline From baaf3dfd25b2e90309f19ceb596c0401ef963602 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 24 Jan 2023 14:06:44 -0500 Subject: [PATCH 23/41] move textual inversion mixin to loaders --- src/diffusers/loaders.py | 153 ++++++++++++++++++++++- src/diffusers/textual_inversion_utils.py | 101 --------------- 2 files changed, 152 insertions(+), 102 deletions(-) delete mode 100644 src/diffusers/textual_inversion_utils.py diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 61754ed6d8ed..f817027e48ec 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,9 +13,10 @@ # limitations under the License. import os from collections import defaultdict -from typing import Callable, Dict, Union +from typing import Callable, Dict, Union, List import torch +from transformers import PreTrainedModel, PreTrainedTokenizer from .models.cross_attention import LoRACrossAttnProcessor from .models.modeling_utils import _get_model_file @@ -241,3 +242,153 @@ def save_attn_procs( save_function(state_dict, os.path.join(save_directory, weights_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + + +class TextualInversionLoaderMixin: + r""" + Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: + - [`~TextualInversionMixin.load_textual_inversion_embeddings`] + - [`~TextualInversionMixin.add_textual_inversion_embedding`] + """ + + def load_textual_inversion_embeddings( + self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]] + ): + r""" + Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's embeddings. + + Arguments: + embeddings (`Dict[str, str]` or `List[str]`): + Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. + The dictionary must have the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix + The list must contain paths to embedding dictionaries where the keys are the tokens and the + values are the embeddings (same as above dictionary definition). + + Returns: + None + """ + # Validate that inheriting class instance contains required attributes + self._validate_method_call(self.load_textual_inversion_embeddings) + + if isinstance(embedding_path_dict_or_list, dict): + for token, embedding_path in embedding_path_dict_or_list.items(): + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + embedding_dict = torch.load(embedding_path) + embedding = self._extract_embedding_from_dict(embedding_dict) + + self.add_textual_inversion_embedding(token, embedding) + + elif isinstance(embedding_path_dict_or_list, list): + for embedding_path in embedding_path_dict_or_list: + embedding_dict = torch.load(embedding_path) + token = self._extract_token_from_dict(embedding_dict) + embedding = self._extract_embedding_from_dict(embedding_dict) + + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + self.add_textual_inversion_embedding(token, embedding) + + def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): + r""" + Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. + + Arguments: + token (`str`): + The token to be added to the tokenizers' vocabulary + embedding (`torch.Tensor`): + The embedding of the token to be added to the text encoder's embedding matrix + """ + # NOTE: Not clear to me that we intend for this to be a public/exposed method. + # Validate that inheriting class instance contains required attributes + self._validate_method_call(self.load_textual_inversion_embeddings) + + # check if token in tokenizer vocab + # if yes, raise exception + if token in self.tokenizer.get_vocab(): + raise ValueError(f"Token {token} already in tokenizer vocabulary. Please choose a different token name.") + + embedding = embedding.to(self.text_encoder.device) + embedding = embedding.to(self.text_encoder.dtype) + + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + + def _extract_embedding_from_dict(self, embedding_dict: Dict[str, str]) -> torch.Tensor: + r""" + Extracts the embedding from the embedding dictionary. + + Arguments: + embedding_dict (`Dict[str, str]`): + The embedding dictionary loaded from the embedding path + + Returns: + embedding (`torch.Tensor`): + The embedding to be added to the text encoder's embedding matrix + """ + # auto1111 embedding case + if "string_to_param" in embedding_dict: + embedding_dict = embedding_dict["string_to_param"] + embedding = embedding_dict["*"] + return embedding + + return list(embedding_dict.values())[0] + + def _extract_token_from_dict(self, embedding_dict: Dict[str, str]) -> str: + r""" + Extracts the token from the embedding dictionary. + + Arguments: + embedding_dict (`Dict[str, str]`): + The embedding dictionary loaded from the embedding path + + Returns: + token (`str`): + The token to be added to the tokenizers' vocabulary + """ + # auto1111 embedding case + if "string_to_param" in embedding_dict: + token = embedding_dict["name"] + return token + + return list(embedding_dict.keys())[0] + + def _validate_method_call(self, method: Callable): + r""" + Validates that the method is being called from a class instance that has the required attributes. + + Arguments: + method (`function`): + The class's method being called + + Raises: + ValueError: + If the method is being called from a class instance that does not have + the required attributes, the method will not be callable. + + Returns: + None + """ + if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling `{method.__name__}`" + ) + + if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling `{method.__name__}`" + ) diff --git a/src/diffusers/textual_inversion_utils.py b/src/diffusers/textual_inversion_utils.py deleted file mode 100644 index 8490fce415b3..000000000000 --- a/src/diffusers/textual_inversion_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch - - -class TextualInversionMixin: - r""" - Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: - - [`~TextualInversionMixin.load_textual_inversion_embeddings`] - - [`~TextualInversionMixin.add_textual_inversion_embedding`] - - Class attributes: - - **textual_inversion_tokens** (`List[str]`): list of tokens added to the tokenizer's vocabulary and the text - encoder's embedding matrix - """ - textual_inversion_tokens = [] - - def load_textual_inversion_embeddings(self, embeddings): - r""" - Loads textual inversion embeddings. - - Receives a dictionary with the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - - Alternatively, it can receive a list of pathes to embedding dictionaries, where the keys are the tokens and the - values are the embeddings. In that case, it will iterate through the list and add the tokens and embeddings to - the tokenizer's vocabulary and the text encoder's embedding matrix. - - Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text - encoder's embedding matrix. - """ - - if isinstance(embeddings, dict): - for token, embedding_path in embeddings.items(): - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - - embedding_dict = torch.load(embedding_path) - embedding = self.extract_embedding_from_dict(embedding_dict) - - self.add_textual_inversion_embedding(token, embedding) - - elif isinstance(embeddings, list): - for embedding_path in embeddings: - embedding_dict = torch.load(embedding_path) - token = self.extract_token_from_dict(embedding_dict) - embedding = self.extract_embedding_from_dict(embedding_dict) - - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - self.add_textual_inversion_embedding(token, embedding) - - def extract_embedding_from_dict(self, embedding_dict): - r""" - Extracts the embedding from the embedding dictionary. - """ - # auto1111 embedding case - if "string_to_param" in embedding_dict: - embedding_dict = embedding_dict["string_to_param"] - embedding = embedding_dict["*"] - return embedding - - return list(embedding_dict.values())[0] - - def extract_token_from_dict(self, embedding_dict): - r""" - Extracts the token from the embedding dictionary. - """ - # auto1111 embedding case - if "string_to_param" in embedding_dict: - token = embedding_dict["name"] - return token - - return list(embedding_dict.keys())[0] - - def add_textual_inversion_embedding(self, token, embedding): - r""" - Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. - """ - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError(f"Token {token} already in tokenizer vocabulary. Please choose a different token name.") - - embedding = embedding.to(self.text_encoder.device) - embedding = embedding.to(self.text_encoder.dtype) - - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - - self.textual_inversion_tokens.append(token) From 314c1e2fa38768c4597c53c05beda21d012fea13 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 24 Jan 2023 14:07:51 -0500 Subject: [PATCH 24/41] move mixin inheritance to DiffusionPipeline from StableDiffusionPipeline) --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ea28ac875f81..c7b715f5545d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -32,6 +32,7 @@ from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin +from ..loaders import TextualInversionLoaderMixin from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( @@ -137,7 +138,7 @@ def is_safetensors_compatible(info) -> bool: return is_safetensors_compatible -class DiffusionPipeline(ConfigMixin): +class DiffusionPipeline(ConfigMixin, TextualInversionLoaderMixin): r""" Base class for all models. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c0e4c351367f..b38ca866d58d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -23,7 +23,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...textual_inversion_utils import TextualInversionMixin from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -47,7 +46,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionMixin): +class StableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. From 719e6a7391407ffdf507f71f7d8f457efcb0a5cf Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 24 Jan 2023 14:08:05 -0500 Subject: [PATCH 25/41] update dummy class name --- src/diffusers/utils/dummy_pt_objects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 91d75243d43a..38c215645330 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -617,7 +617,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class TextualInversionMixin(metaclass=DummyObject): +class TextualInversionLoaderMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 3790d31efac19b88c7d0e124ee58ee32ad38bfa8 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 25 Jan 2023 13:13:45 -0500 Subject: [PATCH 26/41] addressed allo comments --- src/diffusers/loaders.py | 61 ++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f817027e48ec..58c989e4d50e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -252,13 +252,13 @@ class TextualInversionLoaderMixin: """ def load_textual_inversion_embeddings( - self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]] + self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], allow_replacement: bool = False ): r""" Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's embeddings. Arguments: - embeddings (`Dict[str, str]` or `List[str]`): + embeddings_path_dict_or_list (`Dict[str, str]` or `List[str]`): Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. The dictionary must have the following keys: - `token`: name of the token to be added to the tokenizers' vocabulary @@ -275,29 +275,37 @@ def load_textual_inversion_embeddings( if isinstance(embedding_path_dict_or_list, dict): for token, embedding_path in embedding_path_dict_or_list.items(): # check if token in tokenizer vocab - # if yes, raise exception if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - - embedding_dict = torch.load(embedding_path) + if allow_replacement: + logger.info( + f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding with the new one." + ) + else: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + + embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) embedding = self._extract_embedding_from_dict(embedding_dict) self.add_textual_inversion_embedding(token, embedding) elif isinstance(embedding_path_dict_or_list, list): for embedding_path in embedding_path_dict_or_list: - embedding_dict = torch.load(embedding_path) + embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) token = self._extract_token_from_dict(embedding_dict) embedding = self._extract_embedding_from_dict(embedding_dict) # check if token in tokenizer vocab - # if yes, raise exception if token in self.tokenizer.get_vocab(): - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) + if allow_replacement: + logger.info( + f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding with the new one." + ) + else: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) self.add_textual_inversion_embedding(token, embedding) def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): @@ -314,19 +322,24 @@ def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): # Validate that inheriting class instance contains required attributes self._validate_method_call(self.load_textual_inversion_embeddings) - # check if token in tokenizer vocab - # if yes, raise exception - if token in self.tokenizer.get_vocab(): - raise ValueError(f"Token {token} already in tokenizer vocabulary. Please choose a different token name.") - - embedding = embedding.to(self.text_encoder.device) embedding = embedding.to(self.text_encoder.dtype) - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + if token in self.tokenizer.get_vocab(): + # If user has allowed replacement and the token exists, we only need to + # extract the existing id and update the embedding + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + else: + # If the token does not exist, we add it to the tokenizer, then resize and update the + # text encoder acccordingly + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + # NOTE: len() does't start at 0, so we shouldn't need to +1 + # since we already updated the tokenizer and it's new length + # should be old length + 1 + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding def _extract_embedding_from_dict(self, embedding_dict: Dict[str, str]) -> torch.Tensor: r""" From ef8ab030d7af7c1ed58945d6a69ea9ea528ae0a6 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 25 Jan 2023 14:12:53 -0500 Subject: [PATCH 27/41] fix old dangling import --- src/diffusers/__init__.py | 2 +- src/diffusers/loaders.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 75fcfecfc318..e5a62456c51a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -86,7 +86,7 @@ UnCLIPScheduler, VQDiffusionScheduler, ) - from .textual_inversion_utils import TextualInversionMixin + from .loaders import TextualInversionLoaderMixin from .training_utils import EMAModel try: diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 58c989e4d50e..d553dc47e6b3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -247,8 +247,8 @@ def save_attn_procs( class TextualInversionLoaderMixin: r""" Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: - - [`~TextualInversionMixin.load_textual_inversion_embeddings`] - - [`~TextualInversionMixin.add_textual_inversion_embedding`] + - [`~TextualInversionLoaderMixin.load_textual_inversion_embeddings`] + - [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] """ def load_textual_inversion_embeddings( From 32c86b54d0056377931c052b2eef46ec84519046 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Sat, 28 Jan 2023 20:58:56 -0300 Subject: [PATCH 28/41] fix style --- src/diffusers/__init__.py | 2 +- src/diffusers/loaders.py | 33 ++++++++++++++++++++------------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 467f81e9f940..4f55a797fb6c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -32,6 +32,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .loaders import TextualInversionLoaderMixin from .models import ( AutoencoderKL, ModelMixin, @@ -86,7 +87,6 @@ UnCLIPScheduler, VQDiffusionScheduler, ) - from .loaders import TextualInversionLoaderMixin from .training_utils import EMAModel try: diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 69d101c74d22..466ceab809f1 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,9 +13,10 @@ # limitations under the License. import os from collections import defaultdict -from typing import Callable, Dict, Union, List +from typing import Callable, Dict, List, Union import torch + from transformers import PreTrainedModel, PreTrainedTokenizer from .models.cross_attention import LoRACrossAttnProcessor @@ -255,16 +256,18 @@ def load_textual_inversion_embeddings( self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], allow_replacement: bool = False ): r""" - Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's embeddings. + Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's + embeddings. Arguments: embeddings_path_dict_or_list (`Dict[str, str]` or `List[str]`): - Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. - The dictionary must have the following keys: + Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. The + dictionary must have the following keys: - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix - The list must contain paths to embedding dictionaries where the keys are the tokens and the - values are the embeddings (same as above dictionary definition). + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding + matrix + The list must contain paths to embedding dictionaries where the keys are the tokens and the values are + the embeddings (same as above dictionary definition). Returns: None @@ -278,7 +281,8 @@ def load_textual_inversion_embeddings( if token in self.tokenizer.get_vocab(): if allow_replacement: logger.info( - f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding with the new one." + f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" + " with the new one." ) else: raise ValueError( @@ -300,7 +304,8 @@ def load_textual_inversion_embeddings( if token in self.tokenizer.get_vocab(): if allow_replacement: logger.info( - f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding with the new one." + f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" + " with the new one." ) else: raise ValueError( @@ -390,18 +395,20 @@ def _validate_method_call(self, method: Callable): Raises: ValueError: - If the method is being called from a class instance that does not have - the required attributes, the method will not be callable. + If the method is being called from a class instance that does not have the required attributes, the + method will not be callable. Returns: None """ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): raise ValueError( - f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling `{method.__name__}`" + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{method.__name__}`" ) if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): raise ValueError( - f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling `{method.__name__}`" + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" + f" `{method.__name__}`" ) From 23a36effbe731fc91126df9444f2d2bae917dedf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 23 Mar 2023 12:20:47 +0100 Subject: [PATCH 29/41] proposal --- + | 682 ++++++++++++++++++ src/diffusers/loaders.py | 246 ++++++- src/diffusers/models/modeling_utils.py | 136 +--- src/diffusers/pipelines/pipeline_utils.py | 3 +- .../pipeline_stable_diffusion.py | 11 +- src/diffusers/utils/__init__.py | 2 + src/diffusers/utils/hub_utils.py | 147 +++- 7 files changed, 1074 insertions(+), 153 deletions(-) create mode 100644 + diff --git a/+ b/+ new file mode 100644 index 000000000000..b99f8069e64f --- /dev/null +++ b/+ @@ -0,0 +1,682 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import defaultdict +from typing import Callable, Dict, List, Union, Optional + +import torch + +from transformers import PreTrainedModel, PreTrainedTokenizer +from .models.attention_processor import LoRAAttnProcessor +from .models.modeling_utils import _get_model_file +from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging + + +if is_safetensors_available(): + import safetensors + + +logger = logging.get_logger(__name__) + + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + +TEXT_INVERSION_NAME = "learned_embeds.bin" +TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" + + +class AttnProcsLayers(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = key.split(".processor")[0] + ".processor" + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +class UNet2DConditionLoadersMixin: + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be + defined in + [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + and be a `torch.nn.Module` class. + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # set correct dtype & device + attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} + + # set layers + self.set_attn_processor(attn_processors) + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + **kwargs, + ): + r""" + Save an attention processor to a directory, so that it can be re-loaded using the + `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = AttnProcsLayers(self.attn_processors) + + # Save the model + state_dict = model_to_save.state_dict() + + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + # Save the model + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + +class TextualInversionLoaderMixin: + r""" + Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: + - [`~TextualInversionLoaderMixin.load_textual_inversion_embeddings`] + - [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] + """ + + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): + if not isinstance(prompt, List): + prompts = [prompt] + + prompts = [self._maybe_convert_prompt(p) for p in prompts] + + if not isinstance(prompt, List): + return prompts[0] + + return prompts + + def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): + tokens = tokenizer.tokenize(prompt) + if not any(t in tokenizer.added_tokens_encoder for t in tokens): + return prompt + + for token in tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f"{token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def load_textual_inversion(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs): + r""" + Load textual inversion embeddings into the text encoder of stable diffusion pipelines. + + + + This function is experimental and might change in the future. + + + + Parameters: + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + self._validate_method_call(self.load_textual_inversion) + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "text_inversion", + "framework": "pytorch", + } + + # 1. Load textual inversion file + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # 2. Load token and embedding correcly from file + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError("You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`.") + embedding = state_dict + + if len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict))[0] + elif "string_to_param" in state_dict: + # A1111 + loaded_token = self._extract_token_from_dict(embedding_dict) + embedding = self._extract_embedding_from_dict(embedding_dict) + + if token is not None and loaded_token != token: + logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + # 3. Make sure we don't mess up the tokenizer or text encoder + vocab = self.tokenizer.get_vocab() + if token: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." + ) + elif f"{token}_1" in vocab: + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 + + raise ValueError( + f"Multi-vector Token {token} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) + + is_multi_vector = embedding.ndims > 1 and embedding.shape[0] > 1 + + if is_multi_vector: + tokens = [token] + [f"{token}_i" for i in range(1, embedding.shape[0])] + else: + tokens = [token] + + self.tokenizer.add_tokens(tokens) + token_ids = tokenizer.convert_tokens_to_ids(tokens) + + # 4. Load token and embedding into tokenizer + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + + if token in self.tokenizer.get_vocab(): + # If user has allowed replacement and the token exists, we only need to + # extract the existing id and update the embedding + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + else: + # If the token does not exist, we add it to the tokenizer, then resize and update the + # text encoder acccordingly + + + def load_textual_inversion_embeddings( + self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], allow_replacement: bool = False + ): + r""" + Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's + embeddings. + + Arguments: + embeddings_path_dict_or_list (`Dict[str, str]` or `List[str]`): + Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. The + dictionary must have the following keys: + - `token`: name of the token to be added to the tokenizers' vocabulary + - `embedding`: path to the embedding of the token to be added to the text encoder's embedding + matrix + The list must contain paths to embedding dictionaries where the keys are the tokens and the values are + the embeddings (same as above dictionary definition). + + Returns: + None + """ + # Validate that inheriting class instance contains required attributes + self._validate_method_call(self.load_textual_inversion_embeddings) + + if isinstance(embedding_path_dict_or_list, dict): + for token, embedding_path in embedding_path_dict_or_list.items(): + # check if token in tokenizer vocab + if token in self.tokenizer.get_vocab(): + if allow_replacement: + logger.info( + f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" + " with the new one." + ) + else: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or set `allow_replacement=True`." + ) + + embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) + embedding = self._extract_embedding_from_dict(embedding_dict) + + self.add_textual_inversion_embedding(token, embedding) + + elif isinstance(embedding_path_dict_or_list, list): + for embedding_path in embedding_path_dict_or_list: + embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) + token = self._extract_token_from_dict(embedding_dict) + embedding = self._extract_embedding_from_dict(embedding_dict) + + # check if token in tokenizer vocab + if token in self.tokenizer.get_vocab(): + if allow_replacement: + logger.info( + f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" + " with the new one." + ) + else: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + ) + self.add_textual_inversion_embedding(token, embedding) + + def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): + r""" + Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. + + Arguments: + token (`str`): + The token to be added to the tokenizers' vocabulary + embedding (`torch.Tensor`): + The embedding of the token to be added to the text encoder's embedding matrix + """ + # NOTE: Not clear to me that we intend for this to be a public/exposed method. + # Validate that inheriting class instance contains required attributes + self._validate_method_call(self.load_textual_inversion_embeddings) + + embedding = embedding.to(self.text_encoder.dtype) + + if token in self.tokenizer.get_vocab(): + # If user has allowed replacement and the token exists, we only need to + # extract the existing id and update the embedding + token_id = self.tokenizer.convert_tokens_to_ids(token) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + else: + # If the token does not exist, we add it to the tokenizer, then resize and update the + # text encoder acccordingly + self.tokenizer.add_tokens([token]) + + token_id = self.tokenizer.convert_tokens_to_ids(token) + # NOTE: len() does't start at 0, so we shouldn't need to +1 + # since we already updated the tokenizer and it's new length + # should be old length + 1 + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + + def _extract_embedding_from_dict(self, embedding_dict: Dict[str, str]) -> torch.Tensor: + r""" + Extracts the embedding from the embedding dictionary. + + Arguments: + embedding_dict (`Dict[str, str]`): + The embedding dictionary loaded from the embedding path + + Returns: + embedding (`torch.Tensor`): + The embedding to be added to the text encoder's embedding matrix + """ + # auto1111 embedding case + if "string_to_param" in embedding_dict: + embedding_dict = embedding_dict["string_to_param"] + embedding = embedding_dict["*"] + return embedding + + return list(embedding_dict.values())[0] + + def _extract_token_from_dict(self, embedding_dict: Dict[str, str]) -> str: + r""" + Extracts the token from the embedding dictionary. + + Arguments: + embedding_dict (`Dict[str, str]`): + The embedding dictionary loaded from the embedding path + + Returns: + token (`str`): + The token to be added to the tokenizers' vocabulary + """ + # auto1111 embedding case + if "string_to_param" in embedding_dict: + token = embedding_dict["name"] + return token + + return list(embedding_dict.keys())[0] + + def _validate_method_call(self, method: Callable): + r""" + Validates that the method is being called from a class instance that has the required attributes. + + Arguments: + method (`function`): + The class's method being called + + Raises: + ValueError: + If the method is being called from a class instance that does not have the required attributes, the + method will not be callable. + + Returns: + None + """ + if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{method.__name__}`" + ) + + if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" + f" `{method.__name__}`" + ) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 286f09fae192..b8eaef5b940a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,14 +13,12 @@ # limitations under the License. import os from collections import defaultdict -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Optional, Union import torch - from transformers import PreTrainedModel, PreTrainedTokenizer -from .models.attention_processor import LoRAAttnProcessor -from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging + +from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, deprecate, is_safetensors_available, logging if is_safetensors_available(): @@ -33,6 +31,9 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TEXT_INVERSION_NAME = "learned_embeds.bin" +TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): @@ -124,13 +125,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - """ @@ -160,6 +154,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict "framework": "pytorch", } + from .models.attention_processor import LoRAAttnProcessor + model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights @@ -303,6 +299,229 @@ class TextualInversionLoaderMixin: - [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] """ + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): + if not isinstance(prompt, List): + prompts = [prompt] + else: + prompts = prompt + + prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] + + if not isinstance(prompt, List): + return prompts[0] + + return prompts + + def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): + tokens = tokenizer.tokenize(prompt) + if not any(t in tokenizer.added_tokens_encoder for t in tokens): + return prompt + + for token in tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f"{token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def load_textual_inversion( + self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs + ): + r""" + Load textual inversion embeddings into the text encoder of stable diffusion pipelines. + + + + This function is experimental and might change in the future. + + + + Parameters: + + + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + self._validate_method_call(self.load_textual_inversion) + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "text_inversion", + "framework": "pytorch", + } + + # 1. Load textual inversion file + model_file = None + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + pass + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + + # 2. Load token and embedding correcly from file + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = self._extract_token_from_dict(state_dict) + embedding = self._extract_embedding_from_dict(state_dict) + + if token is not None and loaded_token != token: + logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + + # 3. Make sure we don't mess up the tokenizer or text encoder + vocab = self.tokenizer.get_vocab() + if token in vocab: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." + ) + elif f"{token}_1" in vocab: + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 + + raise ValueError( + f"Multi-vector Token {token} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) + + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + + if is_multi_vector: + tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + embeddings = [e for e in embedding] + else: + tokens = [token] + embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]] + + # add tokens and get ids + self.tokenizer.add_tokens(tokens) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + + # resize token embeddings and set new embeddings + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + for token_id, embedding in zip(token_ids, embeddings): + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + + logger.info("Loaded textual inversion embedding for {token}.") + + # TODO(to discuss) think we can remove this one def load_textual_inversion_embeddings( self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], allow_replacement: bool = False ): @@ -337,7 +556,7 @@ def load_textual_inversion_embeddings( ) else: raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or set `allow_replacement=True`." ) embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) @@ -364,6 +583,7 @@ def load_textual_inversion_embeddings( ) self.add_textual_inversion_embedding(token, embedding) + # TODO(to discuss) think we can remove this one def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): r""" Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e51b40ce4509..a48fdeba6027 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -16,27 +16,22 @@ import inspect import os -import warnings from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from packaging import version -from requests import HTTPError from torch import Tensor, device from .. import __version__ from ..utils import ( CONFIG_NAME, - DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, + _add_variant, + _get_model_file, is_accelerate_available, is_safetensors_available, is_torch_version, @@ -144,15 +139,6 @@ def load(module: torch.nn.Module, prefix=""): return error_msgs -def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: - if variant is not None: - splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] - weights_name = ".".join(splits) - - return weights_name - - class ModelMixin(torch.nn.Module): r""" Base class for all models. @@ -782,121 +768,3 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - - -def _get_model_file( - pretrained_model_name_or_path, - *, - weights_name, - subfolder, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - user_agent, - revision, - commit_hash=None, -): - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): - return pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, weights_name) - return model_file - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - return model_file - else: - raise EnvironmentError( - f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." - ) - else: - # 1. First check if deprecated way of loading from branches is used - if ( - revision in DEPRECATED_REVISION_ARGS - and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) - and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0") - ): - try: - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=_add_variant(weights_name, revision), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision or commit_hash, - ) - warnings.warn( - f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - return model_file - except: # noqa: E722 - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", - FutureWarning, - ) - try: - # 2. Load model file as usual - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=weights_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision or commit_hash, - ) - return model_file - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." - ) - except HTTPError as err: - raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {weights_name} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {weights_name}" - ) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f74a7f747110..8f33b506827a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -36,7 +36,6 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..loaders import TextualInversionLoaderMixin from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( @@ -391,7 +390,7 @@ def load_sub_model( return loaded_sub_model -class DiffusionPipeline(ConfigMixin, TextualInversionLoaderMixin): +class DiffusionPipeline(ConfigMixin): r""" Base class for all models. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 81b2cfa9bc3e..2e9935c1bd5e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -20,6 +20,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -52,7 +53,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -315,6 +316,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -375,6 +380,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d803b053be71..2cd256aeb8b8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -37,6 +37,8 @@ from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import ( HF_HUB_OFFLINE, + _add_variant, + _get_model_file, extract_commit_hash, http_user_agent, ) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 916b18d35e7e..511763ec6687 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -18,16 +18,30 @@ import re import sys import traceback +import warnings from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 -from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami +from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami from huggingface_hub.file_download import REGEX_COMMIT_HASH -from huggingface_hub.utils import is_jinja_available +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + is_jinja_available, +) +from packaging import version +from requests import HTTPError from .. import __version__ -from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT +from .constants import ( + DEPRECATED_REVISION_ARGS, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, +) from .import_utils import ( ENV_VARS_TRUE_VALUES, _flax_version, @@ -215,3 +229,130 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " "the directory exists and can be written to." ) + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, + commit_hash=None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + # 1. First check if deprecated way of loading from branches is used + if ( + revision in DEPRECATED_REVISION_ARGS + and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) + and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: # noqa: E722 + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + try: + # 2. Load model file as usual + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) From f0908989b855d29ae264c1a9ca17056f76b01c5e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 23 Mar 2023 12:21:02 +0100 Subject: [PATCH 30/41] remove bogus --- + | 682 -------------------------------------------------------------- 1 file changed, 682 deletions(-) delete mode 100644 + diff --git a/+ b/+ deleted file mode 100644 index b99f8069e64f..000000000000 --- a/+ +++ /dev/null @@ -1,682 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from collections import defaultdict -from typing import Callable, Dict, List, Union, Optional - -import torch - -from transformers import PreTrainedModel, PreTrainedTokenizer -from .models.attention_processor import LoRAAttnProcessor -from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging - - -if is_safetensors_available(): - import safetensors - - -logger = logging.get_logger(__name__) - - -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" - -TEXT_INVERSION_NAME = "learned_embeds.bin" -TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" - - -class AttnProcsLayers(torch.nn.Module): - def __init__(self, state_dict: Dict[str, torch.Tensor]): - super().__init__() - self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = {k: v for k, v in enumerate(state_dict.keys())} - self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - - # we add a hook to state_dict() and load_state_dict() so that the - # naming fits with `unet.attn_processors` - def map_to(module, state_dict, *args, **kwargs): - new_state_dict = {} - for key, value in state_dict.items(): - num = int(key.split(".")[1]) # 0 is always "layers" - new_key = key.replace(f"layers.{num}", module.mapping[num]) - new_state_dict[new_key] = value - - return new_state_dict - - def map_from(module, state_dict, *args, **kwargs): - all_keys = list(state_dict.keys()) - for key in all_keys: - replace_key = key.split(".processor")[0] + ".processor" - new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") - state_dict[new_key] = state_dict[key] - del state_dict[key] - - self._register_state_dict_hook(map_to) - self._register_load_state_dict_pre_hook(map_from, with_module=True) - - -class UNet2DConditionLoadersMixin: - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - r""" - Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be - defined in - [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) - and be a `torch.nn.Module` class. - - - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - """ - - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - if use_safetensors and not is_safetensors_available(): - raise ValueError( - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" - ) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = is_safetensors_available() - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except IOError as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - - # fill attn processors - attn_processors = {} - - is_lora = all("lora" in k for k in state_dict.keys()) - - if is_lora: - lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - lora_grouped_dict[attn_processor_key][sub_key] = value - - for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - - attn_processors[key] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank - ) - attn_processors[key].load_state_dict(value_dict) - - else: - raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") - - # set correct dtype & device - attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} - - # set layers - self.set_attn_processor(attn_processors) - - def save_attn_procs( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = False, - **kwargs, - ): - r""" - Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - """ - weight_name = weight_name or deprecate( - "weights_name", - "0.18.0", - "`weights_name` is deprecated, please use `weight_name` instead.", - take_from=kwargs, - ) - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - if safe_serialization: - - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) - - else: - save_function = torch.save - - os.makedirs(save_directory, exist_ok=True) - - model_to_save = AttnProcsLayers(self.attn_processors) - - # Save the model - state_dict = model_to_save.state_dict() - - if weight_name is None: - if safe_serialization: - weight_name = LORA_WEIGHT_NAME_SAFE - else: - weight_name = LORA_WEIGHT_NAME - - # Save the model - save_function(state_dict, os.path.join(save_directory, weight_name)) - logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - - -class TextualInversionLoaderMixin: - r""" - Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: - - [`~TextualInversionLoaderMixin.load_textual_inversion_embeddings`] - - [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] - """ - - def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): - if not isinstance(prompt, List): - prompts = [prompt] - - prompts = [self._maybe_convert_prompt(p) for p in prompts] - - if not isinstance(prompt, List): - return prompts[0] - - return prompts - - def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): - tokens = tokenizer.tokenize(prompt) - if not any(t in tokenizer.added_tokens_encoder for t in tokens): - return prompt - - for token in tokens: - if token in tokenizer.added_tokens_encoder: - replacement = token - i = 1 - while f"{token}_{i}" in tokenizer.added_tokens_encoder: - replacement += f"{token}_{i}" - i += 1 - - prompt = prompt.replace(token, replacement) - - return prompt - - def load_textual_inversion(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs): - r""" - Load textual inversion embeddings into the text encoder of stable diffusion pipelines. - - - - This function is experimental and might change in the future. - - - - Parameters: - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - """ - self._validate_method_call(self.load_textual_inversion) - - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - if use_safetensors and not is_safetensors_available(): - raise ValueError( - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" - ) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = is_safetensors_available() - allow_pickle = True - - user_agent = { - "file_type": "text_inversion", - "framework": "pytorch", - } - - # 1. Load textual inversion file - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except IOError as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or TEXT_INVERSION_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - - # 2. Load token and embedding correcly from file - if isinstance(state_dict, torch.Tensor): - if token is None: - raise ValueError("You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`.") - embedding = state_dict - - if len(state_dict) == 1: - # diffusers - loaded_token, embedding = next(iter(state_dict))[0] - elif "string_to_param" in state_dict: - # A1111 - loaded_token = self._extract_token_from_dict(embedding_dict) - embedding = self._extract_embedding_from_dict(embedding_dict) - - if token is not None and loaded_token != token: - logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") - else: - token = loaded_token - - # 3. Make sure we don't mess up the tokenizer or text encoder - vocab = self.tokenizer.get_vocab() - if token: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." - ) - elif f"{token}_1" in vocab: - multi_vector_tokens = [token] - i = 1 - while f"{token}_{i}" in tokenizer.added_tokens_encoder: - multi_vector_tokens.append(f"{token}_{i}") - i += 1 - - raise ValueError( - f"Multi-vector Token {token} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." - ) - - is_multi_vector = embedding.ndims > 1 and embedding.shape[0] > 1 - - if is_multi_vector: - tokens = [token] + [f"{token}_i" for i in range(1, embedding.shape[0])] - else: - tokens = [token] - - self.tokenizer.add_tokens(tokens) - token_ids = tokenizer.convert_tokens_to_ids(tokens) - - # 4. Load token and embedding into tokenizer - embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) - self.text_encoder.resize_token_embeddings(len(self.tokenizer)) - - if token in self.tokenizer.get_vocab(): - # If user has allowed replacement and the token exists, we only need to - # extract the existing id and update the embedding - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - else: - # If the token does not exist, we add it to the tokenizer, then resize and update the - # text encoder acccordingly - - - def load_textual_inversion_embeddings( - self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], allow_replacement: bool = False - ): - r""" - Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's - embeddings. - - Arguments: - embeddings_path_dict_or_list (`Dict[str, str]` or `List[str]`): - Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. The - dictionary must have the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding - matrix - The list must contain paths to embedding dictionaries where the keys are the tokens and the values are - the embeddings (same as above dictionary definition). - - Returns: - None - """ - # Validate that inheriting class instance contains required attributes - self._validate_method_call(self.load_textual_inversion_embeddings) - - if isinstance(embedding_path_dict_or_list, dict): - for token, embedding_path in embedding_path_dict_or_list.items(): - # check if token in tokenizer vocab - if token in self.tokenizer.get_vocab(): - if allow_replacement: - logger.info( - f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" - " with the new one." - ) - else: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name or set `allow_replacement=True`." - ) - - embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) - embedding = self._extract_embedding_from_dict(embedding_dict) - - self.add_textual_inversion_embedding(token, embedding) - - elif isinstance(embedding_path_dict_or_list, list): - for embedding_path in embedding_path_dict_or_list: - embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) - token = self._extract_token_from_dict(embedding_dict) - embedding = self._extract_embedding_from_dict(embedding_dict) - - # check if token in tokenizer vocab - if token in self.tokenizer.get_vocab(): - if allow_replacement: - logger.info( - f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" - " with the new one." - ) - else: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - self.add_textual_inversion_embedding(token, embedding) - - def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): - r""" - Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. - - Arguments: - token (`str`): - The token to be added to the tokenizers' vocabulary - embedding (`torch.Tensor`): - The embedding of the token to be added to the text encoder's embedding matrix - """ - # NOTE: Not clear to me that we intend for this to be a public/exposed method. - # Validate that inheriting class instance contains required attributes - self._validate_method_call(self.load_textual_inversion_embeddings) - - embedding = embedding.to(self.text_encoder.dtype) - - if token in self.tokenizer.get_vocab(): - # If user has allowed replacement and the token exists, we only need to - # extract the existing id and update the embedding - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - else: - # If the token does not exist, we add it to the tokenizer, then resize and update the - # text encoder acccordingly - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - # NOTE: len() does't start at 0, so we shouldn't need to +1 - # since we already updated the tokenizer and it's new length - # should be old length + 1 - self.text_encoder.resize_token_embeddings(len(self.tokenizer)) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - - def _extract_embedding_from_dict(self, embedding_dict: Dict[str, str]) -> torch.Tensor: - r""" - Extracts the embedding from the embedding dictionary. - - Arguments: - embedding_dict (`Dict[str, str]`): - The embedding dictionary loaded from the embedding path - - Returns: - embedding (`torch.Tensor`): - The embedding to be added to the text encoder's embedding matrix - """ - # auto1111 embedding case - if "string_to_param" in embedding_dict: - embedding_dict = embedding_dict["string_to_param"] - embedding = embedding_dict["*"] - return embedding - - return list(embedding_dict.values())[0] - - def _extract_token_from_dict(self, embedding_dict: Dict[str, str]) -> str: - r""" - Extracts the token from the embedding dictionary. - - Arguments: - embedding_dict (`Dict[str, str]`): - The embedding dictionary loaded from the embedding path - - Returns: - token (`str`): - The token to be added to the tokenizers' vocabulary - """ - # auto1111 embedding case - if "string_to_param" in embedding_dict: - token = embedding_dict["name"] - return token - - return list(embedding_dict.keys())[0] - - def _validate_method_call(self, method: Callable): - r""" - Validates that the method is being called from a class instance that has the required attributes. - - Arguments: - method (`function`): - The class's method being called - - Raises: - ValueError: - If the method is being called from a class instance that does not have the required attributes, the - method will not be callable. - - Returns: - None - """ - if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): - raise ValueError( - f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" - f" `{method.__name__}`" - ) - - if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): - raise ValueError( - f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" - f" `{method.__name__}`" - ) From f5b6ff16de6beb5b73006acd448ddcd0045e839f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Mar 2023 13:47:35 +0100 Subject: [PATCH 31/41] Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Will Berman --- src/diffusers/loaders.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b8eaef5b940a..11ed73b77eef 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -294,7 +294,7 @@ def save_function(weights, filename): class TextualInversionLoaderMixin: r""" - Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: + Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with methods: - [`~TextualInversionLoaderMixin.load_textual_inversion_embeddings`] - [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] """ @@ -343,8 +343,6 @@ def load_textual_inversion( Parameters: - - Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: @@ -498,7 +496,7 @@ def load_textual_inversion( i += 1 raise ValueError( - f"Multi-vector Token {token} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." ) is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 @@ -656,7 +654,7 @@ def _extract_token_from_dict(self, embedding_dict: Dict[str, str]) -> str: return list(embedding_dict.keys())[0] - def _validate_method_call(self, method: Callable): + def _validate_can_load_textual_inversion(self, method: Callable): r""" Validates that the method is being called from a class instance that has the required attributes. From 8a040e8dc5a7135c9673bb4469f6e4e806fd87cf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Mar 2023 22:04:53 +0200 Subject: [PATCH 32/41] finish --- src/diffusers/loaders.py | 229 +++++------------- .../stable_diffusion/test_stable_diffusion.py | 21 ++ tests/test_pipelines.py | 86 +++++++ 3 files changed, 162 insertions(+), 174 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 11ed73b77eef..2e68a031edf1 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -294,12 +294,27 @@ def save_function(weights, filename): class TextualInversionLoaderMixin: r""" - Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with methods: - - [`~TextualInversionLoaderMixin.load_textual_inversion_embeddings`] - - [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] + Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder. """ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): + r""" + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that + corresponds to a multi-vector textual inversion embedding, this function will process the prompt + so that the special token is replaced with multiple special tokens each corresponding to one of the + vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, + the input prompt is simply returned. + + Parameters: + prompt (`str` or list of `str`): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str` or list of `str`: The converted prompt + """ if not isinstance(prompt, List): prompts = [prompt] else: @@ -313,6 +328,23 @@ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTra return prompts def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): + r""" + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that + corresponds to a multi-vector textual inversion embedding, this function will process the prompt + so that the special token is replaced with multiple special tokens each corresponding to one of the + vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, + the input prompt is simply returned. + + Parameters: + prompt (`str`): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str`: The converted prompt + """ tokens = tokenizer.tokenize(prompt) if not any(t in tokenizer.added_tokens_encoder for t in tokens): return prompt @@ -341,16 +373,18 @@ def load_textual_inversion( - Parameters: - Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. + Valid model ids should have an organization name, like `"sd-concepts-library/low-poly-hd-logos-icons"`. + - A path to a *directory* containing textual inversion weights, e.g. `./my_text_inversion_directory/`. + weight_name (`str`, *optional*): + Name of a custom weight file. This should be used in two cases: + + - The saved textual inversion file is in `diffusers` format, but has was saved under a specific weight name, such as `text_inv.bin`. + - The saved textual inversion file is in the "Automatic1111" form. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. @@ -388,7 +422,17 @@ def load_textual_inversion( """ - self._validate_method_call(self.load_textual_inversion) + if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -472,8 +516,8 @@ def load_textual_inversion( loaded_token, embedding = next(iter(state_dict.items())) elif "string_to_param" in state_dict: # A1111 - loaded_token = self._extract_token_from_dict(state_dict) - embedding = self._extract_embedding_from_dict(state_dict) + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] if token is not None and loaded_token != token: logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") @@ -518,166 +562,3 @@ def load_textual_inversion( self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding logger.info("Loaded textual inversion embedding for {token}.") - - # TODO(to discuss) think we can remove this one - def load_textual_inversion_embeddings( - self, embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], allow_replacement: bool = False - ): - r""" - Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's - embeddings. - - Arguments: - embeddings_path_dict_or_list (`Dict[str, str]` or `List[str]`): - Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. The - dictionary must have the following keys: - - `token`: name of the token to be added to the tokenizers' vocabulary - - `embedding`: path to the embedding of the token to be added to the text encoder's embedding - matrix - The list must contain paths to embedding dictionaries where the keys are the tokens and the values are - the embeddings (same as above dictionary definition). - - Returns: - None - """ - # Validate that inheriting class instance contains required attributes - self._validate_method_call(self.load_textual_inversion_embeddings) - - if isinstance(embedding_path_dict_or_list, dict): - for token, embedding_path in embedding_path_dict_or_list.items(): - # check if token in tokenizer vocab - if token in self.tokenizer.get_vocab(): - if allow_replacement: - logger.info( - f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" - " with the new one." - ) - else: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name or set `allow_replacement=True`." - ) - - embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) - embedding = self._extract_embedding_from_dict(embedding_dict) - - self.add_textual_inversion_embedding(token, embedding) - - elif isinstance(embedding_path_dict_or_list, list): - for embedding_path in embedding_path_dict_or_list: - embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) - token = self._extract_token_from_dict(embedding_dict) - embedding = self._extract_embedding_from_dict(embedding_dict) - - # check if token in tokenizer vocab - if token in self.tokenizer.get_vocab(): - if allow_replacement: - logger.info( - f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding" - " with the new one." - ) - else: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name." - ) - self.add_textual_inversion_embedding(token, embedding) - - # TODO(to discuss) think we can remove this one - def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): - r""" - Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. - - Arguments: - token (`str`): - The token to be added to the tokenizers' vocabulary - embedding (`torch.Tensor`): - The embedding of the token to be added to the text encoder's embedding matrix - """ - # NOTE: Not clear to me that we intend for this to be a public/exposed method. - # Validate that inheriting class instance contains required attributes - self._validate_method_call(self.load_textual_inversion_embeddings) - - embedding = embedding.to(self.text_encoder.dtype) - - if token in self.tokenizer.get_vocab(): - # If user has allowed replacement and the token exists, we only need to - # extract the existing id and update the embedding - token_id = self.tokenizer.convert_tokens_to_ids(token) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - else: - # If the token does not exist, we add it to the tokenizer, then resize and update the - # text encoder acccordingly - self.tokenizer.add_tokens([token]) - - token_id = self.tokenizer.convert_tokens_to_ids(token) - # NOTE: len() does't start at 0, so we shouldn't need to +1 - # since we already updated the tokenizer and it's new length - # should be old length + 1 - self.text_encoder.resize_token_embeddings(len(self.tokenizer)) - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - - def _extract_embedding_from_dict(self, embedding_dict: Dict[str, str]) -> torch.Tensor: - r""" - Extracts the embedding from the embedding dictionary. - - Arguments: - embedding_dict (`Dict[str, str]`): - The embedding dictionary loaded from the embedding path - - Returns: - embedding (`torch.Tensor`): - The embedding to be added to the text encoder's embedding matrix - """ - # auto1111 embedding case - if "string_to_param" in embedding_dict: - embedding_dict = embedding_dict["string_to_param"] - embedding = embedding_dict["*"] - return embedding - - return list(embedding_dict.values())[0] - - def _extract_token_from_dict(self, embedding_dict: Dict[str, str]) -> str: - r""" - Extracts the token from the embedding dictionary. - - Arguments: - embedding_dict (`Dict[str, str]`): - The embedding dictionary loaded from the embedding path - - Returns: - token (`str`): - The token to be added to the tokenizers' vocabulary - """ - # auto1111 embedding case - if "string_to_param" in embedding_dict: - token = embedding_dict["name"] - return token - - return list(embedding_dict.keys())[0] - - def _validate_can_load_textual_inversion(self, method: Callable): - r""" - Validates that the method is being called from a class instance that has the required attributes. - - Arguments: - method (`function`): - The class's method being called - - Raises: - ValueError: - If the method is being called from a class instance that does not have the required attributes, the - method will not be callable. - - Returns: - None - """ - if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): - raise ValueError( - f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" - f" `{method.__name__}`" - ) - - if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): - raise ValueError( - f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" - f" `{method.__name__}`" - ) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 33ef9368586e..991c0e26c3d1 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -21,6 +21,7 @@ import numpy as np import torch +from huggingface_hub import hf_hub_download from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -887,6 +888,26 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): assert mem_bytes_slicing < mem_bytes_offloaded assert mem_bytes_slicing < 3 * 10**9 + def test_stable_diffusion_textual_inversion(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter-style") + pipe.load_textual_inversion(a111_file) + pipe.to("cuda") + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe("An logo of a turtle in Style-Winter with ", generator=generator, output_type="np").images[0] + # np.save("/home/patrick/diffusers-images/text_inv/winter_logo_style.npy", image) + + expected_image = load_numpy( + "https://huggingface.co/datasets/diffusers/test-images/resolve/main" + "/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 1e-3 + @nightly @require_torch_gpu diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 9f0c9b1a4e19..15310eb33559 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -353,6 +353,92 @@ def test_download_broken_variant(self): diffusers.utils.import_utils._safetensors_available = True + def test_text_inversion_download(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe = pipe.to(torch_device) + + num_tokens = len(pipe.tokenizer) + + # single token load local + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"<*>": torch.ones((32,))} + torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin")) + + pipe.load_textual_inversion(tmpdirname) + + token = pipe.tokenizer.convert_tokens_to_ids("<*>") + assert token == num_tokens, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32 + assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>" + + prompt = "hey <*>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # single token load local with weight name + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"<**>": 2 * torch.ones((1, 32))} + torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin")) + + pipe.load_textual_inversion(tmpdirname, weight_name="learned_embeds.bin") + + token = pipe.tokenizer.convert_tokens_to_ids("<**>") + assert token == num_tokens + 1, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 + assert pipe._maybe_convert_prompt("<**>", pipe.tokenizer) == "<**>" + + prompt = "hey <**>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multi token load + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"<***>": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])} + torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin")) + + pipe.load_textual_inversion(tmpdirname) + + token = pipe.tokenizer.convert_tokens_to_ids("<***>") + token_1 = pipe.tokenizer.convert_tokens_to_ids("<***>_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("<***>_2") + + assert token == num_tokens + 2, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 3, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 4, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 + assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2" + + prompt = "hey <***>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multi token load a1111 + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"string_to_param": {"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])}, "name": "<****>"} + torch.save(ten, os.path.join(tmpdirname, "a1111.bin")) + + pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin") + + token = pipe.tokenizer.convert_tokens_to_ids("<****>") + token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2") + + assert token == num_tokens + 5, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 + assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2" + + prompt = "hey <****>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self): From 835a8d0c95740c43ea4f34c5f094313e57a72fdd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Mar 2023 22:05:39 +0200 Subject: [PATCH 33/41] make style --- src/diffusers/loaders.py | 27 ++++++++++--------- .../stable_diffusion/test_stable_diffusion.py | 7 ++--- tests/test_pipelines.py | 7 ++++- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2e68a031edf1..86bed16ac80d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -299,11 +299,10 @@ class TextualInversionLoaderMixin: def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): r""" - Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that - corresponds to a multi-vector textual inversion embedding, this function will process the prompt - so that the special token is replaced with multiple special tokens each corresponding to one of the - vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, - the input prompt is simply returned. + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. Parameters: prompt (`str` or list of `str`): @@ -329,11 +328,10 @@ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTra def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): r""" - Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that - corresponds to a multi-vector textual inversion embedding, this function will process the prompt - so that the special token is replaced with multiple special tokens each corresponding to one of the - vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, - the input prompt is simply returned. + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. Parameters: prompt (`str`): @@ -378,12 +376,15 @@ def load_textual_inversion( Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `"sd-concepts-library/low-poly-hd-logos-icons"`. - - A path to a *directory* containing textual inversion weights, e.g. `./my_text_inversion_directory/`. + Valid model ids should have an organization name, like + `"sd-concepts-library/low-poly-hd-logos-icons"`. + - A path to a *directory* containing textual inversion weights, e.g. + `./my_text_inversion_directory/`. weight_name (`str`, *optional*): Name of a custom weight file. This should be used in two cases: - - The saved textual inversion file is in `diffusers` format, but has was saved under a specific weight name, such as `text_inv.bin`. + - The saved textual inversion file is in `diffusers` format, but has was saved under a specific + weight name, such as `text_inv.bin`. - The saved textual inversion file is in the "Automatic1111" form. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 991c0e26c3d1..c4df845790a3 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -897,12 +897,13 @@ def test_stable_diffusion_textual_inversion(self): pipe.to("cuda") generator = torch.Generator(device="cpu").manual_seed(0) - image = pipe("An logo of a turtle in Style-Winter with ", generator=generator, output_type="np").images[0] + image = pipe( + "An logo of a turtle in Style-Winter with ", generator=generator, output_type="np" + ).images[0] # np.save("/home/patrick/diffusers-images/text_inv/winter_logo_style.npy", image) expected_image = load_numpy( - "https://huggingface.co/datasets/diffusers/test-images/resolve/main" - "/text_inv/winter_logo_style.npy" + "https://huggingface.co/datasets/diffusers/test-images/resolve/main" "/text_inv/winter_logo_style.npy" ) max_diff = np.abs(expected_image - image).max() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 15310eb33559..29cff1640cb0 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -418,7 +418,12 @@ def test_text_inversion_download(self): # multi token load a1111 with tempfile.TemporaryDirectory() as tmpdirname: - ten = {"string_to_param": {"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])}, "name": "<****>"} + ten = { + "string_to_param": { + "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))]) + }, + "name": "<****>", + } torch.save(ten, os.path.join(tmpdirname, "a1111.bin")) pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin") From 08a85dc555f74c69758b61012fb8beabe1d73bef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Mar 2023 22:07:59 +0200 Subject: [PATCH 34/41] up --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 86bed16ac80d..3ed09c02ce7d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -548,7 +548,7 @@ def load_textual_inversion( if is_multi_vector: tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] - embeddings = [e for e in embedding] + embeddings = [e for e in embedding] # noqa: C416 else: tokens = [token] embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]] From d1720999bfa3862113bd19b934377c4501dece3c Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 29 Mar 2023 09:19:52 -0300 Subject: [PATCH 35/41] fix code quality --- .../textual_inversion_bf16.py | 5 +--- .../textual_inversion.py | 5 +--- .../textual_inversion_flax.py | 5 +--- .../textual_inversion/textual_inversion.py | 5 +--- .../textual_inversion/textual_inversion.py | 5 +--- .../textual_inversion_flax.py | 5 +--- src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++---------- 7 files changed, 21 insertions(+), 39 deletions(-) diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index f446efc0b0c0..99fce231c590 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -336,10 +336,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py index 05f714715fc9..8f84a8358aa7 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py @@ -527,10 +527,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py index c23fa4f5d38a..5935e4a9f46d 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py @@ -306,10 +306,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 8d2c4c3c0bd4..57ad1aeb863a 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -443,10 +443,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 92f3d27d4905..743875074346 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -499,10 +499,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index e988a2552612..1c7747c4cc0b 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -319,10 +319,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index be4b15cfa23a..143d9a9d00a5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class TextualInversionLoaderMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] @@ -675,21 +690,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class TextualInversionLoaderMixin(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class EMAModel(metaclass=DummyObject): _backends = ["torch"] From 991d3d70b3f39b35f16dfdd4d562826a18c44fd9 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 29 Mar 2023 09:23:14 -0300 Subject: [PATCH 36/41] fix code quality - again --- .../intel_opts/textual_inversion/textual_inversion_bf16.py | 5 ++++- .../mulit_token_textual_inversion/textual_inversion.py | 5 ++++- .../mulit_token_textual_inversion/textual_inversion_flax.py | 5 ++++- .../onnxruntime/textual_inversion/textual_inversion.py | 5 ++++- examples/textual_inversion/textual_inversion.py | 5 ++++- examples/textual_inversion/textual_inversion_flax.py | 5 ++++- 6 files changed, 24 insertions(+), 6 deletions(-) diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index 99fce231c590..f446efc0b0c0 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -336,7 +336,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py index 8f84a8358aa7..05f714715fc9 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py @@ -527,7 +527,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py index 5935e4a9f46d..c23fa4f5d38a 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py @@ -306,7 +306,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 57ad1aeb863a..8d2c4c3c0bd4 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -443,7 +443,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 743875074346..92f3d27d4905 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -499,7 +499,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 1c7747c4cc0b..e988a2552612 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -319,7 +319,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) From 28c425bf4b03d18e186e6b894454ac325947a71a Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 29 Mar 2023 09:30:38 -0300 Subject: [PATCH 37/41] fix code quality - 3 --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 9 +++++++++ .../stable_diffusion/pipeline_cycle_diffusion.py | 9 +++++++++ .../pipeline_stable_diffusion_attend_and_excite.py | 9 +++++++++ .../pipeline_stable_diffusion_controlnet.py | 9 +++++++++ .../pipeline_stable_diffusion_depth2img.py | 9 +++++++++ .../pipeline_stable_diffusion_img2img.py | 9 +++++++++ .../pipeline_stable_diffusion_inpaint.py | 9 +++++++++ .../pipeline_stable_diffusion_inpaint_legacy.py | 9 +++++++++ .../pipeline_stable_diffusion_k_diffusion.py | 9 +++++++++ .../pipeline_stable_diffusion_panorama.py | 9 +++++++++ .../pipeline_stable_diffusion_pix2pix_zero.py | 9 +++++++++ .../stable_diffusion/pipeline_stable_diffusion_sag.py | 9 +++++++++ .../pipeline_stable_diffusion_upscale.py | 9 +++++++++ .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 9 +++++++++ .../stable_diffusion/pipeline_stable_unclip_img2img.py | 9 +++++++++ .../pipeline_text_to_video_synth.py | 9 +++++++++ 16 files changed, 144 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 1ae82beb54a4..ca6ea7c5695c 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,6 +22,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring @@ -312,6 +313,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -372,6 +377,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 76423867add1..86c86e7e367f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -24,6 +24,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor @@ -338,6 +339,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -398,6 +403,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 2d32c0ba8b62..572e661f331e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -21,6 +21,7 @@ from torch.nn import functional as F from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers @@ -335,6 +336,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -395,6 +400,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index aeb70b1b2234..9b46f6395ebe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -23,6 +23,7 @@ from torch import nn from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.controlnet import ControlNetOutput from ...models.modeling_utils import ModelMixin @@ -354,6 +355,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -414,6 +419,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index b66cfe9b437e..4f17fc3cefbc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor @@ -200,6 +201,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -260,6 +265,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 835c88e19448..065a79f8cd36 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -23,6 +23,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -329,6 +330,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -389,6 +394,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cee7ace239db..08b9aa51159a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -22,6 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -373,6 +374,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -433,6 +438,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index cb953a7803b2..e145b3c0c35b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -22,6 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -317,6 +318,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -377,6 +382,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 2d40390b41d1..0e6ee747a8f5 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -18,6 +18,7 @@ import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from ...loaders import TextualInversionLoaderMixin from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -238,6 +239,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -298,6 +303,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 3fea4c2d83bb..3fdfdca9dafe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -17,6 +17,7 @@ import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, PNDMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring @@ -230,6 +231,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -290,6 +295,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 9c928129d0b9..54dfec8dc1a8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -28,6 +28,7 @@ CLIPTokenizer, ) +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler @@ -470,6 +471,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -530,6 +535,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index b24354a8e568..c61792c5bc43 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring @@ -247,6 +248,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -307,6 +312,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 9f8f44a12bb4..d520c9f59610 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor @@ -194,6 +195,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -254,6 +259,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index a8ba0b504628..ed5acbf20751 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -19,6 +19,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers @@ -342,6 +343,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -402,6 +407,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 99caa8be65a5..d2dd0ba430cb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -21,6 +21,7 @@ from diffusers.utils.import_utils import is_accelerate_available +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers @@ -242,6 +243,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -302,6 +307,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 453809ef6df7..c591a14128dd 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -19,6 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -256,6 +257,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -316,6 +321,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, From df9f5799466afa14eacde1b8b9f82044c464b352 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 29 Mar 2023 09:33:52 -0300 Subject: [PATCH 38/41] fix alt diffusion code quality --- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index b71217a4b3ec..394b5c44c694 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring @@ -322,6 +323,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -382,6 +387,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, From 9dd02676d7e538a9238ee345803b9f403733d8b6 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 29 Mar 2023 09:44:05 -0300 Subject: [PATCH 39/41] fix model editing pipeline --- .../pipeline_stable_diffusion_model_editing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index 0e850b43bd7c..a0bcfd99c4a3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -18,6 +18,7 @@ import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin @@ -266,6 +267,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -326,6 +331,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, From 74b1e641b01af0ce1499201d4e99af4bb91f93f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Mar 2023 16:16:08 +0100 Subject: [PATCH 40/41] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/loaders.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b00545d65770..8cd5e8e9625f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -297,7 +297,7 @@ class TextualInversionLoaderMixin: Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder. """ - def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer): r""" Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds to a multi-vector textual inversion embedding, this function will process the prompt so that the special token @@ -306,8 +306,7 @@ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTra Parameters: prompt (`str` or list of `str`): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. + The prompt or prompts to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. @@ -326,7 +325,7 @@ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTra return prompts - def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): + def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer): r""" Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds to a multi-vector textual inversion embedding, this function will process the prompt so that the special token @@ -335,8 +334,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): Parameters: prompt (`str`): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. + The prompt to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. @@ -364,6 +362,7 @@ def load_textual_inversion( ): r""" Load textual inversion embeddings into the text encoder of stable diffusion pipelines. + Both `diffusers` and `Automatic1111` formats are supported. @@ -383,7 +382,7 @@ def load_textual_inversion( weight_name (`str`, *optional*): Name of a custom weight file. This should be used in two cases: - - The saved textual inversion file is in `diffusers` format, but has was saved under a specific + - The saved textual inversion file is in `diffusers` format, but was saved under a specific weight name, such as `text_inv.bin`. - The saved textual inversion file is in the "Automatic1111" form. cache_dir (`Union[str, os.PathLike]`, *optional*): @@ -487,7 +486,6 @@ def load_textual_inversion( raise e model_file = None - pass if model_file is None: model_file = _get_model_file( From b9f53cbf54907d6499f191eceadd91d2f8e0bfc6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Mar 2023 15:56:07 +0000 Subject: [PATCH 41/41] Finish --- src/diffusers/__init__.py | 2 +- src/diffusers/loaders.py | 28 +++++++++++-------- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 2 +- .../pipeline_cycle_diffusion.py | 2 +- ...line_stable_diffusion_attend_and_excite.py | 2 +- .../pipeline_stable_diffusion_controlnet.py | 2 +- .../pipeline_stable_diffusion_depth2img.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- ...eline_stable_diffusion_instruct_pix2pix.py | 11 +++++++- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- ...pipeline_stable_diffusion_model_editing.py | 2 +- .../pipeline_stable_diffusion_panorama.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 2 +- .../pipeline_stable_diffusion_sag.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_unclip.py | 2 +- .../pipeline_stable_unclip_img2img.py | 2 +- .../pipeline_text_to_video_synth.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 15 ---------- .../dummy_torch_and_transformers_objects.py | 15 ++++++++++ .../stable_diffusion/test_stable_diffusion.py | 21 ++++++++------ 24 files changed, 74 insertions(+), 54 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6c9928786761..bba8d4084636 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,7 +33,6 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .loaders import TextualInversionLoaderMixin from .models import ( AutoencoderKL, ControlNetModel, @@ -110,6 +109,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .loaders import TextualInversionLoaderMixin from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8cd5e8e9625f..265ea92625f5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -16,14 +16,25 @@ from typing import Callable, Dict, List, Optional, Union import torch -from transformers import PreTrainedModel, PreTrainedTokenizer -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, deprecate, is_safetensors_available, logging +from .models.attention_processor import LoRAAttnProcessor +from .utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + _get_model_file, + deprecate, + is_safetensors_available, + is_transformers_available, + logging, +) if is_safetensors_available(): import safetensors +if is_transformers_available(): + from transformers import PreTrainedModel, PreTrainedTokenizer + logger = logging.get_logger(__name__) @@ -154,8 +165,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict "framework": "pytorch", } - from .models.attention_processor import LoRAAttnProcessor - model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights @@ -342,9 +351,6 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer): `str`: The converted prompt """ tokens = tokenizer.tokenize(prompt) - if not any(t in tokenizer.added_tokens_encoder for t in tokens): - return prompt - for token in tokens: if token in tokenizer.added_tokens_encoder: replacement = token @@ -361,8 +367,8 @@ def load_textual_inversion( self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs ): r""" - Load textual inversion embeddings into the text encoder of stable diffusion pipelines. - Both `diffusers` and `Automatic1111` formats are supported. + Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and + `Automatic1111` formats are supported. @@ -382,8 +388,8 @@ def load_textual_inversion( weight_name (`str`, *optional*): Name of a custom weight file. This should be used in two cases: - - The saved textual inversion file is in `diffusers` format, but was saved under a specific - weight name, such as `text_inv.bin`. + - The saved textual inversion file is in `diffusers` format, but was saved under a specific weight + name, such as `text_inv.bin`. - The saved textual inversion file is in the "Automatic1111" form. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 86a827de8ee3..c5bb8f9ac7b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -50,7 +50,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline): +class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Alt Diffusion. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index a1365648eea2..9af55d1d018a 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -89,7 +89,7 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionImg2ImgPipeline(DiffusionPipeline): +class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Alt Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 24a6d47a4932..dd8e4f16dfc0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -119,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): return noise -class CycleDiffusionPipeline(DiffusionPipeline): +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 96beccf96586..46adb6967140 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -160,7 +160,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): +class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 8a04ef785066..93cbc03b12ed 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -147,7 +147,7 @@ def forward( return down_block_res_samples, mid_block_res_sample -class StableDiffusionControlNetPipeline(DiffusionPipeline): +class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 4795f8030ebb..54f00ebc23f2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -55,7 +55,7 @@ def preprocess(image): return image -class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 51c9d310154a..e47fae663de3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -92,7 +92,7 @@ def preprocess(image): return image -class StableDiffusionImg2ImgPipeline(DiffusionPipeline): +class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cc0230009c4d..8e0ea5a8d079 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -138,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask): return mask, masked_image -class StableDiffusionInpaintPipeline(DiffusionPipeline): +class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 71463068f5a2..b7a0c942bbe2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -82,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8): return mask -class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): +class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 40cde74a0596..f7999a08dc9b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -60,7 +61,7 @@ def preprocess(image): return image -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): +class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. @@ -511,6 +512,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -571,6 +576,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 87e735b63835..3d10c7d4e8e8 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -42,7 +42,7 @@ def apply_model(self, *args, **kwargs): return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample -class StableDiffusionKDiffusionPipeline(DiffusionPipeline): +class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index a0bcfd99c4a3..d841bd8a2d26 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -53,7 +53,7 @@ """ -class StableDiffusionModelEditingPipeline(DiffusionPipeline): +class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 0deaca3910fe..c47423bdee5b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -48,7 +48,7 @@ """ -class StableDiffusionPanoramaPipeline(DiffusionPipeline): +class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation". diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 902c638b7cc0..6af923cb7743 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -51,7 +51,7 @@ @dataclass -class Pix2PixInversionPipelineOutput(BaseOutput): +class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): """ Output class for Stable Diffusion pipelines. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 7429aaa9d5bb..2b08cf662bb4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -88,7 +88,7 @@ def __call__( # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input -class StableDiffusionSAGPipeline(DiffusionPipeline): +class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 5218cb507bc6..606202bd3911 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -51,7 +51,7 @@ def preprocess(image): return image -class StableDiffusionUpscalePipeline(DiffusionPipeline): +class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image super-resolution using Stable Diffusion 2. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index a48305f545f0..ce41572e683c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -48,7 +48,7 @@ """ -class StableUnCLIPPipeline(DiffusionPipeline): +class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): """ Pipeline for text-to-image generation using stable unCLIP. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index e1030b127c4e..b9bf00bc7835 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -61,7 +61,7 @@ """ -class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): """ Pipeline for text-guided image to image generation using stable unCLIP. diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index d840db75c707..1cbe78f0c964 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -73,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - return images -class TextToVideoSDPipeline(DiffusionPipeline): +class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-video generation. diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1b298eaef0a9..014e193aa32a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,21 +2,6 @@ from ..utils import DummyObject, requires_backends -class TextualInversionLoaderMixin(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ab85566049d8..cf85ff157f57 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class TextualInversionLoaderMixin(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 4b472ba19b81..c3ad88b34acb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -891,22 +891,27 @@ def test_stable_diffusion_textual_inversion(self): pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") - a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter-style") + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) pipe.to("cuda") - generator = torch.Generator(device="cpu").manual_seed(0) - image = pipe( - "An logo of a turtle in Style-Winter with ", generator=generator, output_type="np" - ).images[0] - # np.save("/home/patrick/diffusers-images/text_inv/winter_logo_style.npy", image) + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] expected_image = load_numpy( - "https://huggingface.co/datasets/diffusers/test-images/resolve/main" "/text_inv/winter_logo_style.npy" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" ) max_diff = np.abs(expected_image - image).max() - assert max_diff < 1e-3 + assert max_diff < 5e-3 @nightly