|
2 | 2 | from transformers import CLIPTextModel |
3 | 3 | from collections import defaultdict |
4 | 4 | import torch |
| 5 | +import safetensors.torch |
5 | 6 | import traceback |
6 | | -import re |
7 | 7 | import sys |
8 | 8 | from apps.stable_diffusion.src.utils import ( |
9 | 9 | compile_through_fx, |
@@ -164,10 +164,23 @@ def get_vae(self): |
164 | 164 | class VaeModel(torch.nn.Module): |
165 | 165 | def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae): |
166 | 166 | super().__init__() |
167 | | - self.vae = AutoencoderKL.from_pretrained( |
168 | | - model_id if custom_vae == "" else custom_vae, |
169 | | - subfolder="vae", |
170 | | - ) |
| 167 | + self.vae = None |
| 168 | + if custom_vae == "": |
| 169 | + self.vae = AutoencoderKL.from_pretrained( |
| 170 | + model_id, |
| 171 | + subfolder="vae", |
| 172 | + ) |
| 173 | + elif not isinstance(custom_vae, dict): |
| 174 | + self.vae = AutoencoderKL.from_pretrained( |
| 175 | + custom_vae, |
| 176 | + subfolder="vae", |
| 177 | + ) |
| 178 | + else: |
| 179 | + self.vae = AutoencoderKL.from_pretrained( |
| 180 | + model_id, |
| 181 | + subfolder="vae", |
| 182 | + ) |
| 183 | + self.vae.load_state_dict(custom_vae) |
171 | 184 | self.base_vae = base_vae |
172 | 185 |
|
173 | 186 | def forward(self, input): |
@@ -254,6 +267,27 @@ def forward(self, input): |
254 | 267 | ) |
255 | 268 | return shark_clip |
256 | 269 |
|
| 270 | + def process_custom_vae(self): |
| 271 | + custom_vae = self.custom_vae.lower() |
| 272 | + if not custom_vae.endswith((".ckpt", ".safetensors")): |
| 273 | + return self.custom_vae |
| 274 | + try: |
| 275 | + preprocessCKPT(self.custom_vae) |
| 276 | + return get_path_to_diffusers_checkpoint(self.custom_vae) |
| 277 | + except: |
| 278 | + print("Processing standalone Vae checkpoint") |
| 279 | + vae_checkpoint = None |
| 280 | + vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} |
| 281 | + if custom_vae.endswith(".ckpt"): |
| 282 | + vae_checkpoint = torch.load(self.custom_vae, map_location="cpu") |
| 283 | + else: |
| 284 | + vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu") |
| 285 | + if "state_dict" in vae_checkpoint: |
| 286 | + vae_checkpoint = vae_checkpoint["state_dict"] |
| 287 | + vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys} |
| 288 | + return vae_dict |
| 289 | + |
| 290 | + |
257 | 291 | # Compiles Clip, Unet and Vae with `base_model_id` as defining their input |
258 | 292 | # configiration. |
259 | 293 | def compile_all(self, base_model_id, need_vae_encode): |
@@ -305,9 +339,7 @@ def __call__(self): |
305 | 339 | model_to_run = args.hf_model_id |
306 | 340 | # For custom Vae user can provide either the repo-id or a checkpoint file, |
307 | 341 | # and for a checkpoint file we'd need to process it via Diffusers' script. |
308 | | - if self.custom_vae.lower().endswith((".ckpt", ".safetensors")): |
309 | | - preprocessCKPT(self.custom_vae) |
310 | | - self.custom_vae = get_path_to_diffusers_checkpoint(self.custom_vae) |
| 342 | + self.custom_vae = self.process_custom_vae() |
311 | 343 | base_model_fetched = fetch_and_update_base_model_id(model_to_run) |
312 | 344 | if base_model_fetched != "": |
313 | 345 | print("Compiling all the models with the fetched base model configuration.") |
|
0 commit comments