Skip to content

Commit cc064a0

Browse files
committed
optimize the handling of the FeedForward precision fix
1 parent 98d6e71 commit cc064a0

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

common.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,17 +243,23 @@ class FeedForward : public GGMLBlock {
243243
int64_t dim_out,
244244
int64_t mult = 4,
245245
Activation activation = Activation::GEGLU,
246-
bool force_prec_f32 = false) {
246+
bool precision_fix = false) {
247247
int64_t inner_dim = dim * mult;
248-
SD_UNUSED(force_prec_f32);
249248
if (activation == Activation::GELU) {
250249
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GELU(dim, inner_dim));
251250
} else {
252251
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
253252
}
254253

255254
// net_1 is nn.Dropout(), skip for inference
256-
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
255+
float scale = 1.f;
256+
if (precision_fix) {
257+
scale = 1.f / 128.f;
258+
}
259+
// The purpose of the scale here is to prevent NaN issues in certain situations.
260+
// For example, when using Vulkan without enabling force_prec_f32,
261+
// or when using CUDA but the weights are k-quants.
262+
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale));
257263
}
258264

259265
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -264,13 +270,7 @@ class FeedForward : public GGMLBlock {
264270
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
265271

266272
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
267-
// The purpose of the scale here is to prevent NaN issues in certain situations.
268-
// For example, when using Vulkan without enabling force_prec_f32,
269-
// or when using CUDA but the weights are k-quants.
270-
float scale = 1.f / 128.f;
271-
x = ggml_scale(ctx, x, scale);
272-
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
273-
x = ggml_scale(ctx, x, 1.f / scale);
273+
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
274274
return x;
275275
}
276276
};

ggml_extend.hpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -944,11 +944,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
944944
struct ggml_tensor* x,
945945
struct ggml_tensor* w,
946946
struct ggml_tensor* b,
947-
bool force_prec_f32 = false) {
947+
bool force_prec_f32 = false,
948+
float scale = 1.f) {
949+
if (scale != 1.f) {
950+
x = ggml_scale(ctx, x, scale);
951+
}
948952
x = ggml_mul_mat(ctx, w, x);
949953
if (force_prec_f32) {
950954
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
951955
}
956+
if (scale != 1.f) {
957+
x = ggml_scale(ctx, x, 1.f / scale);
958+
}
952959
if (b != NULL) {
953960
x = ggml_add_inplace(ctx, x, b);
954961
}
@@ -1962,6 +1969,7 @@ class Linear : public UnaryBlock {
19621969
bool bias;
19631970
bool force_f32;
19641971
bool force_prec_f32;
1972+
float scale;
19651973

19661974
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
19671975
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
@@ -1980,20 +1988,22 @@ class Linear : public UnaryBlock {
19801988
int64_t out_features,
19811989
bool bias = true,
19821990
bool force_f32 = false,
1983-
bool force_prec_f32 = false)
1991+
bool force_prec_f32 = false,
1992+
float scale = 1.f)
19841993
: in_features(in_features),
19851994
out_features(out_features),
19861995
bias(bias),
19871996
force_f32(force_f32),
1988-
force_prec_f32(force_prec_f32) {}
1997+
force_prec_f32(force_prec_f32),
1998+
scale(scale) {}
19891999

19902000
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
19912001
struct ggml_tensor* w = params["weight"];
19922002
struct ggml_tensor* b = NULL;
19932003
if (bias) {
19942004
b = params["bias"];
19952005
}
1996-
return ggml_nn_linear(ctx, x, w, b, force_prec_f32);
2006+
return ggml_nn_linear(ctx, x, w, b, force_prec_f32, scale);
19972007
}
19982008
};
19992009

0 commit comments

Comments
 (0)