Skip to content

Commit b553004

Browse files
committed
Added some more lines to support SD1.x with TINY U-Nets too.
This is a repleacement of my earlier PR leejet#745 (just closed by myself) with much better code. Also updated doc file about TINY U-Nets.
1 parent 66ef705 commit b553004

File tree

6 files changed

+124
-15
lines changed

6 files changed

+124
-15
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ API and command-line option may change frequently.***
1717
- Image Models
1818
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
1919
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
20+
- [some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
2021
- [SD3/SD3.5](./docs/sd3.md)
2122
- [Flux-dev/Flux-schnell](./docs/flux.md)
2223
- [Chroma](./docs/chroma.md)

docs/distilled_sd.md

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
# Running distilled SDXL models: SSD1B
1+
# Running distilled models: SSD1B and SD1.x with tiny U-Nets
22

3-
### Preface
3+
## Preface
44

5-
This kind of models has a reduced U-Net part. Unlike other SDXL models the U-Net has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 .
5+
This kind of models have a reduced U-Net part.
6+
Unlike other SDXL models the U-Net of SSD1B has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 .
7+
Unlike other SD 1.x models Tiny-UNet models consist of only 6 U-Net blocks, resulting in relatively smaller files (approximately 1 GB). Running these models saves almost 50% of the time. For more details, refer to the paper: https://arxiv.org/pdf/2305.15798.pdf .
68

7-
### How to Use
9+
## SSD1B
810

911
Unfortunately not all of this models follow the standard model parameter naming mapping.
10-
Anyway there are some useful SSD1B models available online, such as:
12+
Anyway there are some very useful SSD1B models available online, such as:
1113

1214
* https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors
1315
* https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors
@@ -16,3 +18,69 @@ Also there are useful LORAs available:
1618

1719
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
1820
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
21+
22+
You can use this files **out-of-the-box** - unlike models in next section.
23+
24+
25+
## SD1.x with tiny U-Nets
26+
27+
There are some Tiny SD 1.x models available online, such as:
28+
29+
* https://huggingface.co/segmind/tiny-sd
30+
* https://huggingface.co/segmind/portrait-finetuned
31+
* https://huggingface.co/nota-ai/bk-sdm-tiny
32+
33+
These models need some conversion, for example because partially tensors are **non contiguous** stored. To create a usable checkpoint file, follow these **easy** steps:
34+
35+
### Download model from Hugging Face
36+
37+
Download the model using Python on your computer, for example this way:
38+
39+
```python
40+
import torch
41+
from diffusers import StableDiffusionPipeline
42+
pipe = StableDiffusionPipeline.from_pretrained("segmind/tiny-sd")
43+
unet=pipe.unet
44+
for param in unet.parameters():
45+
param.data = param.data.contiguous() # <- important here
46+
pipe.save_pretrained("segmindtiny-sd", safe_serialization=True)
47+
```
48+
49+
### Convert that to a ckpt file
50+
51+
To convert the downloaded model to a checkpoint file, you need another Python script. Download the conversion script from here:
52+
53+
* https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/scripts/convert_diffusers_to_original_stable_diffusion.py
54+
55+
56+
### Run convert script
57+
58+
Now, run that conversion script:
59+
60+
```bash
61+
python convert_diffusers_to_original_stable_diffusion.py \
62+
--model_path ./segmindtiny-sd \
63+
--checkpoint_path ./segmind_tiny-sd.ckpt --half
64+
```
65+
66+
The file **segmind_tiny-sd.ckpt** will be generated and is now ready to use with sd.cpp
67+
68+
You can follow a similar process for other models mentioned above from Hugging Face.
69+
70+
71+
### Another ckpt file on the net
72+
73+
There is another model file available online:
74+
75+
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
76+
77+
If you want to use that, you have to adjust some **non-contiguous tensors** first:
78+
79+
```python
80+
import torch
81+
ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu'))
82+
for key, value in ckpt['state_dict'].items():
83+
if isinstance(value, torch.Tensor):
84+
ckpt['state_dict'][key] = value.contiguous()
85+
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
86+
```

model.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,14 @@ std::string convert_tensor_name(std::string name) {
623623
if (starts_with(name, "diffusion_model")) {
624624
name = "model." + name;
625625
}
626+
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) {
627+
name.replace(0,sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1,
628+
"model.diffusion_model.output_blocks.0.1.");
629+
}
630+
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) {
631+
name.replace(0,sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1,
632+
"model.diffusion_model.output_blocks.1.1.");
633+
}
626634
// size_t pos = name.find("lora_A");
627635
// if (pos != std::string::npos) {
628636
// name.replace(pos, strlen("lora_A"), "lora_up");
@@ -1887,7 +1895,12 @@ SDVersion ModelLoader::get_sd_version() {
18871895
if (is_ip2p) {
18881896
return VERSION_SD1_PIX2PIX;
18891897
}
1890-
return VERSION_SD1;
1898+
for (auto& tensor_storage : tensor_storages) {
1899+
if (tensor_storage.name.find("model.diffusion_model.middle_block") != std::string::npos) {
1900+
return VERSION_SD1; // found a middle block, so it is SD1
1901+
}
1902+
}
1903+
return VERSION_SD1_TINY_UNET;
18911904
} else if (token_embedding_weight.ne[0] == 1024) {
18921905
if (is_inpaint) {
18931906
return VERSION_SD2_INPAINT;

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ enum SDVersion {
2222
VERSION_SD1,
2323
VERSION_SD1_INPAINT,
2424
VERSION_SD1_PIX2PIX,
25+
VERSION_SD1_TINY_UNET,
2526
VERSION_SD2,
2627
VERSION_SD2_INPAINT,
2728
VERSION_SDXL,
@@ -42,7 +43,7 @@ enum SDVersion {
4243
};
4344

4445
static inline bool sd_version_is_sd1(SDVersion version) {
45-
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
46+
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) {
4647
return true;
4748
}
4849
return false;

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const char* model_version_to_str[] = {
2828
"SD 1.x",
2929
"SD 1.x Inpaint",
3030
"Instruct-Pix2Pix",
31+
"SD 1.x Tiny UNet",
3132
"SD 2.x",
3233
"SD 2.x Inpaint",
3334
"SDXL",

unet.hpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ class UnetModelBlock : public GGMLBlock {
204204
adm_in_channels = 768;
205205
num_head_channels = 64;
206206
num_heads = -1;
207+
} else if (version == VERSION_SD1_TINY_UNET) {
208+
num_res_blocks = 1;
209+
channel_mult = {1, 2, 4};
207210
}
208211
if (sd_version_is_inpaint(version)) {
209212
in_channels = 9;
@@ -281,6 +284,9 @@ class UnetModelBlock : public GGMLBlock {
281284
context_dim));
282285
}
283286
input_block_chans.push_back(ch);
287+
if (version == VERSION_SD1_TINY_UNET) {
288+
input_block_idx++;
289+
}
284290
}
285291
if (i != len_mults - 1) {
286292
input_block_idx += 1;
@@ -299,14 +305,16 @@ class UnetModelBlock : public GGMLBlock {
299305
d_head = num_head_channels;
300306
n_head = ch / d_head;
301307
}
302-
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
303-
if (version != VERSION_SDXL_SSD1B) {
304-
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
308+
if (version != VERSION_SD1_TINY_UNET) {
309+
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
310+
if (version != VERSION_SDXL_SSD1B) {
311+
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
305312
n_head,
306313
d_head,
307314
transformer_depth[transformer_depth.size() - 1],
308315
context_dim));
309-
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
316+
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
317+
}
310318
}
311319
// output_blocks
312320
int output_block_idx = 0;
@@ -340,6 +348,12 @@ class UnetModelBlock : public GGMLBlock {
340348
}
341349

342350
if (i > 0 && j == num_res_blocks) {
351+
if (version == VERSION_SD1_TINY_UNET) {
352+
output_block_idx++;
353+
if (output_block_idx == 2) {
354+
up_sample_idx=1;
355+
}
356+
}
343357
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
344358
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(ch, ch));
345359

@@ -473,6 +487,9 @@ class UnetModelBlock : public GGMLBlock {
473487
}
474488
hs.push_back(h);
475489
}
490+
if (version == VERSION_SD1_TINY_UNET) {
491+
input_block_idx++;
492+
}
476493
if (i != len_mults - 1) {
477494
ds *= 2;
478495
input_block_idx += 1;
@@ -487,10 +504,12 @@ class UnetModelBlock : public GGMLBlock {
487504
// [N, 4*model_channels, h/8, w/8]
488505

489506
// middle_block
490-
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
491-
if (version != VERSION_SDXL_SSD1B) {
492-
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
493-
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
507+
if (version != VERSION_SD1_TINY_UNET) {
508+
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
509+
if (version != VERSION_SDXL_SSD1B) {
510+
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
511+
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
512+
}
494513
}
495514
if (controls.size() > 0) {
496515
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
@@ -527,6 +546,12 @@ class UnetModelBlock : public GGMLBlock {
527546
}
528547

529548
if (i > 0 && j == num_res_blocks) {
549+
if (version == VERSION_SD1_TINY_UNET) {
550+
output_block_idx++;
551+
if (output_block_idx == 2) {
552+
up_sample_idx=1;
553+
}
554+
}
530555
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
531556
auto block = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
532557

0 commit comments

Comments
 (0)