@@ -270,10 +270,14 @@ class UnetModelBlock : public GGMLBlock {
270270 n_head = ch / d_head;
271271 }
272272 std::string name = " input_blocks." + std::to_string (input_block_idx) + " .1" ;
273+ int td=transformer_depth[i];
274+ if (version == VERSION_SDXL_SSD1B) {
275+ if (i==2 ) td=4 ;
276+ }
273277 blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer (ch,
274278 n_head,
275279 d_head,
276- transformer_depth[i] ,
280+ td ,
277281 context_dim));
278282 }
279283 input_block_chans.push_back (ch);
@@ -296,13 +300,14 @@ class UnetModelBlock : public GGMLBlock {
296300 n_head = ch / d_head;
297301 }
298302 blocks[" middle_block.0" ] = std::shared_ptr<GGMLBlock>(get_resblock (ch, time_embed_dim, ch));
299- blocks[" middle_block.1" ] = std::shared_ptr<GGMLBlock>(get_attention_layer (ch,
300- n_head,
301- d_head,
302- transformer_depth[transformer_depth.size () - 1 ],
303- context_dim));
304- blocks[" middle_block.2" ] = std::shared_ptr<GGMLBlock>(get_resblock (ch, time_embed_dim, ch));
305-
303+ if (version != VERSION_SDXL_SSD1B) {
304+ blocks[" middle_block.1" ] = std::shared_ptr<GGMLBlock>(get_attention_layer (ch,
305+ n_head,
306+ d_head,
307+ transformer_depth[transformer_depth.size () - 1 ],
308+ context_dim));
309+ blocks[" middle_block.2" ] = std::shared_ptr<GGMLBlock>(get_resblock (ch, time_embed_dim, ch));
310+ }
306311 // output_blocks
307312 int output_block_idx = 0 ;
308313 for (int i = (int )len_mults - 1 ; i >= 0 ; i--) {
@@ -324,7 +329,12 @@ class UnetModelBlock : public GGMLBlock {
324329 n_head = ch / d_head;
325330 }
326331 std::string name = " output_blocks." + std::to_string (output_block_idx) + " .1" ;
327- blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer (ch, n_head, d_head, transformer_depth[i], context_dim));
332+ int td = transformer_depth[i];
333+ if (version == VERSION_SDXL_SSD1B) {
334+ if (i==2 && (j==0 || j==1 )) td=4 ;
335+ if (i==1 && (j==1 || j==2 )) td=1 ;
336+ }
337+ blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer (ch, n_head, d_head, td, context_dim));
328338
329339 up_sample_idx++;
330340 }
@@ -478,9 +488,10 @@ class UnetModelBlock : public GGMLBlock {
478488
479489 // middle_block
480490 h = resblock_forward (" middle_block.0" , ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
481- h = attention_layer_forward (" middle_block.1" , ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
482- h = resblock_forward (" middle_block.2" , ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
483-
491+ if (version != VERSION_SDXL_SSD1B) {
492+ h = attention_layer_forward (" middle_block.1" , ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
493+ h = resblock_forward (" middle_block.2" , ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
494+ }
484495 if (controls.size () > 0 ) {
485496 auto cs = ggml_scale_inplace (ctx, controls[controls.size () - 1 ], control_strength);
486497 h = ggml_add (ctx, h, cs); // middle control
0 commit comments