From 0f51783f2f04861d51363fdaa25af88aaed2849e Mon Sep 17 00:00:00 2001 From: akleine Date: Fri, 17 Oct 2025 11:24:33 +0200 Subject: [PATCH 1/6] feat: add code and doc for running SSD1B models --- docs/distilled_sd.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 docs/distilled_sd.md diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md new file mode 100644 index 000000000..7e38cb35f --- /dev/null +++ b/docs/distilled_sd.md @@ -0,0 +1,18 @@ +# Running distilled SDXL models: SSD1B + +### Preface + +This kind of models has a reduced U-Net part. Unlike other SDXL models the U-Net has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 . + +### How to Use + +Unfortunately not all of this models follow the standard model parameter naming mapping. +Anyway there are some useful SSD1B models available online, such as: + + * https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors + * https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors + +Also there are useful LORAs available: + + * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors + * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors From 66ef7058d55ddd0a581eb375686ab6012c1266e2 Mon Sep 17 00:00:00 2001 From: akleine Date: Fri, 17 Oct 2025 11:37:45 +0200 Subject: [PATCH 2/6] feat: add code and doc for running SSD1B models --- model.cpp | 7 ++++++- model.h | 3 ++- stable-diffusion.cpp | 1 + unet.hpp | 35 +++++++++++++++++++++++------------ 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/model.cpp b/model.cpp index b45493cc4..1751cb66b 100644 --- a/model.cpp +++ b/model.cpp @@ -1859,7 +1859,12 @@ SDVersion ModelLoader::get_sd_version() { if (is_ip2p) { return VERSION_SDXL_PIX2PIX; } - return VERSION_SDXL; + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.middle_block.1") != std::string::npos) { + return VERSION_SDXL; // found a missing tensor in SSD1B, so it is SDXL + } + } + return VERSION_SDXL_SSD1B; } if (is_flux) { diff --git a/model.h b/model.h index 069bb0c21..c6637dd0b 100644 --- a/model.h +++ b/model.h @@ -27,6 +27,7 @@ enum SDVersion { VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, + VERSION_SDXL_SSD1B, VERSION_SVD, VERSION_SD3, VERSION_FLUX, @@ -55,7 +56,7 @@ static inline bool sd_version_is_sd2(SDVersion version) { } static inline bool sd_version_is_sdxl(SDVersion version) { - if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 87b6a3779..0429624a3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -33,6 +33,7 @@ const char* model_version_to_str[] = { "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", + "SDXL (SSD1B)", "SVD", "SD3.x", "Flux", diff --git a/unet.hpp b/unet.hpp index 19bedb32b..6dde9bcc8 100644 --- a/unet.hpp +++ b/unet.hpp @@ -270,10 +270,14 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; + int td=transformer_depth[i]; + if (version == VERSION_SDXL_SSD1B) { + if (i==2) td=4; + } blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, - transformer_depth[i], + td, context_dim)); } input_block_chans.push_back(ch); @@ -296,13 +300,14 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, - n_head, - d_head, - transformer_depth[transformer_depth.size() - 1], - context_dim)); - blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - + if (version != VERSION_SDXL_SSD1B) { + blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, + n_head, + d_head, + transformer_depth[transformer_depth.size() - 1], + context_dim)); + blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); + } // output_blocks int output_block_idx = 0; for (int i = (int)len_mults - 1; i >= 0; i--) { @@ -324,7 +329,12 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; - blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim)); + int td = transformer_depth[i]; + if (version == VERSION_SDXL_SSD1B) { + if (i==2 && (j==0 || j==1)) td=4; + if (i==1 && (j==1 || j==2)) td=1; + } + blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, td, context_dim)); up_sample_idx++; } @@ -478,9 +488,10 @@ class UnetModelBlock : public GGMLBlock { // middle_block h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - + if (version != VERSION_SDXL_SSD1B) { + h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + } if (controls.size() > 0) { auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength); h = ggml_add(ctx, h, cs); // middle control From b553004dbe4108f83355256e618fb0e6e6e16681 Mon Sep 17 00:00:00 2001 From: akleine Date: Tue, 21 Oct 2025 20:19:04 +0200 Subject: [PATCH 3/6] Added some more lines to support SD1.x with TINY U-Nets too. This is a repleacement of my earlier PR #745 (just closed by myself) with much better code. Also updated doc file about TINY U-Nets. --- README.md | 1 + docs/distilled_sd.md | 78 +++++++++++++++++++++++++++++++++++++++++--- model.cpp | 15 ++++++++- model.h | 3 +- stable-diffusion.cpp | 1 + unet.hpp | 41 ++++++++++++++++++----- 6 files changed, 124 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index b32d3fa62..db4ed4e13 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ API and command-line option may change frequently.*** - Image Models - SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) - SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) + - [some SD1.x and SDXL distilled models](./docs/distilled_sd.md) - [SD3/SD3.5](./docs/sd3.md) - [Flux-dev/Flux-schnell](./docs/flux.md) - [Chroma](./docs/chroma.md) diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index 7e38cb35f..f235f56b8 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -1,13 +1,15 @@ -# Running distilled SDXL models: SSD1B +# Running distilled models: SSD1B and SD1.x with tiny U-Nets -### Preface +## Preface -This kind of models has a reduced U-Net part. Unlike other SDXL models the U-Net has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 . +This kind of models have a reduced U-Net part. +Unlike other SDXL models the U-Net of SSD1B has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 . +Unlike other SD 1.x models Tiny-UNet models consist of only 6 U-Net blocks, resulting in relatively smaller files (approximately 1 GB). Running these models saves almost 50% of the time. For more details, refer to the paper: https://arxiv.org/pdf/2305.15798.pdf . -### How to Use +## SSD1B Unfortunately not all of this models follow the standard model parameter naming mapping. -Anyway there are some useful SSD1B models available online, such as: +Anyway there are some very useful SSD1B models available online, such as: * https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors * https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors @@ -16,3 +18,69 @@ Also there are useful LORAs available: * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors + +You can use this files **out-of-the-box** - unlike models in next section. + + +## SD1.x with tiny U-Nets + +There are some Tiny SD 1.x models available online, such as: + + * https://huggingface.co/segmind/tiny-sd + * https://huggingface.co/segmind/portrait-finetuned + * https://huggingface.co/nota-ai/bk-sdm-tiny + +These models need some conversion, for example because partially tensors are **non contiguous** stored. To create a usable checkpoint file, follow these **easy** steps: + +### Download model from Hugging Face + +Download the model using Python on your computer, for example this way: + +```python +import torch +from diffusers import StableDiffusionPipeline +pipe = StableDiffusionPipeline.from_pretrained("segmind/tiny-sd") +unet=pipe.unet +for param in unet.parameters(): + param.data = param.data.contiguous() # <- important here +pipe.save_pretrained("segmindtiny-sd", safe_serialization=True) +``` + +### Convert that to a ckpt file + +To convert the downloaded model to a checkpoint file, you need another Python script. Download the conversion script from here: + + * https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/scripts/convert_diffusers_to_original_stable_diffusion.py + + +### Run convert script + +Now, run that conversion script: + +```bash +python convert_diffusers_to_original_stable_diffusion.py \ + --model_path ./segmindtiny-sd \ + --checkpoint_path ./segmind_tiny-sd.ckpt --half +``` + +The file **segmind_tiny-sd.ckpt** will be generated and is now ready to use with sd.cpp + +You can follow a similar process for other models mentioned above from Hugging Face. + + +### Another ckpt file on the net + +There is another model file available online: + + * https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt + +If you want to use that, you have to adjust some **non-contiguous tensors** first: + +```python +import torch +ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu')) +for key, value in ckpt['state_dict'].items(): + if isinstance(value, torch.Tensor): + ckpt['state_dict'][key] = value.contiguous() +torch.save(ckpt, "tinySDdistilled_fixed.ckpt") +``` diff --git a/model.cpp b/model.cpp index 1751cb66b..c2cd6e9b3 100644 --- a/model.cpp +++ b/model.cpp @@ -623,6 +623,14 @@ std::string convert_tensor_name(std::string name) { if (starts_with(name, "diffusion_model")) { name = "model." + name; } + if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) { + name.replace(0,sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1, + "model.diffusion_model.output_blocks.0.1."); + } + if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) { + name.replace(0,sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1, + "model.diffusion_model.output_blocks.1.1."); + } // size_t pos = name.find("lora_A"); // if (pos != std::string::npos) { // name.replace(pos, strlen("lora_A"), "lora_up"); @@ -1887,7 +1895,12 @@ SDVersion ModelLoader::get_sd_version() { if (is_ip2p) { return VERSION_SD1_PIX2PIX; } - return VERSION_SD1; + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.middle_block") != std::string::npos) { + return VERSION_SD1; // found a middle block, so it is SD1 + } + } + return VERSION_SD1_TINY_UNET; } else if (token_embedding_weight.ne[0] == 1024) { if (is_inpaint) { return VERSION_SD2_INPAINT; diff --git a/model.h b/model.h index c6637dd0b..52db21602 100644 --- a/model.h +++ b/model.h @@ -22,6 +22,7 @@ enum SDVersion { VERSION_SD1, VERSION_SD1_INPAINT, VERSION_SD1_PIX2PIX, + VERSION_SD1_TINY_UNET, VERSION_SD2, VERSION_SD2_INPAINT, VERSION_SDXL, @@ -42,7 +43,7 @@ enum SDVersion { }; static inline bool sd_version_is_sd1(SDVersion version) { - if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0429624a3..1c91eee23 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -28,6 +28,7 @@ const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", "Instruct-Pix2Pix", + "SD 1.x Tiny UNet", "SD 2.x", "SD 2.x Inpaint", "SDXL", diff --git a/unet.hpp b/unet.hpp index 6dde9bcc8..07f43e83b 100644 --- a/unet.hpp +++ b/unet.hpp @@ -204,6 +204,9 @@ class UnetModelBlock : public GGMLBlock { adm_in_channels = 768; num_head_channels = 64; num_heads = -1; + } else if (version == VERSION_SD1_TINY_UNET) { + num_res_blocks = 1; + channel_mult = {1, 2, 4}; } if (sd_version_is_inpaint(version)) { in_channels = 9; @@ -281,6 +284,9 @@ class UnetModelBlock : public GGMLBlock { context_dim)); } input_block_chans.push_back(ch); + if (version == VERSION_SD1_TINY_UNET) { + input_block_idx++; + } } if (i != len_mults - 1) { input_block_idx += 1; @@ -299,14 +305,16 @@ class UnetModelBlock : public GGMLBlock { d_head = num_head_channels; n_head = ch / d_head; } - blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - if (version != VERSION_SDXL_SSD1B) { - blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, + if (version != VERSION_SD1_TINY_UNET) { + blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); + if (version != VERSION_SDXL_SSD1B) { + blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, transformer_depth[transformer_depth.size() - 1], context_dim)); - blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); + blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); + } } // output_blocks int output_block_idx = 0; @@ -340,6 +348,12 @@ class UnetModelBlock : public GGMLBlock { } if (i > 0 && j == num_res_blocks) { + if (version == VERSION_SD1_TINY_UNET) { + output_block_idx++; + if (output_block_idx == 2) { + up_sample_idx=1; + } + } std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx); blocks[name] = std::shared_ptr(new UpSampleBlock(ch, ch)); @@ -473,6 +487,9 @@ class UnetModelBlock : public GGMLBlock { } hs.push_back(h); } + if (version == VERSION_SD1_TINY_UNET) { + input_block_idx++; + } if (i != len_mults - 1) { ds *= 2; input_block_idx += 1; @@ -487,10 +504,12 @@ class UnetModelBlock : public GGMLBlock { // [N, 4*model_channels, h/8, w/8] // middle_block - h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - if (version != VERSION_SDXL_SSD1B) { - h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + if (version != VERSION_SD1_TINY_UNET) { + h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + if (version != VERSION_SDXL_SSD1B) { + h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + } } if (controls.size() > 0) { auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength); @@ -527,6 +546,12 @@ class UnetModelBlock : public GGMLBlock { } if (i > 0 && j == num_res_blocks) { + if (version == VERSION_SD1_TINY_UNET) { + output_block_idx++; + if (output_block_idx == 2) { + up_sample_idx=1; + } + } std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx); auto block = std::dynamic_pointer_cast(blocks[name]); From 867a925fe1f92b47c83e485a544280df193e7983 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 25 Oct 2025 22:51:09 +0800 Subject: [PATCH 4/6] support SSD-1B.safetensors --- model.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model.cpp b/model.cpp index a05a96be7..2e1ebe061 100644 --- a/model.cpp +++ b/model.cpp @@ -330,6 +330,10 @@ std::string convert_cond_model_name(const std::string& name) { return new_name; } + if (new_name == "model.text_projection.weight") { + new_name = "transformer.text_model.text_projection"; + } + if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) { new_name = open_clip_to_hf_clip_model[new_name]; } From 849ec4ec75471dc3b147baf3921fe61c6dd1a0f5 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 25 Oct 2025 23:23:52 +0800 Subject: [PATCH 5/6] fix sdv1.5 diffusers format loader --- model.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/model.cpp b/model.cpp index 2e1ebe061..36972553e 100644 --- a/model.cpp +++ b/model.cpp @@ -1788,6 +1788,7 @@ SDVersion ModelLoader::get_sd_version() { bool is_wan = false; int64_t patch_embedding_channels = 0; bool has_img_emb = false; + bool has_middle_block_1 = false; for (auto& tensor_storage : tensor_storages) { if (!(is_xl || is_flux)) { @@ -1834,6 +1835,10 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SVD; } } + if (tensor_storage.name.find("model.diffusion_model.middle_block.1.") != std::string::npos || + tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { + has_middle_block_1 = true; + } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" || @@ -1846,7 +1851,7 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; input_block_checked = true; - if (is_xl || is_flux) { + if (is_flux) { break; } } @@ -1870,12 +1875,10 @@ SDVersion ModelLoader::get_sd_version() { if (is_ip2p) { return VERSION_SDXL_PIX2PIX; } - for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name.find("model.diffusion_model.middle_block.1") != std::string::npos) { - return VERSION_SDXL; // found a missing tensor in SSD1B, so it is SDXL - } + if (!has_middle_block_1) { + return VERSION_SDXL_SSD1B; } - return VERSION_SDXL_SSD1B; + return VERSION_SDXL; } if (is_flux) { @@ -1898,12 +1901,10 @@ SDVersion ModelLoader::get_sd_version() { if (is_ip2p) { return VERSION_SD1_PIX2PIX; } - for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name.find("model.diffusion_model.middle_block") != std::string::npos) { - return VERSION_SD1; // found a middle block, so it is SD1 - } + if (!has_middle_block_1) { + return VERSION_SD1_TINY_UNET; } - return VERSION_SD1_TINY_UNET; + return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { if (is_inpaint) { return VERSION_SD2_INPAINT; From 9dbe1024e8ae6350cc02a7fd22063fc38baab356 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 25 Oct 2025 23:29:40 +0800 Subject: [PATCH 6/6] format code --- model.cpp | 10 +++++----- unet.hpp | 48 +++++++++++++++++++++++++++--------------------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/model.cpp b/model.cpp index 36972553e..0a03627f9 100644 --- a/model.cpp +++ b/model.cpp @@ -628,12 +628,12 @@ std::string convert_tensor_name(std::string name) { name = "model." + name; } if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) { - name.replace(0,sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1, - "model.diffusion_model.output_blocks.0.1."); + name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1, + "model.diffusion_model.output_blocks.0.1."); } if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) { - name.replace(0,sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1, - "model.diffusion_model.output_blocks.1.1."); + name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1, + "model.diffusion_model.output_blocks.1.1."); } // size_t pos = name.find("lora_A"); // if (pos != std::string::npos) { @@ -1878,7 +1878,7 @@ SDVersion ModelLoader::get_sd_version() { if (!has_middle_block_1) { return VERSION_SDXL_SSD1B; } - return VERSION_SDXL; + return VERSION_SDXL; } if (is_flux) { diff --git a/unet.hpp b/unet.hpp index d5fd240e8..318dbc081 100644 --- a/unet.hpp +++ b/unet.hpp @@ -205,8 +205,8 @@ class UnetModelBlock : public GGMLBlock { num_head_channels = 64; num_heads = -1; } else if (version == VERSION_SD1_TINY_UNET) { - num_res_blocks = 1; - channel_mult = {1, 2, 4}; + num_res_blocks = 1; + channel_mult = {1, 2, 4}; } if (sd_version_is_inpaint(version)) { in_channels = 9; @@ -273,15 +273,17 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; - int td=transformer_depth[i]; + int td = transformer_depth[i]; if (version == VERSION_SDXL_SSD1B) { - if (i==2) td=4; + if (i == 2) { + td = 4; + } } - blocks[name] = std::shared_ptr(get_attention_layer(ch, - n_head, - d_head, - td, - context_dim)); + blocks[name] = std::shared_ptr(get_attention_layer(ch, + n_head, + d_head, + td, + context_dim)); } input_block_chans.push_back(ch); if (version == VERSION_SD1_TINY_UNET) { @@ -309,10 +311,10 @@ class UnetModelBlock : public GGMLBlock { blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); if (version != VERSION_SDXL_SSD1B) { blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, - n_head, - d_head, - transformer_depth[transformer_depth.size() - 1], - context_dim)); + n_head, + d_head, + transformer_depth[transformer_depth.size() - 1], + context_dim)); blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); } } @@ -337,12 +339,16 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; - int td = transformer_depth[i]; + int td = transformer_depth[i]; if (version == VERSION_SDXL_SSD1B) { - if (i==2 && (j==0 || j==1)) td=4; - if (i==1 && (j==1 || j==2)) td=1; - } - blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, td, context_dim)); + if (i == 2 && (j == 0 || j == 1)) { + td = 4; + } + if (i == 1 && (j == 1 || j == 2)) { + td = 1; + } + } + blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, td, context_dim)); up_sample_idx++; } @@ -351,7 +357,7 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SD1_TINY_UNET) { output_block_idx++; if (output_block_idx == 2) { - up_sample_idx=1; + up_sample_idx = 1; } } std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx); @@ -505,7 +511,7 @@ class UnetModelBlock : public GGMLBlock { // middle_block if (version != VERSION_SD1_TINY_UNET) { - h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] if (version != VERSION_SDXL_SSD1B) { h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] @@ -549,7 +555,7 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SD1_TINY_UNET) { output_block_idx++; if (output_block_idx == 2) { - up_sample_idx=1; + up_sample_idx = 1; } } std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);