@@ -147,14 +147,16 @@ class SelfAttention : public GGMLBlock {
147147 int64_t num_heads;
148148 bool pre_only;
149149 std::string qk_norm;
150+ bool flash_attn;
150151
151152public:
152153 SelfAttention (int64_t dim,
153154 int64_t num_heads = 8 ,
154155 std::string qk_norm = " " ,
155156 bool qkv_bias = false ,
156- bool pre_only = false )
157- : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
157+ bool pre_only = false ,
158+ bool flash_attn = false )
159+ : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
158160 int64_t d_head = dim / num_heads;
159161 blocks[" qkv" ] = std::shared_ptr<GGMLBlock>(new Linear (dim, dim * 3 , qkv_bias));
160162 if (!pre_only) {
@@ -206,8 +208,8 @@ class SelfAttention : public GGMLBlock {
206208 ggml_backend_t backend,
207209 struct ggml_tensor * x) {
208210 auto qkv = pre_attention (ctx, x);
209- x = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], num_heads); // [N, n_token, dim]
210- x = post_attention (ctx, x); // [N, n_token, dim]
211+ x = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], num_heads, NULL , false , false , true ); // [N, n_token, dim]
212+ x = post_attention (ctx, x); // [N, n_token, dim]
211213 return x;
212214 }
213215};
@@ -232,6 +234,7 @@ struct DismantledBlock : public GGMLBlock {
232234 int64_t num_heads;
233235 bool pre_only;
234236 bool self_attn;
237+ bool flash_attn;
235238
236239public:
237240 DismantledBlock (int64_t hidden_size,
@@ -240,16 +243,17 @@ struct DismantledBlock : public GGMLBlock {
240243 std::string qk_norm = " " ,
241244 bool qkv_bias = false ,
242245 bool pre_only = false ,
243- bool self_attn = false )
246+ bool self_attn = false ,
247+ bool flash_attn = false )
244248 : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
245249 // rmsnorm is always Flase
246250 // scale_mod_only is always Flase
247251 // swiglu is always Flase
248252 blocks[" norm1" ] = std::shared_ptr<GGMLBlock>(new LayerNorm (hidden_size, 1e-06f , false ));
249- blocks[" attn" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
253+ blocks[" attn" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn ));
250254
251255 if (self_attn) {
252- blocks[" attn2" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qk_norm, qkv_bias, false ));
256+ blocks[" attn2" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qk_norm, qkv_bias, false , flash_attn ));
253257 }
254258
255259 if (!pre_only) {
@@ -435,8 +439,8 @@ struct DismantledBlock : public GGMLBlock {
435439 auto qkv2 = std::get<1 >(qkv_intermediates);
436440 auto intermediates = std::get<2 >(qkv_intermediates);
437441
438- auto attn_out = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], num_heads); // [N, n_token, dim]
439- auto attn2_out = ggml_nn_attention_ext (ctx, backend, qkv2[0 ], qkv2[1 ], qkv2[2 ], num_heads); // [N, n_token, dim]
442+ auto attn_out = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], num_heads, NULL , false , false , flash_attn ); // [N, n_token, dim]
443+ auto attn2_out = ggml_nn_attention_ext (ctx, backend, qkv2[0 ], qkv2[1 ], qkv2[2 ], num_heads, NULL , false , false , flash_attn ); // [N, n_token, dim]
440444 x = post_attention_x (ctx,
441445 attn_out,
442446 attn2_out,
@@ -452,7 +456,7 @@ struct DismantledBlock : public GGMLBlock {
452456 auto qkv = qkv_intermediates.first ;
453457 auto intermediates = qkv_intermediates.second ;
454458
455- auto attn_out = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], num_heads); // [N, n_token, dim]
459+ auto attn_out = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], num_heads, NULL , false , false , flash_attn ); // [N, n_token, dim]
456460 x = post_attention (ctx,
457461 attn_out,
458462 intermediates[0 ],
@@ -468,6 +472,7 @@ struct DismantledBlock : public GGMLBlock {
468472__STATIC_INLINE__ std::pair<struct ggml_tensor *, struct ggml_tensor *>
469473block_mixing (struct ggml_context * ctx,
470474 ggml_backend_t backend,
475+ bool flash_attn,
471476 struct ggml_tensor * context,
472477 struct ggml_tensor * x,
473478 struct ggml_tensor * c,
@@ -497,8 +502,8 @@ block_mixing(struct ggml_context* ctx,
497502 qkv.push_back (ggml_concat (ctx, context_qkv[i], x_qkv[i], 1 ));
498503 }
499504
500- auto attn = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], x_block->num_heads ); // [N, n_context + n_token, hidden_size]
501- attn = ggml_cont (ctx, ggml_permute (ctx, attn, 0 , 2 , 1 , 3 )); // [n_context + n_token, N, hidden_size]
505+ auto attn = ggml_nn_attention_ext (ctx, backend, qkv[0 ], qkv[1 ], qkv[2 ], x_block->num_heads , NULL , false , false , flash_attn ); // [N, n_context + n_token, hidden_size]
506+ attn = ggml_cont (ctx, ggml_permute (ctx, attn, 0 , 2 , 1 , 3 )); // [n_context + n_token, N, hidden_size]
502507 auto context_attn = ggml_view_3d (ctx,
503508 attn,
504509 attn->ne [0 ],
@@ -556,16 +561,20 @@ block_mixing(struct ggml_context* ctx,
556561}
557562
558563struct JointBlock : public GGMLBlock {
564+ bool flash_attn;
565+
559566public:
560567 JointBlock (int64_t hidden_size,
561568 int64_t num_heads,
562569 float mlp_ratio = 4.0 ,
563570 std::string qk_norm = " " ,
564571 bool qkv_bias = false ,
565572 bool pre_only = false ,
566- bool self_attn_x = false ) {
567- blocks[" context_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
568- blocks[" x_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false , self_attn_x));
573+ bool self_attn_x = false ,
574+ bool flash_attn = false )
575+ : flash_attn(flash_attn) {
576+ blocks[" context_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false , flash_attn));
577+ blocks[" x_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false , self_attn_x, flash_attn));
569578 }
570579
571580 std::pair<struct ggml_tensor *, struct ggml_tensor *> forward (struct ggml_context * ctx,
@@ -576,7 +585,7 @@ struct JointBlock : public GGMLBlock {
576585 auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks[" context_block" ]);
577586 auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks[" x_block" ]);
578587
579- return block_mixing (ctx, backend, context, x, c, context_block, x_block);
588+ return block_mixing (ctx, backend, flash_attn, context, x, c, context_block, x_block);
580589 }
581590};
582591
@@ -634,14 +643,16 @@ struct MMDiT : public GGMLBlock {
634643 int64_t context_embedder_out_dim = 1536 ;
635644 int64_t hidden_size;
636645 std::string qk_norm;
646+ bool flash_attn = false ;
637647
638648 void init_params (struct ggml_context * ctx, const String2GGMLType& tensor_types = {}, std::string prefix = " " ) {
639649 enum ggml_type wtype = GGML_TYPE_F32;
640650 params[" pos_embed" ] = ggml_new_tensor_3d (ctx, wtype, hidden_size, num_patchs, 1 );
641651 }
642652
643653public:
644- MMDiT (const String2GGMLType& tensor_types = {}) {
654+ MMDiT (bool flash_attn = false , const String2GGMLType& tensor_types = {})
655+ : flash_attn(flash_attn) {
645656 // input_size is always None
646657 // learn_sigma is always False
647658 // register_length is alwalys 0
@@ -709,7 +720,8 @@ struct MMDiT : public GGMLBlock {
709720 qk_norm,
710721 true ,
711722 i == depth - 1 ,
712- i <= d_self));
723+ i <= d_self,
724+ flash_attn));
713725 }
714726
715727 blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new FinalLayer (hidden_size, patch_size, out_channels));
@@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner {
856868
857869 MMDiTRunner (ggml_backend_t backend,
858870 bool offload_params_to_cpu,
871+ bool flash_attn,
859872 const String2GGMLType& tensor_types = {},
860873 const std::string prefix = " " )
861- : GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
874+ : GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
862875 mmdit.init (params_ctx, tensor_types, prefix);
863876 }
864877
@@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner {
957970 // ggml_backend_t backend = ggml_backend_cuda_init(0);
958971 ggml_backend_t backend = ggml_backend_cpu_init ();
959972 ggml_type model_data_type = GGML_TYPE_F16;
960- std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner (backend, false ));
973+ std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner (backend, false , false ));
961974 {
962975 LOG_INFO (" loading from '%s'" , file_path.c_str ());
963976
0 commit comments