Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ API and command-line option may change frequently.***
- Image Models
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [Chroma](./docs/chroma.md)
Expand Down
86 changes: 86 additions & 0 deletions docs/distilled_sd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Running distilled models: SSD1B and SD1.x with tiny U-Nets

## Preface

This kind of models have a reduced U-Net part.
Unlike other SDXL models the U-Net of SSD1B has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 .
Unlike other SD 1.x models Tiny-UNet models consist of only 6 U-Net blocks, resulting in relatively smaller files (approximately 1 GB). Running these models saves almost 50% of the time. For more details, refer to the paper: https://arxiv.org/pdf/2305.15798.pdf .

## SSD1B

Unfortunately not all of this models follow the standard model parameter naming mapping.
Anyway there are some very useful SSD1B models available online, such as:

* https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors
* https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors

Also there are useful LORAs available:

* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors

You can use this files **out-of-the-box** - unlike models in next section.


## SD1.x with tiny U-Nets

There are some Tiny SD 1.x models available online, such as:

* https://huggingface.co/segmind/tiny-sd
* https://huggingface.co/segmind/portrait-finetuned
* https://huggingface.co/nota-ai/bk-sdm-tiny

These models need some conversion, for example because partially tensors are **non contiguous** stored. To create a usable checkpoint file, follow these **easy** steps:

### Download model from Hugging Face

Download the model using Python on your computer, for example this way:

```python
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("segmind/tiny-sd")
unet=pipe.unet
for param in unet.parameters():
param.data = param.data.contiguous() # <- important here
pipe.save_pretrained("segmindtiny-sd", safe_serialization=True)
```

### Convert that to a ckpt file

To convert the downloaded model to a checkpoint file, you need another Python script. Download the conversion script from here:

* https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/scripts/convert_diffusers_to_original_stable_diffusion.py


### Run convert script

Now, run that conversion script:

```bash
python convert_diffusers_to_original_stable_diffusion.py \
--model_path ./segmindtiny-sd \
--checkpoint_path ./segmind_tiny-sd.ckpt --half
```

The file **segmind_tiny-sd.ckpt** will be generated and is now ready to use with sd.cpp

You can follow a similar process for other models mentioned above from Hugging Face.


### Another ckpt file on the net

There is another model file available online:

* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt

If you want to use that, you have to adjust some **non-contiguous tensors** first:

```python
import torch
ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu'))
for key, value in ckpt['state_dict'].items():
if isinstance(value, torch.Tensor):
ckpt['state_dict'][key] = value.contiguous()
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
```
25 changes: 24 additions & 1 deletion model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ std::string convert_cond_model_name(const std::string& name) {
return new_name;
}

if (new_name == "model.text_projection.weight") {
new_name = "transformer.text_model.text_projection";
}

