Skip to content

Commit 98d6e71

Browse files
committed
fix the issue that occurs when using CUDA with k-quants weights
1 parent 6ea2a75 commit 98d6e71

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

common.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,15 @@ class FeedForward : public GGMLBlock {
245245
Activation activation = Activation::GEGLU,
246246
bool force_prec_f32 = false) {
247247
int64_t inner_dim = dim * mult;
248-
248+
SD_UNUSED(force_prec_f32);
249249
if (activation == Activation::GELU) {
250250
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GELU(dim, inner_dim));
251251
} else {
252252
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
253253
}
254254

255255
// net_1 is nn.Dropout(), skip for inference
256-
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, force_prec_f32));
256+
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
257257
}
258258

259259
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -264,7 +264,13 @@ class FeedForward : public GGMLBlock {
264264
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
265265

266266
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
267-
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
267+
// The purpose of the scale here is to prevent NaN issues in certain situations.
268+
// For example, when using Vulkan without enabling force_prec_f32,
269+
// or when using CUDA but the weights are k-quants.
270+
float scale = 1.f / 128.f;
271+
x = ggml_scale(ctx, x, scale);
272+
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
273+
x = ggml_scale(ctx, x, 1.f / scale);
268274
return x;
269275
}
270276
};

ggml_extend.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
#define __STATIC_INLINE__ static inline
5757
#endif
5858

59+
#ifndef SD_UNUSED
60+
#define SD_UNUSED(x) (void)(x)
61+
#endif
62+
5963
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
6064
switch (level) {
6165
case GGML_LOG_LEVEL_DEBUG:

0 commit comments

Comments
 (0)