Skip to content

Commit 66ef705

Browse files
committed
feat: add code and doc for running SSD1B models
1 parent 0f51783 commit 66ef705

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

model.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1859,7 +1859,12 @@ SDVersion ModelLoader::get_sd_version() {
18591859
if (is_ip2p) {
18601860
return VERSION_SDXL_PIX2PIX;
18611861
}
1862-
return VERSION_SDXL;
1862+
for (auto& tensor_storage : tensor_storages) {
1863+
if (tensor_storage.name.find("model.diffusion_model.middle_block.1") != std::string::npos) {
1864+
return VERSION_SDXL; // found a missing tensor in SSD1B, so it is SDXL
1865+
}
1866+
}
1867+
return VERSION_SDXL_SSD1B;
18631868
}
18641869

18651870
if (is_flux) {

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ enum SDVersion {
2727
VERSION_SDXL,
2828
VERSION_SDXL_INPAINT,
2929
VERSION_SDXL_PIX2PIX,
30+
VERSION_SDXL_SSD1B,
3031
VERSION_SVD,
3132
VERSION_SD3,
3233
VERSION_FLUX,
@@ -55,7 +56,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
5556
}
5657

5758
static inline bool sd_version_is_sdxl(SDVersion version) {
58-
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
59+
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
5960
return true;
6061
}
6162
return false;

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ const char* model_version_to_str[] = {
3333
"SDXL",
3434
"SDXL Inpaint",
3535
"SDXL Instruct-Pix2Pix",
36+
"SDXL (SSD1B)",
3637
"SVD",
3738
"SD3.x",
3839
"Flux",

unet.hpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,14 @@ class UnetModelBlock : public GGMLBlock {
270270
n_head = ch / d_head;
271271
}
272272
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
273+
int td=transformer_depth[i];
274+
if (version == VERSION_SDXL_SSD1B) {
275+
if (i==2) td=4;
276+
}
273277
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
274278
n_head,
275279
d_head,
276-
transformer_depth[i],
280+
td,
277281
context_dim));
278282
}
279283
input_block_chans.push_back(ch);
@@ -296,13 +300,14 @@ class UnetModelBlock : public GGMLBlock {
296300
n_head = ch / d_head;
297301
}
298302
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
299-
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
300-
n_head,
301-
d_head,
302-
transformer_depth[transformer_depth.size() - 1],
303-
context_dim));
304-
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
305-
303+
if (version != VERSION_SDXL_SSD1B) {
304+
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
305+
n_head,
306+
d_head,
307+
transformer_depth[transformer_depth.size() - 1],
308+
context_dim));
309+
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
310+
}
306311
// output_blocks
307312
int output_block_idx = 0;
308313
for (int i = (int)len_mults - 1; i >= 0; i--) {
@@ -324,7 +329,12 @@ class UnetModelBlock : public GGMLBlock {
324329
n_head = ch / d_head;
325330
}
326331
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
327-
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim));
332+
int td = transformer_depth[i];
333+
if (version == VERSION_SDXL_SSD1B) {
334+
if (i==2 && (j==0 || j==1)) td=4;
335+
if (i==1 && (j==1 || j==2)) td=1;
336+
}
337+
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, td, context_dim));
328338

329339
up_sample_idx++;
330340
}
@@ -478,9 +488,10 @@ class UnetModelBlock : public GGMLBlock {
478488

479489
// middle_block
480490
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
481-
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
482-
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
483-
491+
if (version != VERSION_SDXL_SSD1B) {
492+
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
493+
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
494+
}
484495
if (controls.size() > 0) {
485496
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
486497
h = ggml_add(ctx, h, cs); // middle control

0 commit comments

Comments
 (0)