Skip to content

Commit 1c32fa0

Browse files
authored
fix: avoid generating black images when running T5 on the GPU (#882)
1 parent 9727c6b commit 1c32fa0

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

stable-diffusion.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -338,17 +338,7 @@ class StableDiffusionGGML {
338338
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
339339

340340
{
341-
clip_backend = backend;
342-
bool use_t5xxl = false;
343-
if (sd_version_is_dit(version) && !sd_version_is_qwen_image(version)) {
344-
use_t5xxl = true;
345-
}
346-
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
347-
LOG_WARN(
348-
"!!!It appears that you are using the T5 model. Some backends may encounter issues with it."
349-
"If you notice that the generated images are completely black,"
350-
"try running the T5 model on the CPU using the --clip-on-cpu parameter.");
351-
}
341+
clip_backend = backend;
352342
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
353343
LOG_INFO("CLIP: Using CPU backend");
354344
clip_backend = ggml_backend_cpu_init();

t5.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,9 @@ struct T5DenseGatedActDense : public UnaryBlock {
504504
T5DenseGatedActDense(int64_t model_dim, int64_t ff_dim) {
505505
blocks["wi_0"] = std::shared_ptr<GGMLBlock>(new Linear(model_dim, ff_dim, false));
506506
blocks["wi_1"] = std::shared_ptr<GGMLBlock>(new Linear(model_dim, ff_dim, false));
507-
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false));
507+
float scale = 1.f / 32.f;
508+
// The purpose of the scale here is to prevent NaN issues on some backends(CUDA, ...).
509+
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false, false, false, scale));
508510
}
509511

510512
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {

0 commit comments

Comments
 (0)