Skip to content

Commit fce6afc

Browse files
authored
feat: add sd3 flash attn support (#815)
1 parent 49d6570 commit fce6afc

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

diffusion_model.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ struct MMDiTModel : public DiffusionModel {
9595

9696
MMDiTModel(ggml_backend_t backend,
9797
bool offload_params_to_cpu,
98+
bool flash_attn = false,
9899
const String2GGMLType& tensor_types = {})
99-
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
100+
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
100101
}
101102

102103
std::string get_desc() {

mmdit.hpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,16 @@ class SelfAttention : public GGMLBlock {
147147
int64_t num_heads;
148148
bool pre_only;
149149
std::string qk_norm;
150+
bool flash_attn;
150151

151152
public:
152153
SelfAttention(int64_t dim,
153154
int64_t num_heads = 8,
154155
std::string qk_norm = "",
155156
bool qkv_bias = false,
156-
bool pre_only = false)
157-
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
157+
bool pre_only = false,
158+
bool flash_attn = false)
159+
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
158160
int64_t d_head = dim / num_heads;
159161
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
160162
if (!pre_only) {
@@ -206,8 +208,8 @@ class SelfAttention : public GGMLBlock {
206208
ggml_backend_t backend,
207209
struct ggml_tensor* x) {
208210
auto qkv = pre_attention(ctx, x);
209-
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
210-
x = post_attention(ctx, x); // [N, n_token, dim]
211+
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, true); // [N, n_token, dim]
212+
x = post_attention(ctx, x); // [N, n_token, dim]
211213
return x;
212214
}
213215
};
@@ -232,6 +234,7 @@ struct DismantledBlock : public GGMLBlock {
232234
int64_t num_heads;
233235
bool pre_only;
234236
bool self_attn;
237+
bool flash_attn;
235238

236239
public:
237240
DismantledBlock(int64_t hidden_size,
@@ -240,16 +243,17 @@ struct DismantledBlock : public GGMLBlock {
240243
std::string qk_norm = "",
241244
bool qkv_bias = false,
242245
bool pre_only = false,
243-
bool self_attn = false)
246+
bool self_attn = false,
247+
bool flash_attn = false)
244248
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
245249
// rmsnorm is always Flase
246250
// scale_mod_only is always Flase
247251
// swiglu is always Flase
248252
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
249-
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
253+
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));
250254

251255
if (self_attn) {
252-
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
256+
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
253257
}
254258

255259
if (!pre_only) {
@@ -435,8 +439,8 @@ struct DismantledBlock : public GGMLBlock {
435439
auto qkv2 = std::get<1>(qkv_intermediates);
436440
auto intermediates = std::get<2>(qkv_intermediates);
437441

438-
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
439-
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
442+
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
443+
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
440444
x = post_attention_x(ctx,
441445
attn_out,
442446
attn2_out,
@@ -452,7 +456,7 @@ struct DismantledBlock : public GGMLBlock {
452456
auto qkv = qkv_intermediates.first;
453457
auto intermediates = qkv_intermediates.second;
454458

455-
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
459+
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
456460
x = post_attention(ctx,
457461
attn_out,
458462
intermediates[0],
@@ -468,6 +472,7 @@ struct DismantledBlock : public GGMLBlock {
468472
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
469473
block_mixing(struct ggml_context* ctx,
470474
ggml_backend_t backend,
475+
bool flash_attn,
471476
struct ggml_tensor* context,
472477
struct ggml_tensor* x,
473478
struct ggml_tensor* c,
@@ -497,8 +502,8 @@ block_mixing(struct ggml_context* ctx,
497502
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
498503
}
499504

500-
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size]
501-
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
505+
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, NULL, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
506+
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
502507
auto context_attn = ggml_view_3d(ctx,
503508
attn,
504509
attn->ne[0],
@@ -556,16 +561,20 @@ block_mixing(struct ggml_context* ctx,
556561
}
557562

558563
struct JointBlock : public GGMLBlock {
564+
bool flash_attn;
565+
559566
public:
560567
JointBlock(int64_t hidden_size,
561568
int64_t num_heads,
562569
float mlp_ratio = 4.0,
563570
std::string qk_norm = "",
564571
bool qkv_bias = false,
565572
bool pre_only = false,
566-
bool self_attn_x = false) {
567-
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
568-
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
573+
bool self_attn_x = false,
574+
bool flash_attn = false)
575+
: flash_attn(flash_attn) {
576+
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
577+
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
569578
}
570579

571580
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@@ -576,7 +585,7 @@ struct JointBlock : public GGMLBlock {
576585
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
577586
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
578587

579-
return block_mixing(ctx, backend, context, x, c, context_block, x_block);
588+
return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
580589
}
581590
};
582591

@@ -634,14 +643,16 @@ struct MMDiT : public GGMLBlock {
634643
int64_t context_embedder_out_dim = 1536;
635644
int64_t hidden_size;
636645
std::string qk_norm;
646+
bool flash_attn = false;
637647

638648
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
639649
enum ggml_type wtype = GGML_TYPE_F32;
640650
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
641651
}
642652

643653
public:
644-
MMDiT(const String2GGMLType& tensor_types = {}) {
654+
MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
655+
: flash_attn(flash_attn) {
645656
// input_size is always None
646657
// learn_sigma is always False
647658
// register_length is alwalys 0
@@ -709,7 +720,8 @@ struct MMDiT : public GGMLBlock {
709720
qk_norm,
710721
true,
711722
i == depth - 1,
712-
i <= d_self));
723+
i <= d_self,
724+
flash_attn));
713725
}
714726

715727
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
@@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner {
856868

857869
MMDiTRunner(ggml_backend_t backend,
858870
bool offload_params_to_cpu,
871+
bool flash_attn,
859872
const String2GGMLType& tensor_types = {},
860873
const std::string prefix = "")
861-
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
874+
: GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
862875
mmdit.init(params_ctx, tensor_types, prefix);
863876
}
864877

@@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner {
957970
// ggml_backend_t backend = ggml_backend_cuda_init(0);
958971
ggml_backend_t backend = ggml_backend_cpu_init();
959972
ggml_type model_data_type = GGML_TYPE_F16;
960-
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false));
973+
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false, false));
961974
{
962975
LOG_INFO("loading from '%s'", file_path.c_str());
963976

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ class StableDiffusionGGML {
350350
model_loader.tensor_storages_types);
351351
diffusion_model = std::make_shared<MMDiTModel>(backend,
352352
offload_params_to_cpu,
353+
sd_ctx_params->diffusion_flash_attn,
353354
model_loader.tensor_storages_types);
354355
} else if (sd_version_is_flux(version)) {
355356
bool is_chroma = false;

0 commit comments

Comments
 (0)