Skip to content

Commit 143492f

Browse files
Abhishek-VarmaAbhishek Varma
andauthored
[SD] Add support for standalone Vae checkpoints (huggingface#1020)
-- This commit adds support for standalone Vae checkpoints. Signed-off-by: Abhishek Varma <[email protected]> Co-authored-by: Abhishek Varma <[email protected]>
1 parent ecc5c66 commit 143492f

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

apps/stable_diffusion/src/models/model_wrappers.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from transformers import CLIPTextModel
33
from collections import defaultdict
44
import torch
5+
import safetensors.torch
56
import traceback
6-
import re
77
import sys
88
from apps.stable_diffusion.src.utils import (
99
compile_through_fx,
@@ -164,10 +164,23 @@ def get_vae(self):
164164
class VaeModel(torch.nn.Module):
165165
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae):
166166
super().__init__()
167-
self.vae = AutoencoderKL.from_pretrained(
168-
model_id if custom_vae == "" else custom_vae,
169-
subfolder="vae",
170-
)
167+
self.vae = None
168+
if custom_vae == "":
169+
self.vae = AutoencoderKL.from_pretrained(
170+
model_id,
171+
subfolder="vae",
172+
)
173+
elif not isinstance(custom_vae, dict):
174+
self.vae = AutoencoderKL.from_pretrained(
175+
custom_vae,
176+
subfolder="vae",
177+
)
178+
else:
179+
self.vae = AutoencoderKL.from_pretrained(
180+
model_id,
181+
subfolder="vae",
182+
)
183+
self.vae.load_state_dict(custom_vae)
171184
self.base_vae = base_vae
172185

173186
def forward(self, input):
@@ -254,6 +267,27 @@ def forward(self, input):
254267
)
255268
return shark_clip
256269

270+
def process_custom_vae(self):
271+
custom_vae = self.custom_vae.lower()
272+
if not custom_vae.endswith((".ckpt", ".safetensors")):
273+
return self.custom_vae
274+
try:
275+
preprocessCKPT(self.custom_vae)
276+
return get_path_to_diffusers_checkpoint(self.custom_vae)
277+
except:
278+
print("Processing standalone Vae checkpoint")
279+
vae_checkpoint = None
280+
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
281+
if custom_vae.endswith(".ckpt"):
282+
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
283+
else:
284+
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
285+
if "state_dict" in vae_checkpoint:
286+
vae_checkpoint = vae_checkpoint["state_dict"]
287+
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
288+
return vae_dict
289+
290+
257291
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
258292
# configiration.
259293
def compile_all(self, base_model_id, need_vae_encode):
@@ -305,9 +339,7 @@ def __call__(self):
305339
model_to_run = args.hf_model_id
306340
# For custom Vae user can provide either the repo-id or a checkpoint file,
307341
# and for a checkpoint file we'd need to process it via Diffusers' script.
308-
if self.custom_vae.lower().endswith((".ckpt", ".safetensors")):
309-
preprocessCKPT(self.custom_vae)
310-
self.custom_vae = get_path_to_diffusers_checkpoint(self.custom_vae)
342+
self.custom_vae = self.process_custom_vae()
311343
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
312344
if base_model_fetched != "":
313345
print("Compiling all the models with the fetched base model configuration.")

0 commit comments

Comments
 (0)