diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 9122fca6cf99f..2fa0e799ce327 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -745,10 +745,14 @@ void launch_fattn( size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { - GGML_ASSERT(ggml_is_contiguously_allocated(K)); - K_f16.alloc(ggml_nelements(K)); + const int64_t n_seq = K->ne[3]; + const int64_t n_eps = (K->nb[3]/ggml_type_size(K->type))*ggml_blck_size(K->type); // elements per sequence + + K_f16.alloc(n_seq*n_eps); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + for (int s = 0; s < n_seq; ++s) { + to_fp16(K_data + s*K->nb[3], K_f16.ptr + s*n_eps, n_eps, main_stream); + } K_data = (char *) K_f16.ptr; const size_t bs = ggml_blck_size(K->type); @@ -760,10 +764,14 @@ void launch_fattn( } if (V && need_f16_V && V->type != GGML_TYPE_F16) { - GGML_ASSERT(ggml_is_contiguously_allocated(V)); - V_f16.alloc(ggml_nelements(V)); + const int64_t n_seq = V->ne[3]; + const int64_t n_eps = (V->nb[3]/ggml_type_size(V->type))*ggml_blck_size(V->type); + + V_f16.alloc(n_seq*n_eps); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + for (int s = 0; s < n_seq; ++s) { + to_fp16(V_data + s*V->nb[3], V_f16.ptr + s*n_eps, n_eps, main_stream); + } V_data = (char *) V_f16.ptr; const size_t bs = ggml_blck_size(V->type); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a0ab5b9257e8c..ecfe99aa3961c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5525,6 +5525,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 3}, 512, 128, true, 0.0f, 0.0f, GGML_PREC_DEFAULT, GGML_TYPE_Q8_0)); + for (int hsk : { 64, 80, 128, 192, 256, 576 }) { for (int hsv : { 64, 80, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 576 && hsk != hsv) continue;