if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
}
Expand Down Expand Up @@ -623,6 +627,14 @@ std::string convert_tensor_name(std::string name) {
if (starts_with(name, "diffusion_model")) {
name = "model." + name;
}
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) {
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1,
"model.diffusion_model.output_blocks.0.1.");
}
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) {
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1,
"model.diffusion_model.output_blocks.1.1.");
}
// size_t pos = name.find("lora_A");
// if (pos != std::string::npos) {
// name.replace(pos, strlen("lora_A"), "lora_up");
Expand Down Expand Up @@ -1776,6 +1788,7 @@ SDVersion ModelLoader::get_sd_version() {
bool is_wan = false;
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
bool has_middle_block_1 = false;

for (auto& tensor_storage : tensor_storages) {
if (!(is_xl || is_flux)) {
Expand Down Expand Up @@ -1822,6 +1835,10 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SVD;
}
}
if (tensor_storage.name.find("model.diffusion_model.middle_block.1.") != std::string::npos ||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
Expand All @@ -1834,7 +1851,7 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (is_xl || is_flux) {
if (is_flux) {
break;
}
}
Expand All @@ -1858,6 +1875,9 @@ SDVersion ModelLoader::get_sd_version() {
if (is_ip2p) {
return VERSION_SDXL_PIX2PIX;
}
if (!has_middle_block_1) {
return VERSION_SDXL_SSD1B;
}
return VERSION_SDXL;
}

Expand All @@ -1881,6 +1901,9 @@ SDVersion ModelLoader::get_sd_version() {
if (is_ip2p) {
return VERSION_SD1_PIX2PIX;
}
if (!has_middle_block_1) {
return VERSION_SD1_TINY_UNET;
}
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {
if (is_inpaint) {
Expand Down
6 changes: 4 additions & 2 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ enum SDVersion {
VERSION_SD1,
VERSION_SD1_INPAINT,
VERSION_SD1_PIX2PIX,
VERSION_SD1_TINY_UNET,
VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX,
VERSION_SDXL_SSD1B,
VERSION_SVD,
VERSION_SD3,
VERSION_FLUX,
Expand All @@ -42,7 +44,7 @@ enum SDVersion {
};

static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) {
return true;
}
return false;
Expand All @@ -56,7 +58,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
}

static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
return true;
}
return false;
Expand Down
2 changes: 2 additions & 0 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
"Instruct-Pix2Pix",
"SD 1.x Tiny UNet",
"SD 2.x",
"SD 2.x Inpaint",
"SDXL",
"SDXL Inpaint",
"SDXL Instruct-Pix2Pix",
"SDXL (SSD1B)",
"SVD",
"SD3.x",
"Flux",
Expand Down
78 changes: 60 additions & 18 deletions unet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ class UnetModelBlock : public GGMLBlock {
adm_in_channels = 768;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_SD1_TINY_UNET) {
num_res_blocks = 1;
channel_mult = {1, 2, 4};
}
if (sd_version_is_inpaint(version)) {
in_channels = 9;
Expand Down Expand Up @@ -270,13 +273,22 @@ class UnetModelBlock : public GGMLBlock {
n_head = ch / d_head;
}
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[i],
context_dim));
int td = transformer_depth[i];
if (version == VERSION_SDXL_SSD1B) {
if (i == 2) {
td = 4;
}
}
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
td,
context_dim));
}
input_block_chans.push_back(ch);
if (version == VERSION_SD1_TINY_UNET) {
input_block_idx++;
}
}
if (i != len_mults - 1) {
input_block_idx += 1;
Expand All @@ -295,14 +307,17 @@ class UnetModelBlock : public GGMLBlock {
d_head = num_head_channels;
n_head = ch / d_head;
}
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[transformer_depth.size() - 1],
context_dim));
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));

if (version != VERSION_SD1_TINY_UNET) {
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
if (version != VERSION_SDXL_SSD1B) {
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[transformer_depth.size() - 1],
context_dim));
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
}
}
// output_blocks
int output_block_idx = 0;
for (int i = (int)len_mults - 1; i >= 0; i--) {
Expand All @@ -324,12 +339,27 @@ class UnetModelBlock : public GGMLBlock {
n_head = ch / d_head;
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim));
int td = transformer_depth[i];
if (version == VERSION_SDXL_SSD1B) {
if (i == 2 && (j == 0 || j == 1)) {
td = 4;
}
if (i == 1 && (j == 1 || j == 2)) {
td = 1;
}
}
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, td, context_dim));

up_sample_idx++;
}

if (i > 0 && j == num_res_blocks) {
if (version == VERSION_SD1_TINY_UNET) {
output_block_idx++;
if (output_block_idx == 2) {
up_sample_idx = 1;
}
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(ch, ch));

Expand Down Expand Up @@ -463,6 +493,9 @@ class UnetModelBlock : public GGMLBlock {
}
hs.push_back(h);
}
if (version == VERSION_SD1_TINY_UNET) {
input_block_idx++;
}
if (i != len_mults - 1) {
ds *= 2;
input_block_idx += 1;
Expand All @@ -477,10 +510,13 @@ class UnetModelBlock : public GGMLBlock {
// [N, 4*model_channels, h/8, w/8]

// middle_block
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]

if (version != VERSION_SD1_TINY_UNET) {
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (version != VERSION_SDXL_SSD1B) {
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
}
}
if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
h = ggml_add(ctx, h, cs); // middle control
Expand Down Expand Up @@ -516,6 +552,12 @@ class UnetModelBlock : public GGMLBlock {
}

if (i > 0 && j == num_res_blocks) {
if (version == VERSION_SD1_TINY_UNET) {
output_block_idx++;
if (output_block_idx == 2) {
up_sample_idx = 1;
}
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
auto block = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);

Expand Down
Loading