@@ -243,17 +243,23 @@ class FeedForward : public GGMLBlock {
243243 int64_t dim_out,
244244 int64_t mult = 4 ,
245245 Activation activation = Activation::GEGLU,
246- bool force_prec_f32 = false ) {
246+ bool precision_fix = false ) {
247247 int64_t inner_dim = dim * mult;
248- SD_UNUSED (force_prec_f32);
249248 if (activation == Activation::GELU) {
250249 blocks[" net.0" ] = std::shared_ptr<GGMLBlock>(new GELU (dim, inner_dim));
251250 } else {
252251 blocks[" net.0" ] = std::shared_ptr<GGMLBlock>(new GEGLU (dim, inner_dim));
253252 }
254253
255254 // net_1 is nn.Dropout(), skip for inference
256- blocks[" net.2" ] = std::shared_ptr<GGMLBlock>(new Linear (inner_dim, dim_out));
255+ float scale = 1 .f ;
256+ if (precision_fix) {
257+ scale = 1 .f / 128 .f ;
258+ }
259+ // The purpose of the scale here is to prevent NaN issues in certain situations.
260+ // For example, when using Vulkan without enabling force_prec_f32,
261+ // or when using CUDA but the weights are k-quants.
262+ blocks[" net.2" ] = std::shared_ptr<GGMLBlock>(new Linear (inner_dim, dim_out, true , false , false , scale));
257263 }
258264
259265 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
@@ -264,13 +270,7 @@ class FeedForward : public GGMLBlock {
264270 auto net_2 = std::dynamic_pointer_cast<Linear>(blocks[" net.2" ]);
265271
266272 x = net_0->forward (ctx, x); // [ne3, ne2, ne1, inner_dim]
267- // The purpose of the scale here is to prevent NaN issues in certain situations.
268- // For example, when using Vulkan without enabling force_prec_f32,
269- // or when using CUDA but the weights are k-quants.
270- float scale = 1 .f / 128 .f ;
271- x = ggml_scale (ctx, x, scale);
272- x = net_2->forward (ctx, x); // [ne3, ne2, ne1, dim_out]
273- x = ggml_scale (ctx, x, 1 .f / scale);
273+ x = net_2->forward (ctx, x); // [ne3, ne2, ne1, dim_out]
274274 return x;
275275 }
276276};
0 commit comments