From 19596b1793918df7c62be28fa8a682ffc5bf66cb Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 5 Sep 2025 11:32:58 +0530 Subject: [PATCH 01/12] CUDA: cov2d with tensor core --- ggml/src/ggml-cuda/conv2d.cu | 328 ++++++++++++++++++++++++++-------- ggml/src/ggml-cuda/conv2d.cuh | 9 +- 2 files changed, 265 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 142dd66903aaa..4914393acab2f 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,6 +1,9 @@ #include "conv2d.cuh" #include "convert.cuh" +#include +using namespace nvcuda; + struct conv_params { const int64_t IW, IH; const int64_t OW, OH; @@ -11,112 +14,292 @@ struct conv_params { const int64_t IC, OC; const int64_t B; const int64_t TOTAL; + // helpers + const int64_t IC_KH_KW, N_OH_OW; }; -struct kernel_bounds { - int64_t y_min, y_max; - int64_t x_min, x_max; +auto ceil_div = [](int a, int b) { + return (a + b - 1) / b; }; -__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { - return (a > b) ? a : b; -} - -__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { - return (a < b) ? a : b; -} - -__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { - kernel_bounds bounds; - bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - return bounds; -} - -__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { +__device__ __forceinline__ static int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int64_t n, + int64_t c, + int64_t y, + int64_t x, + const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int64_t c_out, + int64_t c_in, + int64_t ky, + int64_t kx, + const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int64_t n, + int64_t c, + int64_t y, + int64_t x, + const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ static void unpack_indices(int64_t global_idx, - const conv_params & P, - int64_t & n, - int64_t & c, - int64_t & out_y, - int64_t & out_x) { - out_x = global_idx % P.OW; - out_y = (global_idx / P.OW) % P.OH; - c = (global_idx / (P.OW * P.OH)) % P.OC; - n = global_idx / (P.OW * P.OH * P.OC); + __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, + int64_t & ic, + int64_t & kh, + int64_t & kw, + const conv_params & P) { + ic = idx / (P.KW * P.KH); + int64_t r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; + } + + __device__ __forceinline__ static void unpack_nohow(int64_t idx, + int64_t & n, + int64_t & oh, + int64_t & ow, + const conv_params & P) { + n = idx / (P.OH * P.OW); + int64_t r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; + } +}; + +class float_mma { + public: + float * buf; + + __device__ __forceinline__ float_mma(float * scratch) { + buf = scratch; + const int lane_id = threadIdx.x % warpSize; +#pragma unroll + for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { + buf[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const float * A_sh, const float * B_sh, const int strideA, const int strideB) { + const int lane_id = threadIdx.x % warpSize; +#pragma unroll + for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { + int m = e / WMMA_N; + int n = e % WMMA_N; + float sum = buf[m * WMMA_N + n]; +#pragma unroll + for (int k = 0; k < WMMA_K; k++) { + float a = A_sh[m * strideA + k]; + float b = B_sh[k * strideB + n]; + sum = fmaf(a, b, sum); + } + buf[m * WMMA_N + n] = sum; + } } + + __device__ __forceinline__ float * store_result() const { return buf; } }; -template -static __global__ void conv2d_kernel(const float * __restrict__ input, - const T * __restrict__ kernel, - float * __restrict__ output, - const conv_params P) { - const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; +class half_mma { + private: + wmma::fragment acc; + wmma::fragment a_frag; + wmma::fragment b_frag; + public: + float * buf; + + __device__ __forceinline__ half_mma(float * scratch) { + buf = scratch; + wmma::fill_fragment(acc, 0.0f); + } + + __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { + wmma::load_matrix_sync(a_frag, A_sh, strideA); + wmma::load_matrix_sync(b_frag, B_sh, strideB); + wmma::mma_sync(acc, a_frag, b_frag, acc); + } - if (global_idx >= P.TOTAL) { - return; + __device__ __forceinline__ float * store_result() const { + wmma::store_matrix_sync(buf, acc, WMMA_N, wmma::mem_row_major); + return buf; } +}; + +template +static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) { + extern __shared__ unsigned char smem_raw[]; + + const int64_t OUTPUT_NUMEL = WMMA_M * WMMA_N; + const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + + const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + + const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int64_t tile_id = blockIdx.x; + const int64_t tile_oc = tile_id / NUM_BL_NOHOW; + const int64_t tile_nohow = tile_id % NUM_BL_NOHOW; + const int64_t BLOCK_OC_BASE = tile_oc * BS_OC; + const int64_t BLOCK_NOHOW_BASE = tile_nohow * BS_NOHOW; + + const int64_t laneId = threadIdx.x % WARP_SIZE; + const int64_t warpId = threadIdx.x / WARP_SIZE; + + const int64_t WARP_OC = warpId / WARPS_PER_NOHOW; + const int64_t WARP_NOHOW = warpId % WARPS_PER_NOHOW; - int64_t n, c_out, out_y, out_x; - Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; - float acc = 0.0f; + unsigned char * ptr = smem_raw; + T * A_sh = reinterpret_cast(ptr); - for (int64_t c_in = 0; c_in < P.IC; ++c_in) { - kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + size_t offsetA = BS_OC * BS_ICKHKW * sizeof(T); + ptr += offsetA; - for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { - const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); + T * B_sh = reinterpret_cast(ptr); + ptr += BS_ICKHKW * BS_NOHOW * sizeof(T); - for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { - const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + float * shared_scratch = reinterpret_cast(ptr); + float * warp_scratch = shared_scratch + warpId * (WMMA_M * WMMA_N); - const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; - const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; - acc += (input_val * ggml_cuda_cast(kernel_val)); + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + + mma acc(warp_scratch); + + const int64_t A_total = BS_OC * BS_ICKHKW; + const int64_t B_total = BS_ICKHKW * BS_NOHOW; + +#pragma unroll + for (int64_t t = 0; t < NUM_IC_TILES; ++t) { +#pragma unroll + for (int64_t tid = (threadIdx.x); tid < A_total; tid += blockDim.x) { + const int row = tid / BS_ICKHKW; + const int col = tid % BS_ICKHKW; + + int64_t shared_oc = BLOCK_OC_BASE + row; + int64_t shared_ickhkw = t * BS_ICKHKW + col; + + T val = ggml_cuda_cast(0); + if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { + int64_t ic, kh, kw; + layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); + + const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } + A_sh[row * BS_ICKHKW + col] = val; + } + +#pragma unroll + for (int64_t tid = (threadIdx.x); tid < B_total; tid += blockDim.x) { + const int brow = tid / BS_NOHOW; + const int bcol = tid % BS_NOHOW; + + int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow; + int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + + T val = ggml_cuda_cast(0); + if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { + int64_t n, oh, ow; + layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); + int64_t ic, kh, kw; + layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); + int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); + int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); + if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) { + const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P); + val = ggml_cuda_cast(IN[in_idx]); + } + } + B_sh[brow * BS_NOHOW + bcol] = val; + } + + __syncthreads(); + +#pragma unroll + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + const T * A_k_ptr = A_warp_base + k_tile; + const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; + + acc.mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); } + __syncthreads(); } - // [N, OC, OH, OW] - output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; + const float * out_buf = acc.store_result(); +#pragma unroll + for (int e = laneId; e < OUTPUT_NUMEL; e += WARP_SIZE) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < (P.N_OH_OW)) { + int64_t n, oh, ow; + layout::unpack_nohow(nohow, n, oh, ow, P); + const int64_t out_idx = layout::output_index(n, oc, oh, ow, P); + OUT[out_idx] = out_buf[e]; + } + } } -template -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; - conv2d_kernel<<>>(X_D, K_D, Y_D, P); +template +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, conv_params P, cudaStream_t st) + +{ + const int64_t NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + + int64_t TOTAL_TILES = NUM_BL_OC * NUM_BL_NOHOW; + TOTAL_TILES = std::min(TOTAL_TILES, (int64_t) INT_MAX); + + const int WARPS_PER_OC = std::max(1, ceil_div(BS_OC, WMMA_M)); + const int WARPS_PER_NOHOW = std::max(1, ceil_div(BS_NOHOW, WMMA_N)); + const int EXPECTED_WARPS = WARPS_PER_OC * WARPS_PER_NOHOW; + int N_THREADS = EXPECTED_WARPS * WARP_SIZE; + + const int MAX_TPB = 1024; + if (N_THREADS > MAX_TPB) { + N_THREADS = (MAX_TPB / WARP_SIZE) * WARP_SIZE; + } + + if (N_THREADS < WARP_SIZE) { + N_THREADS = WARP_SIZE; + } + + const int N_WARPS = N_THREADS / WARP_SIZE; + + // scratch_buff to store output, can't store directly using wmma, + // output mapping is unknown + const int64_t scratch_bytes = N_WARPS * (WMMA_M * WMMA_N) * sizeof(float); + + const int64_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); + const int64_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); + const int64_t shared_bytes = A_bytes + B_bytes + scratch_bytes; + + dim3 grid(TOTAL_TILES, 1, 1); + conv2d_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, conv_params & P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, conv_params & P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -155,11 +338,14 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - const int64_t total = B * OC * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + const int64_t TOTAL = B * OC * OH * OW; + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t N_OH_OW = B * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, + PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW }; if (kernel->type == GGML_TYPE_F16) { - conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); } else { conv2d_cuda_f32(X_D, K_D, Y_D, params, st); } diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index ce4802c7ed797..ccf5b6192ed08 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,5 +1,12 @@ #pragma once #include "common.cuh" -#define CUDA_CONV2D_BLOCK_SIZE 256 +constexpr int BS_OC = 128; +constexpr int BS_ICKHKW = 16; +constexpr int BS_NOHOW = 128; + +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; + void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 96db6275398f6a9301c206d1ffdf302dae853a90 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 5 Sep 2025 11:50:52 +0530 Subject: [PATCH 02/12] CUDA: conv2d added comment --- ggml/src/ggml-cuda/conv2d.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index ccf5b6192ed08..28c8b9bab6f98 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -5,6 +5,8 @@ constexpr int BS_OC = 128; constexpr int BS_ICKHKW = 16; constexpr int BS_NOHOW = 128; +// supported configuration +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes constexpr int WMMA_M = 16; constexpr int WMMA_N = 16; constexpr int WMMA_K = 16; From 2cd9fb0f56441fff24d67a48c60e1f9432612d31 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 5 Sep 2025 13:31:01 +0530 Subject: [PATCH 03/12] CUDA: conv2d support fp16 without wmma * removed flash-attenion definition --- ggml/src/ggml-cuda/conv2d.cu | 54 ++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 4914393acab2f..9802883f5aeda 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,9 +1,19 @@ #include "conv2d.cuh" #include "convert.cuh" -#include -using namespace nvcuda; - +#ifdef FP16_MMA_AVAILABLE +# if !defined(GGML_USE_HIP) +# include +# ifdef GGML_USE_MUSA +namespace wmma = mtmusa::wmma; +# else +namespace wmma = nvcuda::wmma; +# endif +# else +# include +namespace wmma = rocwmma; +# endif +#endif struct conv_params { const int64_t IW, IH; const int64_t OW, OH; @@ -111,6 +121,8 @@ class float_mma { __device__ __forceinline__ float * store_result() const { return buf; } }; +#if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE))) + class half_mma { private: wmma::fragment acc; @@ -136,6 +148,42 @@ class half_mma { } }; +#else + +class half_mma { + public: + float * buf; + + __device__ __forceinline__ half_mma(float * scratch) { + buf = scratch; + const int lane_id = threadIdx.x % warpSize; +# pragma unroll + for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { + buf[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { + const int lane_id = threadIdx.x % warpSize; +# pragma unroll + for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { + int m = e / WMMA_N; + int n = e % WMMA_N; + float sum = buf[m * WMMA_N + n]; +# pragma unroll + for (int k = 0; k < WMMA_K; k++) { + float a = A_sh[m * strideA + k]; + float b = B_sh[k * strideB + n]; + sum = fmaf(__half2float(a), __half2float(b), sum); + } + buf[m * WMMA_N + n] = sum; + } + } + + __device__ __forceinline__ float * store_result() const { return buf; } +}; +#endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) + template static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) { extern __shared__ unsigned char smem_raw[]; From d633cee19ced956a088ca3313dbc31b8698a0e16 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 12 Sep 2025 09:52:02 +0530 Subject: [PATCH 04/12] CUDA: conv2d using mma.cuh --- ggml/src/ggml-cuda/conv2d.cu | 381 ++++++++++++++++++++-------------- ggml/src/ggml-cuda/conv2d.cuh | 14 +- 2 files changed, 230 insertions(+), 165 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 9802883f5aeda..99799ac6db6f8 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,19 +1,6 @@ #include "conv2d.cuh" #include "convert.cuh" -#ifdef FP16_MMA_AVAILABLE -# if !defined(GGML_USE_HIP) -# include -# ifdef GGML_USE_MUSA -namespace wmma = mtmusa::wmma; -# else -namespace wmma = nvcuda::wmma; -# endif -# else -# include -namespace wmma = rocwmma; -# endif -#endif struct conv_params { const int64_t IW, IH; const int64_t OW, OH; @@ -28,10 +15,6 @@ struct conv_params { const int64_t IC_KH_KW, N_OH_OW; }; -auto ceil_div = [](int a, int b) { - return (a + b - 1) / b; -}; - __device__ __forceinline__ static int calculate_input_coord(int64_t out_coord, int64_t kern_coord, int64_t stride, @@ -88,151 +71,227 @@ struct whcn_layout { } }; -class float_mma { +template class float_mma { public: - float * buf; + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - __device__ __forceinline__ float_mma(float * scratch) { - buf = scratch; - const int lane_id = threadIdx.x % warpSize; + float acc[num_acc]; + + __device__ __forceinline__ float_mma() { #pragma unroll - for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { - buf[i] = 0.0f; + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; } } - __device__ __forceinline__ void mma(const float * A_sh, const float * B_sh, const int strideA, const int strideB) { - const int lane_id = threadIdx.x % warpSize; + __device__ __forceinline__ void clear() { #pragma unroll - for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { - int m = e / WMMA_N; - int n = e % WMMA_N; - float sum = buf[m * WMMA_N + n]; + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const float * __restrict__ A_sh, + const float * __restrict__ B_sh, + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + #pragma unroll for (int k = 0; k < WMMA_K; k++) { - float a = A_sh[m * strideA + k]; - float b = B_sh[k * strideB + n]; - sum = fmaf(a, b, sum); + const float a = A_sh[m * strideA + k]; + const float b = B_sh[k * strideB + n]; + acc[i] = fmaf(a, b, acc[i]); } - buf[m * WMMA_N + n] = sum; } } - __device__ __forceinline__ float * store_result() const { return buf; } + __device__ __forceinline__ void store_result(const int64_t OC_BASE, + const int64_t NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { + const int lane_id = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < P.N_OH_OW) { + int64_t n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); + OUT[out_idx] = acc[i]; + } + } + } }; #if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE))) +# include "mma.cuh" +using namespace ggml_cuda_mma; + +typedef ggml_cuda_mma::tile tile_a; +typedef ggml_cuda_mma::tile tile_b; +typedef ggml_cuda_mma::tile tile_acc; -class half_mma { +template class half_mma { private: - wmma::fragment acc; - wmma::fragment a_frag; - wmma::fragment b_frag; + tile_a a_frag; + tile_b b_frag; + tile_acc c_frag; public: - float * buf; + __device__ __forceinline__ half_mma() {} - __device__ __forceinline__ half_mma(float * scratch) { - buf = scratch; - wmma::fill_fragment(acc, 0.0f); + __device__ __forceinline__ void clear() { +# pragma unroll + for (int l = 0; l < c_frag.ne; ++l) { + c_frag.x[l] = 0.0f; + } } __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { - wmma::load_matrix_sync(a_frag, A_sh, strideA); - wmma::load_matrix_sync(b_frag, B_sh, strideB); - wmma::mma_sync(acc, a_frag, b_frag, acc); + ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); + ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); + ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ float * store_result() const { - wmma::store_matrix_sync(buf, acc, WMMA_N, wmma::mem_row_major); - return buf; + __device__ __forceinline__ void store_result(const int64_t OC_BASE, + const int64_t NOHOW_BASE, + float * OUT, + const conv_params & P) const { +# pragma unroll + for (int l = 0; l < tile_acc::ne; ++l) { + const int64_t e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < (P.N_OH_OW)) { + int64_t n, oh, ow; + layout::unpack_nohow(nohow, n, oh, ow, P); + OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l]; + } + } } }; #else -class half_mma { +template class half_mma { public: - float * buf; + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + + float acc[num_acc]; - __device__ __forceinline__ half_mma(float * scratch) { - buf = scratch; - const int lane_id = threadIdx.x % warpSize; + __device__ __forceinline__ half_mma() { # pragma unroll - for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { - buf[i] = 0.0f; + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; } } - __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { - const int lane_id = threadIdx.x % warpSize; + __device__ __forceinline__ void clear() { +# pragma unroll + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const half * __restrict__ A_sh, + const half * __restrict__ B_sh, + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; + # pragma unroll - for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { - int m = e / WMMA_N; - int n = e % WMMA_N; - float sum = buf[m * WMMA_N + n]; + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + # pragma unroll for (int k = 0; k < WMMA_K; k++) { - float a = A_sh[m * strideA + k]; - float b = B_sh[k * strideB + n]; - sum = fmaf(__half2float(a), __half2float(b), sum); + const half a = A_sh[m * strideA + k]; + const half b = B_sh[k * strideB + n]; + acc[i] = fmaf(__half2float(a), __half2float(b), acc[i]); } - buf[m * WMMA_N + n] = sum; } } - __device__ __forceinline__ float * store_result() const { return buf; } + __device__ __forceinline__ void store_result(const int64_t OC_BASE, + const int64_t NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { + const int lane_id = threadIdx.x % WARP_SIZE; + +# pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < P.N_OH_OW) { + int64_t n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); + OUT[out_idx] = acc[i]; + } + } + } }; + #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) -template -static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) { +template +__global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const int64_t OUTPUT_NUMEL = WMMA_M * WMMA_N; const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int64_t warpId = threadIdx.y; - const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int64_t total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + const int64_t num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; - const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int64_t tile_id = blockIdx.x; - const int64_t tile_oc = tile_id / NUM_BL_NOHOW; - const int64_t tile_nohow = tile_id % NUM_BL_NOHOW; - const int64_t BLOCK_OC_BASE = tile_oc * BS_OC; - const int64_t BLOCK_NOHOW_BASE = tile_nohow * BS_NOHOW; + mma acc[num_work_per_warps]; - const int64_t laneId = threadIdx.x % WARP_SIZE; - const int64_t warpId = threadIdx.x / WARP_SIZE; + const int64_t num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int64_t BL_IDX_OC = blockIdx.x / num_block_nohow; + const int64_t BL_IDX_NOHOW = blockIdx.x % num_block_nohow; - const int64_t WARP_OC = warpId / WARPS_PER_NOHOW; - const int64_t WARP_NOHOW = warpId % WARPS_PER_NOHOW; + const int64_t BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int64_t BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; - const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + unsigned char * ptr = smem_raw; - unsigned char * ptr = smem_raw; - T * A_sh = reinterpret_cast(ptr); + const int64_t A_total = BS_OC * BS_ICKHKW; + const int64_t B_total = BS_ICKHKW * BS_NOHOW; - size_t offsetA = BS_OC * BS_ICKHKW * sizeof(T); + size_t offsetA = (size_t) A_total * sizeof(T); + T * A_sh = reinterpret_cast(ptr); ptr += offsetA; - T * B_sh = reinterpret_cast(ptr); - ptr += BS_ICKHKW * BS_NOHOW * sizeof(T); - - float * shared_scratch = reinterpret_cast(ptr); - float * warp_scratch = shared_scratch + warpId * (WMMA_M * WMMA_N); - - const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; - const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + size_t offsetB = (size_t) B_total * sizeof(T); + T * B_sh = reinterpret_cast(ptr); + ptr += offsetB; - mma acc(warp_scratch); - - const int64_t A_total = BS_OC * BS_ICKHKW; - const int64_t B_total = BS_ICKHKW * BS_NOHOW; - -#pragma unroll + int64_t ic, kh, kw; + int64_t n, oh, ow; for (int64_t t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (int64_t tid = (threadIdx.x); tid < A_total; tid += blockDim.x) { + for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { const int row = tid / BS_ICKHKW; const int col = tid % BS_ICKHKW; @@ -241,7 +300,6 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { - int64_t ic, kh, kw; layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); @@ -249,9 +307,8 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT } A_sh[row * BS_ICKHKW + col] = val; } - #pragma unroll - for (int64_t tid = (threadIdx.x); tid < B_total; tid += blockDim.x) { + for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { const int brow = tid / BS_NOHOW; const int bcol = tid % BS_NOHOW; @@ -260,9 +317,7 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { - int64_t n, oh, ow; layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); - int64_t ic, kh, kw; layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); @@ -277,76 +332,88 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT __syncthreads(); #pragma unroll - for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { - const T * A_k_ptr = A_warp_base + k_tile; - const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; - - acc.mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); + for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { + const int64_t WARP_OC = warp / WARPS_PER_NOHOW; + const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; +#pragma unroll + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + const T * A_k_ptr = A_warp_base + k_tile; + const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; + acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); + } } __syncthreads(); } - const float * out_buf = acc.store_result(); #pragma unroll - for (int e = laneId; e < OUTPUT_NUMEL; e += WARP_SIZE) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; - - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; - - if (oc < P.OC && nohow < (P.N_OH_OW)) { - int64_t n, oh, ow; - layout::unpack_nohow(nohow, n, oh, ow, P); - const int64_t out_idx = layout::output_index(n, oc, oh, ow, P); - OUT[out_idx] = out_buf[e]; - } + for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { + const int64_t WARP_OC = warp / WARPS_PER_NOHOW; + const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; + const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } -template -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, conv_params P, cudaStream_t st) - -{ - const int64_t NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; - const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; +template class mma> +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int warp_size = 32; + const int max_block_size = 256; - int64_t TOTAL_TILES = NUM_BL_OC * NUM_BL_NOHOW; - TOTAL_TILES = std::min(TOTAL_TILES, (int64_t) INT_MAX); + GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); - const int WARPS_PER_OC = std::max(1, ceil_div(BS_OC, WMMA_M)); - const int WARPS_PER_NOHOW = std::max(1, ceil_div(BS_NOHOW, WMMA_N)); - const int EXPECTED_WARPS = WARPS_PER_OC * WARPS_PER_NOHOW; - int N_THREADS = EXPECTED_WARPS * WARP_SIZE; + const int num_block_oc = (P.OC + BS_OC - 1) / BS_OC; + const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int num_blocks = num_block_oc * num_block_nohow; - const int MAX_TPB = 1024; - if (N_THREADS > MAX_TPB) { - N_THREADS = (MAX_TPB / WARP_SIZE) * WARP_SIZE; + int nwarps_best = 1; + int niter_best = (BS_OC * BS_NOHOW + warp_size - 1) / (warp_size); + for (int nwarps = 2; nwarps <= max_block_size / warp_size; ++nwarps) { + const int niter = (BS_OC * BS_NOHOW + nwarps * warp_size - 1) / (nwarps * warp_size); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } } - if (N_THREADS < WARP_SIZE) { - N_THREADS = WARP_SIZE; + const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); + const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); + const size_t shared_bytes = A_bytes + B_bytes; + + dim3 grid(num_blocks, 1, 1); + dim3 block(warp_size, nwarps_best); + + switch (nwarps_best) { + case 1: + conv2d_kernel, 1><<>>(X_D, K_D, Y_D, P); + break; + case 2: + conv2d_kernel, 2><<>>(X_D, K_D, Y_D, P); + break; + case 4: + conv2d_kernel, 4><<>>(X_D, K_D, Y_D, P); + break; + case 8: + conv2d_kernel, 8><<>>(X_D, K_D, Y_D, P); + break; + case 16: + conv2d_kernel, 16><<>>(X_D, K_D, Y_D, P); + break; + case 32: + conv2d_kernel, 32><<>>(X_D, K_D, Y_D, P); + break; + default: + GGML_ABORT("UNSUPPROTED NWARPS_BEST"); } - - const int N_WARPS = N_THREADS / WARP_SIZE; - - // scratch_buff to store output, can't store directly using wmma, - // output mapping is unknown - const int64_t scratch_bytes = N_WARPS * (WMMA_M * WMMA_N) * sizeof(float); - - const int64_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); - const int64_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); - const int64_t shared_bytes = A_bytes + B_bytes + scratch_bytes; - - dim3 grid(TOTAL_TILES, 1, 1); - conv2d_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, conv_params & P, cudaStream_t st) { +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { conv2d_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, conv_params & P, cudaStream_t st) { +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { conv2d_cuda(X_D, K_D, Y_D, P, st); } diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index 28c8b9bab6f98..a1de712b54a66 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,14 +1,12 @@ #pragma once #include "common.cuh" -constexpr int BS_OC = 128; -constexpr int BS_ICKHKW = 16; -constexpr int BS_NOHOW = 128; +#define BS_OC 64 +#define BS_ICKHKW 16 +#define BS_NOHOW 64 -// supported configuration -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 16; +#define WMMA_M 16 +#define WMMA_N 16 +#define WMMA_K 16 void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From ac5e0c023c5ff96c9946f82ef33204bcf6c3a45e Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 12 Sep 2025 13:30:04 +0530 Subject: [PATCH 05/12] CUDA: conv2d convert int64_t to int --- ggml/src/ggml-cuda/conv2d.cu | 173 ++++++++++++++++------------------- 1 file changed, 81 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 99799ac6db6f8..f9cdd7786069d 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -2,56 +2,44 @@ #include "convert.cuh" struct conv_params { - const int64_t IW, IH; - const int64_t OW, OH; - const int64_t KW, KH; - const int64_t ST_X, ST_Y; - const int64_t PD_X, PD_Y; - const int64_t DL_X, DL_Y; - const int64_t IC, OC; - const int64_t B; + const int IW, IH; + const int OW, OH; + const int KW, KH; + const int ST_X, ST_Y; + const int PD_X, PD_Y; + const int DL_X, DL_Y; + const int IC, OC; + const int B; const int64_t TOTAL; // helpers - const int64_t IC_KH_KW, N_OH_OW; + const int IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ static int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { +__device__ __forceinline__ static int calculate_input_coord(int out_coord, + int kern_coord, + int stride, + int dilation, + int padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static int64_t input_index(int64_t n, - int64_t c, - int64_t y, - int64_t x, - const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static int64_t kernel_index(int64_t c_out, - int64_t c_in, - int64_t ky, - int64_t kx, - const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static int64_t output_index(int64_t n, - int64_t c, - int64_t y, - int64_t x, - const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, - int64_t & ic, - int64_t & kh, - int64_t & kw, + int & ic, + int & kh, + int & kw, const conv_params & P) { ic = idx / (P.KW * P.KH); int64_t r = idx - ic * (P.KW * P.KH); @@ -60,9 +48,9 @@ struct whcn_layout { } __device__ __forceinline__ static void unpack_nohow(int64_t idx, - int64_t & n, - int64_t & oh, - int64_t & ow, + int & n, + int & oh, + int & ow, const conv_params & P) { n = idx / (P.OH * P.OW); int64_t r = idx - n * (P.OH * P.OW); @@ -111,8 +99,8 @@ template class float_mma { } } - __device__ __forceinline__ void store_result(const int64_t OC_BASE, - const int64_t NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { const int lane_id = threadIdx.x % WARP_SIZE; @@ -122,14 +110,13 @@ template class float_mma { const int m = e / WMMA_N; const int n = e % WMMA_N; - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int64_t n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); - const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); - OUT[out_idx] = acc[i]; + OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } } } @@ -158,27 +145,30 @@ template class half_mma { } } - __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { + __device__ __forceinline__ void mma(const half * __restrict__ A_sh, + const half * __restrict__ B_sh, + const int strideA, + const int strideB) { ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ void store_result(const int64_t OC_BASE, - const int64_t NOHOW_BASE, - float * OUT, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, + float * __restrict__ OUT, const conv_params & P) const { # pragma unroll for (int l = 0; l < tile_acc::ne; ++l) { - const int64_t e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - const int m = e / WMMA_N; - const int n = e % WMMA_N; + const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < (P.N_OH_OW)) { - int64_t n, oh, ow; + int n, oh, ow; layout::unpack_nohow(nohow, n, oh, ow, P); OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l]; } @@ -228,8 +218,8 @@ template class half_mma { } } - __device__ __forceinline__ void store_result(const int64_t OC_BASE, - const int64_t NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { const int lane_id = threadIdx.x % WARP_SIZE; @@ -239,14 +229,13 @@ template class half_mma { const int m = e / WMMA_N; const int n = e % WMMA_N; - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int64_t n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); - const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); - OUT[out_idx] = acc[i]; + OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } } } @@ -258,26 +247,26 @@ template __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const int64_t warpId = threadIdx.y; + const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int warpId = threadIdx.y; - const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); - const int64_t total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - const int64_t num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; + const int WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + const int num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; mma acc[num_work_per_warps]; - const int64_t num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int64_t BL_IDX_OC = blockIdx.x / num_block_nohow; - const int64_t BL_IDX_NOHOW = blockIdx.x % num_block_nohow; + const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int BL_IDX_OC = blockIdx.x / num_block_nohow; + const int BL_IDX_NOHOW = blockIdx.x % num_block_nohow; - const int64_t BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const int64_t BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; unsigned char * ptr = smem_raw; - const int64_t A_total = BS_OC * BS_ICKHKW; - const int64_t B_total = BS_ICKHKW * BS_NOHOW; + const int A_total = BS_OC * BS_ICKHKW; + const int B_total = BS_ICKHKW * BS_NOHOW; size_t offsetA = (size_t) A_total * sizeof(T); T * A_sh = reinterpret_cast(ptr); @@ -287,33 +276,33 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const T * B_sh = reinterpret_cast(ptr); ptr += offsetB; - int64_t ic, kh, kw; - int64_t n, oh, ow; - for (int64_t t = 0; t < NUM_IC_TILES; ++t) { + int ic, kh, kw; + int n, oh, ow; + for (int t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { + for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { const int row = tid / BS_ICKHKW; const int col = tid % BS_ICKHKW; - int64_t shared_oc = BLOCK_OC_BASE + row; - int64_t shared_ickhkw = t * BS_ICKHKW + col; + int shared_oc = BLOCK_OC_BASE + row; + int shared_ickhkw = t * BS_ICKHKW + col; T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; + const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } A_sh[row * BS_ICKHKW + col] = val; } #pragma unroll - for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { + for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { const int brow = tid / BS_NOHOW; const int bcol = tid % BS_NOHOW; - int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow; - int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + int IC_KH_KW_IDX = t * BS_ICKHKW + brow; + int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { @@ -333,10 +322,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const #pragma unroll for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int64_t WARP_OC = warp / WARPS_PER_NOHOW; - const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; - const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; - const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + const int WARP_OC = warp / WARPS_PER_NOHOW; + const int WARP_NOHOW = warp % WARPS_PER_NOHOW; + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; #pragma unroll for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { const T * A_k_ptr = A_warp_base + k_tile; @@ -349,10 +338,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const #pragma unroll for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int64_t WARP_OC = warp / WARPS_PER_NOHOW; - const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; - const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + const int WARP_OC = warp / WARPS_PER_NOHOW; + const int WARP_NOHOW = warp % WARPS_PER_NOHOW; + const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } @@ -454,8 +443,8 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int B = input->ne[3]; // n_batches const int64_t TOTAL = B * OC * OH * OW; - const int64_t IC_KH_KW = IC * KH * KW; - const int64_t N_OH_OW = B * OH * OW; + const int IC_KH_KW = IC * KH * KW; + const int N_OH_OW = B * OH * OW; conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW }; From 410171ae113962a1cf687f85052fe855ac6cc551 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Sat, 13 Sep 2025 15:02:07 +0530 Subject: [PATCH 06/12] CUDA: conv2d update block size --- ggml/src/ggml-cuda/conv2d.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index a1de712b54a66..3dcce2b4a2e3b 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,9 +1,9 @@ #pragma once #include "common.cuh" -#define BS_OC 64 +#define BS_OC 16 #define BS_ICKHKW 16 -#define BS_NOHOW 64 +#define BS_NOHOW 128 #define WMMA_M 16 #define WMMA_N 16 From 51f85ff57ad1c82fb69d0cd1faac92bfa95b8763 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Tue, 16 Sep 2025 03:22:36 +0530 Subject: [PATCH 07/12] CUDA: conv2d performance optimization --- ggml/src/ggml-cuda/conv2d.cu | 383 ++++++++++++++++------------------ ggml/src/ggml-cuda/conv2d.cuh | 6 +- 2 files changed, 185 insertions(+), 204 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index f9cdd7786069d..db92a40cd0dd0 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,97 +1,98 @@ #include "conv2d.cuh" #include "convert.cuh" +#include + struct conv_params { - const int IW, IH; - const int OW, OH; - const int KW, KH; - const int ST_X, ST_Y; - const int PD_X, PD_Y; - const int DL_X, DL_Y; - const int IC, OC; - const int B; - const int64_t TOTAL; + const uint IW, IH; + const uint OW, OH; + const uint KW, KH; + const uint ST_X, ST_Y; + const uint PD_X, PD_Y; + const uint DL_X, DL_Y; + const uint IC, OC; + const uint B; // helpers - const int IC_KH_KW, N_OH_OW; + const uint IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ static int calculate_input_coord(int out_coord, - int kern_coord, - int stride, - int dilation, - int padding) { +__device__ __forceinline__ static uint64_t calculate_input_coord(uint out_coord, + uint kern_coord, + uint stride, + uint dilation, + uint padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { + __device__ __forceinline__ static uint64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { + __device__ __forceinline__ static uint64_t kernel_index(uint c_out, + uint c_in, + uint ky, + uint kx, + const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { + __device__ __forceinline__ static uint64_t output_index(uint n, uint c, uint y, uint x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, - int & ic, - int & kh, - int & kw, + __device__ __forceinline__ static void unpack_ickhkw(uint64_t idx, + uint & ic, + uint & kh, + uint & kw, const conv_params & P) { - ic = idx / (P.KW * P.KH); - int64_t r = idx - ic * (P.KW * P.KH); - kh = r / P.KW; - kw = r - kh * P.KW; + ic = idx / (P.KW * P.KH); + uint r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; } - __device__ __forceinline__ static void unpack_nohow(int64_t idx, - int & n, - int & oh, - int & ow, + __device__ __forceinline__ static void unpack_nohow(uint64_t idx, + uint & n, + uint & oh, + uint & ow, const conv_params & P) { - n = idx / (P.OH * P.OW); - int64_t r = idx - n * (P.OH * P.OW); - oh = r / P.OW; - ow = r - oh * P.OW; + n = idx / (P.OH * P.OW); + uint r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; } }; -template class float_mma { - public: - static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - - float acc[num_acc]; +template __device__ class float_mma { + private: + static constexpr uint num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + // for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] ... [14,0] + // lane 1 will store and compute for [0,1], [2,1], [4,1] ... [14,1] + float acc[num_acc]; + public: __device__ __forceinline__ float_mma() { #pragma unroll - for (int i = 0; i < num_acc; i++) { - acc[i] = 0.0f; - } - } - - __device__ __forceinline__ void clear() { -#pragma unroll - for (int i = 0; i < num_acc; i++) { + for (uint i = 0; i < num_acc; i++) { acc[i] = 0.0f; } } __device__ __forceinline__ void mma(const float * __restrict__ A_sh, const float * __restrict__ B_sh, - const int strideA, - const int strideB) { - const int lane_id = threadIdx.x % WARP_SIZE; + const uint strideA, + const uint strideB) { + const uint lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint i = 0; i < num_acc; i++) { + const uint e = lane_id + i * WARP_SIZE; + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; #pragma unroll - for (int k = 0; k < WMMA_K; k++) { + for (uint k = 0; k < WMMA_K; k++) { const float a = A_sh[m * strideA + k]; const float b = B_sh[k * strideB + n]; acc[i] = fmaf(a, b, acc[i]); @@ -99,22 +100,23 @@ template class float_mma { } } - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, + __device__ __forceinline__ void store_result(const uint OC_BASE, + const uint NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const int lane_id = threadIdx.x % WARP_SIZE; + const uint lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint i = 0; i < num_acc; i++) { + const uint e = lane_id + i * WARP_SIZE; + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + const uint oc = OC_BASE + m; + const uint nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int n_, oh, ow; + uint n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -126,15 +128,16 @@ template class float_mma { # include "mma.cuh" using namespace ggml_cuda_mma; -typedef ggml_cuda_mma::tile tile_a; -typedef ggml_cuda_mma::tile tile_b; -typedef ggml_cuda_mma::tile tile_acc; +typedef tile tile_a; +typedef tile tile_b; +typedef tile tile_acc; template class half_mma { private: tile_a a_frag; tile_b b_frag; tile_acc c_frag; + public: __device__ __forceinline__ half_mma() {} @@ -147,30 +150,30 @@ template class half_mma { __device__ __forceinline__ void mma(const half * __restrict__ A_sh, const half * __restrict__ B_sh, - const int strideA, - const int strideB) { - ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); - ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); + const uint strideA, + const uint strideB) { + load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); + load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, + __device__ __forceinline__ void store_result(const uint OC_BASE, + const uint NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { # pragma unroll - for (int l = 0; l < tile_acc::ne; ++l) { - const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint l = 0; l < tile_acc::ne; ++l) { + const uint e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + const uint oc = OC_BASE + m; + const uint nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < (P.N_OH_OW)) { - int n, oh, ow; - layout::unpack_nohow(nohow, n, oh, ow, P); - OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l]; + uint n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + OUT[layout::output_index(n_, oc, oh, ow, P)] = c_frag.x[l]; } } } @@ -181,8 +184,8 @@ template class half_mma { template class half_mma { public: static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - - float acc[num_acc]; + // eg. for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] .. [14,0] + float acc[num_acc]; __device__ __forceinline__ half_mma() { # pragma unroll @@ -191,13 +194,6 @@ template class half_mma { } } - __device__ __forceinline__ void clear() { -# pragma unroll - for (int i = 0; i < num_acc; i++) { - acc[i] = 0.0f; - } - } - __device__ __forceinline__ void mma(const half * __restrict__ A_sh, const half * __restrict__ B_sh, const int strideA, @@ -218,22 +214,22 @@ template class half_mma { } } - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, + __device__ __forceinline__ void store_result(const uint OC_BASE, + const uint NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const int lane_id = threadIdx.x % WARP_SIZE; + const uint lane_id = threadIdx.x % WARP_SIZE; # pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + const uint oc = OC_BASE + m; + const uint nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int n_, oh, ow; + uint n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -243,30 +239,35 @@ template class half_mma { #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) -template -__global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) { +template class mma, int num_warps> +__global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const float * __restrict__ IN, + const T * __restrict__ IK, + float * __restrict__ Out, + const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const int warpId = threadIdx.y; + const uint warpId = threadIdx.y; + const uint linear_tid = threadIdx.y * blockDim.x + threadIdx.x; - const int WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); - const int total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - const int num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; + const uint NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const uint NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); + const uint NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - mma acc[num_work_per_warps]; + const uint NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; - const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int BL_IDX_OC = blockIdx.x / num_block_nohow; - const int BL_IDX_NOHOW = blockIdx.x % num_block_nohow; + mma acc[NUM_TILES_PER_WARP]; - const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const uint BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; + const uint BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; + + const uint BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const uint BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; unsigned char * ptr = smem_raw; - const int A_total = BS_OC * BS_ICKHKW; - const int B_total = BS_ICKHKW * BS_NOHOW; + const uint A_total = BS_OC * BS_ICKHKW; + const uint B_total = BS_ICKHKW * BS_NOHOW; size_t offsetA = (size_t) A_total * sizeof(T); T * A_sh = reinterpret_cast(ptr); @@ -276,40 +277,41 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const T * B_sh = reinterpret_cast(ptr); ptr += offsetB; - int ic, kh, kw; - int n, oh, ow; - for (int t = 0; t < NUM_IC_TILES; ++t) { + for (uint t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { - const int row = tid / BS_ICKHKW; - const int col = tid % BS_ICKHKW; + for (uint tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { + const uint row = tid / BS_ICKHKW; + const uint col = tid % BS_ICKHKW; - int shared_oc = BLOCK_OC_BASE + row; - int shared_ickhkw = t * BS_ICKHKW + col; + const uint shared_oc = BLOCK_OC_BASE + row; + const uint shared_ickhkw = t * BS_ICKHKW + col; T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { + uint ic, kh, kw; layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; + const uint kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } A_sh[row * BS_ICKHKW + col] = val; } #pragma unroll - for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { - const int brow = tid / BS_NOHOW; - const int bcol = tid % BS_NOHOW; + for (uint tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { + const uint brow = tid / BS_NOHOW; + const uint bcol = tid % BS_NOHOW; - int IC_KH_KW_IDX = t * BS_ICKHKW + brow; - int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + const uint IC_KH_KW_IDX = t * BS_ICKHKW + brow; + const uint N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { + uint n, oh, ow; + uint ic, kh, kw; layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); - int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); - int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); + const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); + const int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) { const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P); val = ggml_cuda_cast(IN[in_idx]); @@ -321,13 +323,19 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const __syncthreads(); #pragma unroll - for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int WARP_OC = warp / WARPS_PER_NOHOW; - const int WARP_NOHOW = warp % WARPS_PER_NOHOW; + for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { + const uint warp = warpId + i * num_warps; + if (warp >= NUM_WARPS_NEED) { + continue; + } + const uint WARP_OC = warp / NUM_WARPS_NOHOW; + const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + #pragma unroll - for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + for (uint k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { const T * A_k_ptr = A_warp_base + k_tile; const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); @@ -337,65 +345,37 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const } #pragma unroll - for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int WARP_OC = warp / WARPS_PER_NOHOW; - const int WARP_NOHOW = warp % WARPS_PER_NOHOW; - const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { + const uint warp = warpId + i * num_warps; + if (warp >= NUM_WARPS_NEED) { + continue; + } + const uint WARP_OC = warp / NUM_WARPS_NOHOW; + const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const uint OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const uint NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } template class mma> static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int warp_size = 32; - const int max_block_size = 256; - GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); - const int num_block_oc = (P.OC + BS_OC - 1) / BS_OC; - const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int num_blocks = num_block_oc * num_block_nohow; - - int nwarps_best = 1; - int niter_best = (BS_OC * BS_NOHOW + warp_size - 1) / (warp_size); - for (int nwarps = 2; nwarps <= max_block_size / warp_size; ++nwarps) { - const int niter = (BS_OC * BS_NOHOW + nwarps * warp_size - 1) / (nwarps * warp_size); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } + const uint NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const uint NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; + + constexpr uint NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); const size_t shared_bytes = A_bytes + B_bytes; - dim3 grid(num_blocks, 1, 1); - dim3 block(warp_size, nwarps_best); - - switch (nwarps_best) { - case 1: - conv2d_kernel, 1><<>>(X_D, K_D, Y_D, P); - break; - case 2: - conv2d_kernel, 2><<>>(X_D, K_D, Y_D, P); - break; - case 4: - conv2d_kernel, 4><<>>(X_D, K_D, Y_D, P); - break; - case 8: - conv2d_kernel, 8><<>>(X_D, K_D, Y_D, P); - break; - case 16: - conv2d_kernel, 16><<>>(X_D, K_D, Y_D, P); - break; - case 32: - conv2d_kernel, 32><<>>(X_D, K_D, Y_D, P); - break; - default: - GGML_ABORT("UNSUPPROTED NWARPS_BEST"); - } + dim3 grid(NUM_BL, 1, 1); + dim3 block(WARP_SIZE, NUM_WARPS, 1); + + conv2d_kernel<<>>(X_D, K_D, Y_D, P); } static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { @@ -422,31 +402,30 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t st = ctx.stream(); const int32_t * p = (const int32_t *) dst->op_params; - const int ST_X = p[0]; // stride_x - const int ST_Y = p[1]; // stride_y - const int PD_X = p[2]; // padding_x - const int PD_Y = p[3]; // padding_y - const int DL_X = p[4]; // dilation_x - const int DL_Y = p[5]; // dilation_y + const uint ST_X = p[0]; // stride_x + const uint ST_Y = p[1]; // stride_y + const uint PD_X = p[2]; // padding_x + const uint PD_Y = p[3]; // padding_y + const uint DL_X = p[4]; // dilation_x + const uint DL_Y = p[5]; // dilation_y // No cwhn GGML_ASSERT(p[6] == false); - const int IW = input->ne[0]; // input_w - const int IH = input->ne[1]; // input_h - const int OW = dst->ne[0]; // output_w - const int OH = dst->ne[1]; // output_h - const int KW = kernel->ne[0]; // kernel_w - const int KH = kernel->ne[1]; // kernel_h - const int IC = input->ne[2]; // input_channels - const int OC = kernel->ne[3]; // ouptut_chanles - const int B = input->ne[3]; // n_batches - - const int64_t TOTAL = B * OC * OH * OW; - const int IC_KH_KW = IC * KH * KW; - const int N_OH_OW = B * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, - PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW }; + const uint IW = input->ne[0]; // input_w + const uint IH = input->ne[1]; // input_h + const uint OW = dst->ne[0]; // output_w + const uint OH = dst->ne[1]; // output_h + const uint KW = kernel->ne[0]; // kernel_w + const uint KH = kernel->ne[1]; // kernel_h + const uint IC = input->ne[2]; // input_channels + const uint OC = kernel->ne[3]; // ouptut_chanles + const uint B = input->ne[3]; // n_batches + + const uint IC_KH_KW = IC * KH * KW; + const uint N_OH_OW = B * OH * OW; + const conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, + PD_Y, DL_X, DL_Y, IC, OC, B, IC_KH_KW, N_OH_OW }; if (kernel->type == GGML_TYPE_F16) { conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index 3dcce2b4a2e3b..3a1a5f28b572c 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,12 +1,14 @@ #pragma once #include "common.cuh" -#define BS_OC 16 +#define BS_OC 32 #define BS_ICKHKW 16 -#define BS_NOHOW 128 +#define BS_NOHOW 32 #define WMMA_M 16 #define WMMA_N 16 #define WMMA_K 16 +#define CUDA_CONV2D_BLOCK_SIZE 128 + void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 604957644fe5751c961ca1222ffed6b759751830 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Tue, 16 Sep 2025 04:58:58 +0530 Subject: [PATCH 08/12] CUDA: conv2d minor fixes CUDA: uint to int and added assertion --- ggml/src/ggml-cuda/conv2d.cu | 297 +++++++++++++++++------------------ 1 file changed, 148 insertions(+), 149 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index db92a40cd0dd0..deaca3d648d5a 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -4,95 +4,94 @@ #include struct conv_params { - const uint IW, IH; - const uint OW, OH; - const uint KW, KH; - const uint ST_X, ST_Y; - const uint PD_X, PD_Y; - const uint DL_X, DL_Y; - const uint IC, OC; - const uint B; + const int IW, IH; + const int OW, OH; + const int KW, KH; + const int ST_X, ST_Y; + const int PD_X, PD_Y; + const int DL_X, DL_Y; + const int IC, OC; + const int B; // helpers - const uint IC_KH_KW, N_OH_OW; + const int IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ static uint64_t calculate_input_coord(uint out_coord, - uint kern_coord, - uint stride, - uint dilation, - uint padding) { +__device__ __forceinline__ static int calculate_input_coord(int out_coord, + int kern_coord, + int stride, + int dilation, + int padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static uint64_t input_index(int n, int c, int y, int x, const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static uint64_t kernel_index(uint c_out, - uint c_in, - uint ky, - uint kx, - const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static uint64_t output_index(uint n, uint c, uint y, uint x, const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ __forceinline__ static void unpack_ickhkw(uint64_t idx, - uint & ic, - uint & kh, - uint & kw, + __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, + int & ic, + int & kh, + int & kw, const conv_params & P) { - ic = idx / (P.KW * P.KH); - uint r = idx - ic * (P.KW * P.KH); - kh = r / P.KW; - kw = r - kh * P.KW; + ic = idx / (P.KW * P.KH); + int r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; } - __device__ __forceinline__ static void unpack_nohow(uint64_t idx, - uint & n, - uint & oh, - uint & ow, + __device__ __forceinline__ static void unpack_nohow(int64_t idx, + int & n, + int & oh, + int & ow, const conv_params & P) { - n = idx / (P.OH * P.OW); - uint r = idx - n * (P.OH * P.OW); - oh = r / P.OW; - ow = r - oh * P.OW; + n = idx / (P.OH * P.OW); + int r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; } }; -template __device__ class float_mma { +template class float_mma { private: - static constexpr uint num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; // for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] ... [14,0] // lane 1 will store and compute for [0,1], [2,1], [4,1] ... [14,1] - float acc[num_acc]; + float acc[num_acc]; public: __device__ __forceinline__ float_mma() { #pragma unroll - for (uint i = 0; i < num_acc; i++) { + for (int i = 0; i < num_acc; i++) { acc[i] = 0.0f; } } __device__ __forceinline__ void mma(const float * __restrict__ A_sh, const float * __restrict__ B_sh, - const uint strideA, - const uint strideB) { - const uint lane_id = threadIdx.x % WARP_SIZE; + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (uint i = 0; i < num_acc; i++) { - const uint e = lane_id + i * WARP_SIZE; - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int i = 0; i < num_acc; i++) { + const int e = lane_id + i * WARP_SIZE; + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; #pragma unroll - for (uint k = 0; k < WMMA_K; k++) { + for (int k = 0; k < WMMA_K; k++) { const float a = A_sh[m * strideA + k]; const float b = B_sh[k * strideB + n]; acc[i] = fmaf(a, b, acc[i]); @@ -100,23 +99,26 @@ template __device__ class float_mma { } } - __device__ __forceinline__ void store_result(const uint OC_BASE, - const uint NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const uint lane_id = threadIdx.x % WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (uint i = 0; i < num_acc; i++) { - const uint e = lane_id + i * WARP_SIZE; - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int i = 0; i < num_acc; i++) { + const int e = lane_id + i * WARP_SIZE; + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const uint oc = OC_BASE + m; - const uint nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - uint n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -141,37 +143,33 @@ template class half_mma { public: __device__ __forceinline__ half_mma() {} - __device__ __forceinline__ void clear() { -# pragma unroll - for (int l = 0; l < c_frag.ne; ++l) { - c_frag.x[l] = 0.0f; - } - } - __device__ __forceinline__ void mma(const half * __restrict__ A_sh, const half * __restrict__ B_sh, - const uint strideA, - const uint strideB) { + const int strideA, + const int strideB) { load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ void store_result(const uint OC_BASE, - const uint NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { # pragma unroll - for (uint l = 0; l < tile_acc::ne; ++l) { - const uint e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int l = 0; l < tile_acc::ne; ++l) { + const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const uint oc = OC_BASE + m; - const uint nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < (P.N_OH_OW)) { - uint n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = c_frag.x[l]; } @@ -214,22 +212,22 @@ template class half_mma { } } - __device__ __forceinline__ void store_result(const uint OC_BASE, - const uint NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const uint lane_id = threadIdx.x % WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; # pragma unroll - for (uint e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const uint oc = OC_BASE + m; - const uint nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - uint n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -246,28 +244,28 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const uint warpId = threadIdx.y; - const uint linear_tid = threadIdx.y * blockDim.x + threadIdx.x; + const int warpId = threadIdx.y; + const int linear_tid = threadIdx.y * blockDim.x + threadIdx.x; - const uint NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const uint NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); - const uint NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - const uint NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; + const int NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; mma acc[NUM_TILES_PER_WARP]; - const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const uint BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; - const uint BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; + const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; + const int BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; - const uint BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const uint BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; unsigned char * ptr = smem_raw; - const uint A_total = BS_OC * BS_ICKHKW; - const uint B_total = BS_ICKHKW * BS_NOHOW; + const int A_total = BS_OC * BS_ICKHKW; + const int B_total = BS_ICKHKW * BS_NOHOW; size_t offsetA = (size_t) A_total * sizeof(T); T * A_sh = reinterpret_cast(ptr); @@ -277,37 +275,37 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo T * B_sh = reinterpret_cast(ptr); ptr += offsetB; - for (uint t = 0; t < NUM_IC_TILES; ++t) { + for (int t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (uint tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { - const uint row = tid / BS_ICKHKW; - const uint col = tid % BS_ICKHKW; + for (int tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { + const int row = tid / BS_ICKHKW; + const int col = tid % BS_ICKHKW; - const uint shared_oc = BLOCK_OC_BASE + row; - const uint shared_ickhkw = t * BS_ICKHKW + col; + const int shared_oc = BLOCK_OC_BASE + row; + const int shared_ickhkw = t * BS_ICKHKW + col; T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { - uint ic, kh, kw; + int ic, kh, kw; layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - const uint kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; + const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } A_sh[row * BS_ICKHKW + col] = val; } #pragma unroll - for (uint tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { - const uint brow = tid / BS_NOHOW; - const uint bcol = tid % BS_NOHOW; + for (int tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { + const int brow = tid / BS_NOHOW; + const int bcol = tid % BS_NOHOW; - const uint IC_KH_KW_IDX = t * BS_ICKHKW + brow; - const uint N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + const int IC_KH_KW_IDX = t * BS_ICKHKW + brow; + const int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { - uint n, oh, ow; - uint ic, kh, kw; + int n, oh, ow; + int ic, kh, kw; layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); @@ -323,19 +321,19 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo __syncthreads(); #pragma unroll - for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { - const uint warp = warpId + i * num_warps; + for (int i = 0; i < NUM_TILES_PER_WARP; i++) { + const int warp = warpId + i * num_warps; if (warp >= NUM_WARPS_NEED) { continue; } - const uint WARP_OC = warp / NUM_WARPS_NOHOW; - const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const int WARP_OC = warp / NUM_WARPS_NOHOW; + const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; #pragma unroll - for (uint k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { const T * A_k_ptr = A_warp_base + k_tile; const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); @@ -345,28 +343,29 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo } #pragma unroll - for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { - const uint warp = warpId + i * num_warps; + for (int i = 0; i < NUM_TILES_PER_WARP; i++) { + const int warp = warpId + i * num_warps; if (warp >= NUM_WARPS_NEED) { continue; } - const uint WARP_OC = warp / NUM_WARPS_NOHOW; - const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; - const uint OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const uint NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + const int WARP_OC = warp / NUM_WARPS_NOHOW; + const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } template class mma> -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); + GGML_ASSERT(BS_ICKHKW % WMMA_K == 0); - const uint NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; - const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const uint NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; + const int NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; - constexpr uint NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); @@ -402,28 +401,28 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t st = ctx.stream(); const int32_t * p = (const int32_t *) dst->op_params; - const uint ST_X = p[0]; // stride_x - const uint ST_Y = p[1]; // stride_y - const uint PD_X = p[2]; // padding_x - const uint PD_Y = p[3]; // padding_y - const uint DL_X = p[4]; // dilation_x - const uint DL_Y = p[5]; // dilation_y + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y // No cwhn GGML_ASSERT(p[6] == false); - const uint IW = input->ne[0]; // input_w - const uint IH = input->ne[1]; // input_h - const uint OW = dst->ne[0]; // output_w - const uint OH = dst->ne[1]; // output_h - const uint KW = kernel->ne[0]; // kernel_w - const uint KH = kernel->ne[1]; // kernel_h - const uint IC = input->ne[2]; // input_channels - const uint OC = kernel->ne[3]; // ouptut_chanles - const uint B = input->ne[3]; // n_batches - - const uint IC_KH_KW = IC * KH * KW; - const uint N_OH_OW = B * OH * OW; + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int IC_KH_KW = IC * KH * KW; + const int N_OH_OW = B * OH * OW; const conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, IC_KH_KW, N_OH_OW }; From cc3d366e75132628dde9fa3c0e907911a59b7646 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Thu, 18 Sep 2025 16:51:44 +0200 Subject: [PATCH 09/12] Adds CUDA version of Vulkan direct conv2d. * Extra: reduces bank conflicts --- ggml/src/ggml-cuda/conv2d-mm.cu | 450 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/conv2d-mm.cuh | 9 + ggml/src/ggml-cuda/ggml-cuda.cu | 9 +- 3 files changed, 467 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-cuda/conv2d-mm.cu create mode 100644 ggml/src/ggml-cuda/conv2d-mm.cuh diff --git a/ggml/src/ggml-cuda/conv2d-mm.cu b/ggml/src/ggml-cuda/conv2d-mm.cu new file mode 100644 index 0000000000000..7de78b4372eda --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-mm.cu @@ -0,0 +1,450 @@ +#include "conv2d-mm.cuh" + +#include + +// If defined, indices are computed once and re-used by each thread +#if __CUDA_ARCH__ < 700 +# define USE_COLLECTIVES +#endif + +//#define A_TRANS // Transposes the A matrix in shmem +//#define A_OPT // Optimizes A for reducing bank conflicts +#define B_OPT // Optimizes B for reducing bank conflicts + +#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) + +uint32_t ceil_div(uint32_t M, uint32_t N); +int get_sm_count(); + +uint32_t ceil_div(uint32_t M, uint32_t N) { + return (M + N - 1) / N; +} + +__align__(16) struct Params { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + uint32_t KWmp; + uint32_t KWL; + uint32_t KWKHmp; + uint32_t KWKHL; + uint32_t OWmp; + uint32_t OWL; + uint32_t OWOHmp; + uint32_t OWOHL; +}; + +__constant__ __device__ Params dp; + +// see init_fastdiv_values in ggml-vulkan.cpp +__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) { + return (__umulhi(n, mp) + n) >> L; +} + +// --> conv_2d kernel modified to function as a matmul +template +__global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, + uint NPQ, + uint CRS, + const float * knl_data, + const float * src_data, + float * dst_data) { + // Each block computes a tile of the result of size BS_K*BS_NPQ + const uint B_idx_K = blockIdx.x; + const uint B_idx_NPQ = blockIdx.y; + assert(gridDim.z == 1); + + // T_y, T_x: the tile position this thread is resposible for computing. + assert(BS_NPQ % TS_NPQ == 0); + assert(TS_NPQ <= BS_NPQ); + const uint NT_x = BS_NPQ / TS_NPQ; + assert(BS_K % TS_K == 0); + assert(TS_K <= BS_K); + // const uint NT_y = BS_K / TS_K; // unused + + // Ensure that the kernel is properly called + // 1. each thread processes a threadtile of size TS_K*TS_NPQ, that is exactly the WG_SIZE + assert((BS_K / TS_K) * (BS_NPQ / TS_NPQ) == WG_SIZE); + // 2. the number of threads is exactly the WG_SIZE + assert(blockDim.x == WG_SIZE && blockDim.y == 1 && blockDim.z == 1); + + const uint T_y = threadIdx.x / NT_x; + const uint T_x = threadIdx.x % NT_x; + + __shared__ float Ash[BS_K * BS_CRS]; + __shared__ float Bsh[BS_CRS * BS_NPQ]; + + const uint Ar = threadIdx.x / BS_CRS; + const uint Ac = threadIdx.x % BS_CRS; + assert(WG_SIZE >= BS_CRS); + const uint ArpWg = WG_SIZE / BS_CRS; + + const uint Br = threadIdx.x / BS_NPQ; + const uint Bc = threadIdx.x % BS_NPQ; + assert(WG_SIZE >= BS_NPQ); + const uint BrpWg = WG_SIZE / BS_NPQ; + + float regA[TS_K] = { 0.0 }; + float regB[TS_NPQ] = { 0.0 }; + float regC[TS_K * TS_NPQ] = { 0.0 }; + + /* Advance block in CRS dim */ + for (uint idx_CRS = 0; idx_CRS < CRS; idx_CRS += BS_CRS) { +/* Load kernel to A_block: (BS_K x BS_CRS)*/ +#ifdef USE_COLLECTIVES + const int laneId = threadIdx.x & 0x1f; + // Each thread in CRS dim computes a result that will be broadcast among them + assert(CRS <= warpSize); + const uint32_t cached_CRS_idx = idx_CRS + laneId; + const uint32_t cached_Cin_idx = cached_CRS_idx / (dp.KW * dp.KH); + uint32_t rem = (cached_CRS_idx - cached_Cin_idx * dp.KW * dp.KH); + const uint32_t cached_KH_idx = rem / dp.KW; + const uint32_t cached_KW_idx = rem - cached_KH_idx * dp.KW; + + const uint32_t CRS_idx_a = __shfl_sync(0xffffffff, cached_CRS_idx, Ac); + const uint32_t KH_idx_a = __shfl_sync(0xffffffff, cached_KH_idx, Ac); + //const uint32_t KW_idx_a = __shfl_sync(0xffffffff, cached_KW_idx, Ac); // unused + const uint32_t Cin_idx_a = __shfl_sync(0xffffffff, cached_Cin_idx, Ac); +#else + uint32_t CRS_idx_a = idx_CRS + Ac; //Global CRS_idx (column index of A) + //uint32_t Cin_idx_a = CRS_idx_a / (dp.KW*dp.KH); + uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * dp.KW * dp.KH; + //uint32_t KH_idx_a = (CRS_idx_a - Cin_idx_a*dp.KW*dp.KH) / dp.KW; + uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW; +//uint32_t KW_idx_a = CRS_idx_a - Cin_idx_a*dp.KW*dp.KH - KH_idx_a*dp.KW; // unused +#endif + +#pragma unroll + for (uint r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + const uint32_t K_idx_a = B_idx_K * BS_K + r_offset + Ar; /* Global K_idx (row index of A)*/ + // General addressing (does not assume contiguity) + //const uint32_t knl_idx = KW_idx_a + KH_idx_a*dp.nb01 + Cin_idx_a*dp.nb02 + K_idx_a*dp.nb03; + // Contiguous addressing + float val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)]; + if (CRS_idx_a >= CRS || K_idx_a >= K) { + val = 0.0; + } + +#ifdef A_TRANS +# ifdef A_OPT + uint32_t T_id = (r_offset + Ar) / TS_K; // E.g.: 41/16 = 2 + uint32_t vec_in_TT = ((r_offset + Ar) - T_id * TS_K) / VEC_SIZE; // E.g.: 41-2*16 = 9 -> 9/4 = 2 + uint32_t elem_in_vec = ((r_offset + Ar) - T_id * TS_K) % VEC_SIZE; // E.g.: 9 -> 9%4 = 1 + uint32_t col_offset = vec_in_TT * (NT_y * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; +# else + uint32_t col_offset = (r_offset + Ar); +# endif + Ash[Ac * BS_K + col_offset] = val; +#else + Ash[(r_offset + Ar) * BS_CRS + Ac] = val; +#endif + } + +#pragma unroll + for (uint r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + // Compute indices for N, OH, OW from NPQ_idx + const uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + Bc; /* Global NPQ index (column index of B) */ + //const uint32_t N_idx = NPQ_idx / (dp.OH*dp.OW); + uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW; + uint32_t NPQ_remainder = NPQ_idx - N_idx * dp.OH * dp.OW; + //const uint32_t OH_idx = (NPQ_idx - N_idx*dp.OH*dp.OW) / dp.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OWmp, dp.OWL); // divide by p.OW; + const uint32_t OW_idx = NPQ_idx - N_idx * dp.OH * dp.OW - OH_idx * dp.OW; + +#ifdef USE_COLLECTIVES + const uint32_t CRS_idx_b = __shfl_sync(0xffffffff, cached_CRS_idx, r_offset + Br); + const uint32_t KH_idx_b = __shfl_sync(0xffffffff, cached_KH_idx, r_offset + Br); + const uint32_t KW_idx_b = __shfl_sync(0xffffffff, cached_KW_idx, r_offset + Br); + const uint32_t Cin_idx_b = __shfl_sync(0xffffffff, cached_Cin_idx, r_offset + Br); +#else + // Compute indices KH, KW, Cin from CRS_idx + uint32_t CRS_idx_b = idx_CRS + r_offset + Br; + //uint32_t Cin_idx_b = CRS_idx_b / (dp.KW*dp.KH); + uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH; + //uint32_t KH_idx_b = (CRS_idx_b - Cin_idx_b*dp.KW*dp.KH) / dp.KW; + uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW; + uint32_t KW_idx_b = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH - KH_idx_b * dp.KW; +#endif + + // Compute indices for W, H from OH, OW, KH, KW + const int32_t H_idx = OH_idx * dp.s1 + KH_idx_b * dp.d1 - dp.p1; + const int32_t W_idx = OW_idx * dp.s0 + KW_idx_b * dp.d0 - dp.p0; + const uint32_t src_idx = min(max(W_idx + H_idx * dp.nb11 + Cin_idx_b * dp.nb12 + N_idx * dp.nb13, 0), + dp.Cin * dp.N * dp.W * dp.H - 1); + float val; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= dp.H || W_idx < 0 || W_idx >= dp.W) { + val = 0.0; + } else { + val = src_data[src_idx]; + } + +#ifdef B_OPT + assert(VEC_SIZE <= TS_NPQ); + const uint32_t T_id = Bc / TS_NPQ; // E.g.: 41/16 = 2 + const uint32_t vec_in_TT = (Bc - T_id * TS_NPQ) / VEC_SIZE; // E.g.: 41-2*16 = 9 -> 9/4 = 2 + const uint32_t elem_in_vec = (Bc - T_id * TS_NPQ) % VEC_SIZE; // E.g.: 9 -> 9%4 = 1 + const uint32_t col_offset = vec_in_TT * (NT_x * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; +#else + uint32_t col_offset = Bc; +#endif + Bsh[(r_offset + Br) * BS_NPQ + col_offset] = val; + } + + __syncthreads(); + + if (T_y * TS_K < K) { +#pragma unroll + for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; ++CRS_lidx) { +#pragma unroll + for (uint32_t T_ly = 0; T_ly < TS_K; ++T_ly) { +#ifdef A_TRANS +# ifdef A_OPT + uint32_t T_id = T_y; + uint32_t vec_in_TT = T_ly / VEC_SIZE; + uint32_t elem_in_vec = T_ly % VEC_SIZE; + uint32_t col_offset = vec_in_TT * (NT_y * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; +# else + uint32_t col_offset = (T_y * TS_K + T_ly); +# endif + regA[T_ly] = Ash[CRS_lidx * BS_K + col_offset]; +#else + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]; +#endif + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) { +#ifdef B_OPT + const uint32_t T_id = T_x; + const uint32_t vec_in_TT = T_lx / VEC_SIZE; + const uint32_t elem_in_vec = T_lx % VEC_SIZE; + const uint32_t col_offset = vec_in_TT * (NT_x * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; +#else + const uint32_t col_offset = T_x * TS_NPQ + T_lx; +#endif + regB[T_lx] = Bsh[CRS_lidx * BS_NPQ + col_offset]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; ++T_ly) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) { + regC[T_ly * TS_NPQ + T_lx] = fmaf(regA[T_ly], regB[T_lx], regC[T_ly * TS_NPQ + T_lx]); + } + } + } + } + __syncthreads(); + } + + /* Save C* */ + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + const uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + const uint32_t NPQ_idx_c = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + //const uint32_t N_idx_c = NPQ_idx_c / (dp.OH*dp.OW); + const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW; + //const uint32_t OH_idx_c = (NPQ_idx_c - N_idx_c*dp.OH*dp.OW) / dp.OW; + const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OWmp, dp.OWL); // divide by p.OW; + const uint32_t OW_idx_c = NPQ_idx_c - N_idx_c * dp.OH * dp.OW - OH_idx_c * dp.OW; + const uint32_t dst_idx = OW_idx_c + OH_idx_c * dp.nb1 + K_idx * dp.nb2 + N_idx_c * dp.nb3; + if (K_idx < K && NPQ_idx_c < NPQ) { + dst_data[dst_idx] = regC[T_ly * TS_NPQ + T_lx]; + } + } + } +} + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); +} + +constexpr int conv_shapes[][NUM_VARIANTS] = { + { 128, 64, 32 }, // BS_K + { 16, 32, 16 }, // BS_CRS + { 128, 32, 256 }, // BS_NPQ + { 8, 4, 8 } // TS_K + //{8, 8, 8} // TS_NPQ // Option 2 +}; + +int get_sm_count() { + int device; + cudaGetDevice(&device); + + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); + return sm_count; +} + +template +void ggml_cuda_op_conv_2d_variant(ggml_backend_cuda_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + const Params & p) { + // Tile size calculation options: + // Option 1: fix block size and all tile sizes except TS_NPQ as it is the free parameter (used in the Vulkan backend). + // Option 2: fix all tile sizes and block size is the free parameter. + const uint32_t WG_SIZE = 256; // Option 1 + + const uint32_t BS_K = conv_shapes[0][CONV_SHAPE]; + const uint32_t BS_CRS = conv_shapes[1][CONV_SHAPE]; + const uint32_t BS_NPQ = conv_shapes[2][CONV_SHAPE]; + const uint32_t TS_K = conv_shapes[3][CONV_SHAPE]; + //const uint32_t TS_NPQ = sh[4][CONV_SHAPE]; // Option 2 + const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + + // Some architectures can use 128-bit loads that might be more efficient. + const uint32_t VEC_SIZE = TS_NPQ >= 4 ? 4 : 1; + + //const uint32_t WG_SIZE = (BS_K*BS_NPQ) / (TS_K*TS_NPQ); // Option 2 + + // Kernel runtime parameters + int64_t NPQ = p.N * p.OW * p.OH; + uint32_t NB_K = CEIL_DIV(p.Cout, BS_K); + uint32_t NB_NPQ = CEIL_DIV(NPQ, BS_NPQ); + + cudaMemcpyToSymbol(dp, &p, sizeof(Params)); + + // Kernel arguments + float * src0_data = (float *) src0->data; + float * src1_data = (float *) src1->data; + float * dst_data = (float *) dst->data; + + dim3 gridDim(NB_K, NB_NPQ); + dim3 blockDim(WG_SIZE); + cudaStream_t stream = ctx.stream(); + + mm + <<>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data); +} + +void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // Initialize kernel variants + + using Conv2DFuncPtr = + void (*)(ggml_backend_cuda_context &, ggml_tensor *, ggml_tensor *, ggml_tensor *, const Params &); + + Conv2DFuncPtr conv2d_variants[NUM_VARIANTS]; + + conv2d_variants[CONV_SHAPE_128x128] = &ggml_cuda_op_conv_2d_variant; + conv2d_variants[CONV_SHAPE_64x32] = &ggml_cuda_op_conv_2d_variant; + conv2d_variants[CONV_SHAPE_32x256] = &ggml_cuda_op_conv_2d_variant; + + // Parse op input, prepare kernel input + + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + Params p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + // Select the proper variant based on problem size and device parameters (sm count) + + // Problem size (Cout x NPQ) + std::array elements = { p.Cout, p.N * p.OW * p.OH, 1 }; + + const uint32_t sm_count = get_sm_count(); + + uint32_t variant_ntiles[NUM_VARIANTS]; + + for (int var_id = 0; var_id < NUM_VARIANTS; var_id++) { + const uint32_t ntilesy = ceil_div(elements[0], conv_shapes[var_id][0]); // CEIL_DIV(Cout, NB_K) + const uint32_t ntilesx = ceil_div(elements[1], conv_shapes[var_id][2]); // CEIL_DIV(NPQ, NB_NPQ) + variant_ntiles[var_id] = ntilesy * ntilesx; + } + + uint32_t selected_variant_id = CONV_SHAPE_128x128; + + if (elements[0] > 64 && variant_ntiles[CONV_SHAPE_128x128] >= sm_count * 2) { + selected_variant_id = CONV_SHAPE_128x128; + } else if (elements[0] <= 32 && variant_ntiles[CONV_SHAPE_32x256] >= sm_count * 2) { + selected_variant_id = CONV_SHAPE_32x256; + } else { + selected_variant_id = CONV_SHAPE_64x32; + } + + conv2d_variants[selected_variant_id](ctx, src0, src1, dst, p); +} diff --git a/ggml/src/ggml-cuda/conv2d-mm.cuh b/ggml/src/ggml-cuda/conv2d-mm.cuh new file mode 100644 index 0000000000000..fc547397e53df --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-mm.cuh @@ -0,0 +1,9 @@ +#include "common.cuh" + +#define CONV_SHAPE_128x128 0 +#define CONV_SHAPE_64x32 1 +#define CONV_SHAPE_32x256 2 + +#define NUM_VARIANTS 3 + +void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9ea8f4589d71d..a9fe51875af46 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,6 +13,7 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv2d-mm.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2461,7 +2462,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_im2col_3d(ctx, dst); break; case GGML_OP_CONV_2D: - ggml_cuda_op_conv2d(ctx, dst); + if (!getenv("GGML_CUDA_USE_LEGACY_CONV") && + (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32)) { + ggml_cuda_op_conv2d_mm(ctx, dst); + } else { + ggml_cuda_op_conv2d(ctx, dst); + } break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); From e3f94c684a1efcfaff406c32ec178ac2007749ae Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Wed, 29 Oct 2025 01:29:54 +0530 Subject: [PATCH 10/12] adding vulkan code like tensor code conv2d --- ggml/src/ggml-cuda/conv2d-mm.cu | 450 ---------------------- ggml/src/ggml-cuda/conv2d-mm.cuh | 9 - ggml/src/ggml-cuda/conv2d-tensor-core.cu | 373 ++++++++++++++++++ ggml/src/ggml-cuda/conv2d-tensor-core.cuh | 31 ++ ggml/src/ggml-cuda/conv2d.cu | 437 +++++---------------- ggml/src/ggml-cuda/conv2d.cuh | 11 +- ggml/src/ggml-cuda/ggml-cuda.cu | 9 +- 7 files changed, 492 insertions(+), 828 deletions(-) delete mode 100644 ggml/src/ggml-cuda/conv2d-mm.cu delete mode 100644 ggml/src/ggml-cuda/conv2d-mm.cuh create mode 100644 ggml/src/ggml-cuda/conv2d-tensor-core.cu create mode 100644 ggml/src/ggml-cuda/conv2d-tensor-core.cuh diff --git a/ggml/src/ggml-cuda/conv2d-mm.cu b/ggml/src/ggml-cuda/conv2d-mm.cu deleted file mode 100644 index 7de78b4372eda..0000000000000 --- a/ggml/src/ggml-cuda/conv2d-mm.cu +++ /dev/null @@ -1,450 +0,0 @@ -#include "conv2d-mm.cuh" - -#include - -// If defined, indices are computed once and re-used by each thread -#if __CUDA_ARCH__ < 700 -# define USE_COLLECTIVES -#endif - -//#define A_TRANS // Transposes the A matrix in shmem -//#define A_OPT // Optimizes A for reducing bank conflicts -#define B_OPT // Optimizes B for reducing bank conflicts - -#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) - -uint32_t ceil_div(uint32_t M, uint32_t N); -int get_sm_count(); - -uint32_t ceil_div(uint32_t M, uint32_t N) { - return (M + N - 1) / N; -} - -__align__(16) struct Params { - uint32_t Cout; - uint32_t Cin; - uint32_t N; - - uint32_t KW; - uint32_t KH; - uint32_t W; - uint32_t H; - uint32_t OW; - uint32_t OH; - - uint32_t s0; - uint32_t s1; - uint32_t p0; - uint32_t p1; - uint32_t d0; - uint32_t d1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - - uint32_t nb1; - uint32_t nb2; - uint32_t nb3; - - uint32_t KWmp; - uint32_t KWL; - uint32_t KWKHmp; - uint32_t KWKHL; - uint32_t OWmp; - uint32_t OWL; - uint32_t OWOHmp; - uint32_t OWOHL; -}; - -__constant__ __device__ Params dp; - -// see init_fastdiv_values in ggml-vulkan.cpp -__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) { - return (__umulhi(n, mp) + n) >> L; -} - -// --> conv_2d kernel modified to function as a matmul -template -__global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, - uint NPQ, - uint CRS, - const float * knl_data, - const float * src_data, - float * dst_data) { - // Each block computes a tile of the result of size BS_K*BS_NPQ - const uint B_idx_K = blockIdx.x; - const uint B_idx_NPQ = blockIdx.y; - assert(gridDim.z == 1); - - // T_y, T_x: the tile position this thread is resposible for computing. - assert(BS_NPQ % TS_NPQ == 0); - assert(TS_NPQ <= BS_NPQ); - const uint NT_x = BS_NPQ / TS_NPQ; - assert(BS_K % TS_K == 0); - assert(TS_K <= BS_K); - // const uint NT_y = BS_K / TS_K; // unused - - // Ensure that the kernel is properly called - // 1. each thread processes a threadtile of size TS_K*TS_NPQ, that is exactly the WG_SIZE - assert((BS_K / TS_K) * (BS_NPQ / TS_NPQ) == WG_SIZE); - // 2. the number of threads is exactly the WG_SIZE - assert(blockDim.x == WG_SIZE && blockDim.y == 1 && blockDim.z == 1); - - const uint T_y = threadIdx.x / NT_x; - const uint T_x = threadIdx.x % NT_x; - - __shared__ float Ash[BS_K * BS_CRS]; - __shared__ float Bsh[BS_CRS * BS_NPQ]; - - const uint Ar = threadIdx.x / BS_CRS; - const uint Ac = threadIdx.x % BS_CRS; - assert(WG_SIZE >= BS_CRS); - const uint ArpWg = WG_SIZE / BS_CRS; - - const uint Br = threadIdx.x / BS_NPQ; - const uint Bc = threadIdx.x % BS_NPQ; - assert(WG_SIZE >= BS_NPQ); - const uint BrpWg = WG_SIZE / BS_NPQ; - - float regA[TS_K] = { 0.0 }; - float regB[TS_NPQ] = { 0.0 }; - float regC[TS_K * TS_NPQ] = { 0.0 }; - - /* Advance block in CRS dim */ - for (uint idx_CRS = 0; idx_CRS < CRS; idx_CRS += BS_CRS) { -/* Load kernel to A_block: (BS_K x BS_CRS)*/ -#ifdef USE_COLLECTIVES - const int laneId = threadIdx.x & 0x1f; - // Each thread in CRS dim computes a result that will be broadcast among them - assert(CRS <= warpSize); - const uint32_t cached_CRS_idx = idx_CRS + laneId; - const uint32_t cached_Cin_idx = cached_CRS_idx / (dp.KW * dp.KH); - uint32_t rem = (cached_CRS_idx - cached_Cin_idx * dp.KW * dp.KH); - const uint32_t cached_KH_idx = rem / dp.KW; - const uint32_t cached_KW_idx = rem - cached_KH_idx * dp.KW; - - const uint32_t CRS_idx_a = __shfl_sync(0xffffffff, cached_CRS_idx, Ac); - const uint32_t KH_idx_a = __shfl_sync(0xffffffff, cached_KH_idx, Ac); - //const uint32_t KW_idx_a = __shfl_sync(0xffffffff, cached_KW_idx, Ac); // unused - const uint32_t Cin_idx_a = __shfl_sync(0xffffffff, cached_Cin_idx, Ac); -#else - uint32_t CRS_idx_a = idx_CRS + Ac; //Global CRS_idx (column index of A) - //uint32_t Cin_idx_a = CRS_idx_a / (dp.KW*dp.KH); - uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * dp.KW * dp.KH; - //uint32_t KH_idx_a = (CRS_idx_a - Cin_idx_a*dp.KW*dp.KH) / dp.KW; - uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW; -//uint32_t KW_idx_a = CRS_idx_a - Cin_idx_a*dp.KW*dp.KH - KH_idx_a*dp.KW; // unused -#endif - -#pragma unroll - for (uint r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { - const uint32_t K_idx_a = B_idx_K * BS_K + r_offset + Ar; /* Global K_idx (row index of A)*/ - // General addressing (does not assume contiguity) - //const uint32_t knl_idx = KW_idx_a + KH_idx_a*dp.nb01 + Cin_idx_a*dp.nb02 + K_idx_a*dp.nb03; - // Contiguous addressing - float val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)]; - if (CRS_idx_a >= CRS || K_idx_a >= K) { - val = 0.0; - } - -#ifdef A_TRANS -# ifdef A_OPT - uint32_t T_id = (r_offset + Ar) / TS_K; // E.g.: 41/16 = 2 - uint32_t vec_in_TT = ((r_offset + Ar) - T_id * TS_K) / VEC_SIZE; // E.g.: 41-2*16 = 9 -> 9/4 = 2 - uint32_t elem_in_vec = ((r_offset + Ar) - T_id * TS_K) % VEC_SIZE; // E.g.: 9 -> 9%4 = 1 - uint32_t col_offset = vec_in_TT * (NT_y * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; -# else - uint32_t col_offset = (r_offset + Ar); -# endif - Ash[Ac * BS_K + col_offset] = val; -#else - Ash[(r_offset + Ar) * BS_CRS + Ac] = val; -#endif - } - -#pragma unroll - for (uint r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { - // Compute indices for N, OH, OW from NPQ_idx - const uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + Bc; /* Global NPQ index (column index of B) */ - //const uint32_t N_idx = NPQ_idx / (dp.OH*dp.OW); - uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW; - uint32_t NPQ_remainder = NPQ_idx - N_idx * dp.OH * dp.OW; - //const uint32_t OH_idx = (NPQ_idx - N_idx*dp.OH*dp.OW) / dp.OW; - uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OWmp, dp.OWL); // divide by p.OW; - const uint32_t OW_idx = NPQ_idx - N_idx * dp.OH * dp.OW - OH_idx * dp.OW; - -#ifdef USE_COLLECTIVES - const uint32_t CRS_idx_b = __shfl_sync(0xffffffff, cached_CRS_idx, r_offset + Br); - const uint32_t KH_idx_b = __shfl_sync(0xffffffff, cached_KH_idx, r_offset + Br); - const uint32_t KW_idx_b = __shfl_sync(0xffffffff, cached_KW_idx, r_offset + Br); - const uint32_t Cin_idx_b = __shfl_sync(0xffffffff, cached_Cin_idx, r_offset + Br); -#else - // Compute indices KH, KW, Cin from CRS_idx - uint32_t CRS_idx_b = idx_CRS + r_offset + Br; - //uint32_t Cin_idx_b = CRS_idx_b / (dp.KW*dp.KH); - uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH; - //uint32_t KH_idx_b = (CRS_idx_b - Cin_idx_b*dp.KW*dp.KH) / dp.KW; - uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW; - uint32_t KW_idx_b = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH - KH_idx_b * dp.KW; -#endif - - // Compute indices for W, H from OH, OW, KH, KW - const int32_t H_idx = OH_idx * dp.s1 + KH_idx_b * dp.d1 - dp.p1; - const int32_t W_idx = OW_idx * dp.s0 + KW_idx_b * dp.d0 - dp.p0; - const uint32_t src_idx = min(max(W_idx + H_idx * dp.nb11 + Cin_idx_b * dp.nb12 + N_idx * dp.nb13, 0), - dp.Cin * dp.N * dp.W * dp.H - 1); - float val; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= dp.H || W_idx < 0 || W_idx >= dp.W) { - val = 0.0; - } else { - val = src_data[src_idx]; - } - -#ifdef B_OPT - assert(VEC_SIZE <= TS_NPQ); - const uint32_t T_id = Bc / TS_NPQ; // E.g.: 41/16 = 2 - const uint32_t vec_in_TT = (Bc - T_id * TS_NPQ) / VEC_SIZE; // E.g.: 41-2*16 = 9 -> 9/4 = 2 - const uint32_t elem_in_vec = (Bc - T_id * TS_NPQ) % VEC_SIZE; // E.g.: 9 -> 9%4 = 1 - const uint32_t col_offset = vec_in_TT * (NT_x * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; -#else - uint32_t col_offset = Bc; -#endif - Bsh[(r_offset + Br) * BS_NPQ + col_offset] = val; - } - - __syncthreads(); - - if (T_y * TS_K < K) { -#pragma unroll - for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; ++CRS_lidx) { -#pragma unroll - for (uint32_t T_ly = 0; T_ly < TS_K; ++T_ly) { -#ifdef A_TRANS -# ifdef A_OPT - uint32_t T_id = T_y; - uint32_t vec_in_TT = T_ly / VEC_SIZE; - uint32_t elem_in_vec = T_ly % VEC_SIZE; - uint32_t col_offset = vec_in_TT * (NT_y * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; -# else - uint32_t col_offset = (T_y * TS_K + T_ly); -# endif - regA[T_ly] = Ash[CRS_lidx * BS_K + col_offset]; -#else - regA[T_ly] = Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]; -#endif - } - for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) { -#ifdef B_OPT - const uint32_t T_id = T_x; - const uint32_t vec_in_TT = T_lx / VEC_SIZE; - const uint32_t elem_in_vec = T_lx % VEC_SIZE; - const uint32_t col_offset = vec_in_TT * (NT_x * VEC_SIZE) + T_id * VEC_SIZE + elem_in_vec; -#else - const uint32_t col_offset = T_x * TS_NPQ + T_lx; -#endif - regB[T_lx] = Bsh[CRS_lidx * BS_NPQ + col_offset]; - } - for (uint32_t T_ly = 0; T_ly < TS_K; ++T_ly) { - for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) { - regC[T_ly * TS_NPQ + T_lx] = fmaf(regA[T_ly], regB[T_lx], regC[T_ly * TS_NPQ + T_lx]); - } - } - } - } - __syncthreads(); - } - - /* Save C* */ - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - const uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; - const uint32_t NPQ_idx_c = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; - //const uint32_t N_idx_c = NPQ_idx_c / (dp.OH*dp.OW); - const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW; - //const uint32_t OH_idx_c = (NPQ_idx_c - N_idx_c*dp.OH*dp.OW) / dp.OW; - const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OWmp, dp.OWL); // divide by p.OW; - const uint32_t OW_idx_c = NPQ_idx_c - N_idx_c * dp.OH * dp.OW - OH_idx_c * dp.OW; - const uint32_t dst_idx = OW_idx_c + OH_idx_c * dp.nb1 + K_idx * dp.nb2 + N_idx_c * dp.nb3; - if (K_idx < K && NPQ_idx_c < NPQ) { - dst_data[dst_idx] = regC[T_ly * TS_NPQ + T_lx]; - } - } - } -} - -// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. -// Precompute mp (m' in the paper) and L such that division -// can be computed using a multiply (high 32b of 64b result) -// and a shift: -// -// n/d = (mulhi(n, mp) + n) >> L; -static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { - // compute L = ceil(log2(d)); - L = 0; - while (L < 32 && (uint32_t{ 1 } << L) < d) { - L++; - } - - mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); -} - -constexpr int conv_shapes[][NUM_VARIANTS] = { - { 128, 64, 32 }, // BS_K - { 16, 32, 16 }, // BS_CRS - { 128, 32, 256 }, // BS_NPQ - { 8, 4, 8 } // TS_K - //{8, 8, 8} // TS_NPQ // Option 2 -}; - -int get_sm_count() { - int device; - cudaGetDevice(&device); - - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); - return sm_count; -} - -template -void ggml_cuda_op_conv_2d_variant(ggml_backend_cuda_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - const Params & p) { - // Tile size calculation options: - // Option 1: fix block size and all tile sizes except TS_NPQ as it is the free parameter (used in the Vulkan backend). - // Option 2: fix all tile sizes and block size is the free parameter. - const uint32_t WG_SIZE = 256; // Option 1 - - const uint32_t BS_K = conv_shapes[0][CONV_SHAPE]; - const uint32_t BS_CRS = conv_shapes[1][CONV_SHAPE]; - const uint32_t BS_NPQ = conv_shapes[2][CONV_SHAPE]; - const uint32_t TS_K = conv_shapes[3][CONV_SHAPE]; - //const uint32_t TS_NPQ = sh[4][CONV_SHAPE]; // Option 2 - const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; - - // Some architectures can use 128-bit loads that might be more efficient. - const uint32_t VEC_SIZE = TS_NPQ >= 4 ? 4 : 1; - - //const uint32_t WG_SIZE = (BS_K*BS_NPQ) / (TS_K*TS_NPQ); // Option 2 - - // Kernel runtime parameters - int64_t NPQ = p.N * p.OW * p.OH; - uint32_t NB_K = CEIL_DIV(p.Cout, BS_K); - uint32_t NB_NPQ = CEIL_DIV(NPQ, BS_NPQ); - - cudaMemcpyToSymbol(dp, &p, sizeof(Params)); - - // Kernel arguments - float * src0_data = (float *) src0->data; - float * src1_data = (float *) src1->data; - float * dst_data = (float *) dst->data; - - dim3 gridDim(NB_K, NB_NPQ); - dim3 blockDim(WG_SIZE); - cudaStream_t stream = ctx.stream(); - - mm - <<>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data); -} - -void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - // Initialize kernel variants - - using Conv2DFuncPtr = - void (*)(ggml_backend_cuda_context &, ggml_tensor *, ggml_tensor *, ggml_tensor *, const Params &); - - Conv2DFuncPtr conv2d_variants[NUM_VARIANTS]; - - conv2d_variants[CONV_SHAPE_128x128] = &ggml_cuda_op_conv_2d_variant; - conv2d_variants[CONV_SHAPE_64x32] = &ggml_cuda_op_conv_2d_variant; - conv2d_variants[CONV_SHAPE_32x256] = &ggml_cuda_op_conv_2d_variant; - - // Parse op input, prepare kernel input - - ggml_tensor * src0 = dst->src[0]; - ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - GGML_ASSERT(nb0 == sizeof(float)); - - Params p{}; - p.Cout = static_cast(ne03); - p.Cin = static_cast(ne02); - p.N = static_cast(ne13); - - p.KW = static_cast(ne00); - p.KH = static_cast(ne01); - p.W = static_cast(ne10); - p.H = static_cast(ne11); - p.OW = static_cast(ne0); - p.OH = static_cast(ne1); - - p.s0 = static_cast(dst->op_params[0]); - p.s1 = static_cast(dst->op_params[1]); - p.p0 = static_cast(dst->op_params[2]); - p.p1 = static_cast(dst->op_params[3]); - p.d0 = static_cast(dst->op_params[4]); - p.d1 = static_cast(dst->op_params[5]); - - p.nb01 = static_cast(nb01 / nb00); - p.nb02 = static_cast(nb02 / nb00); - p.nb03 = static_cast(nb03 / nb00); - - p.nb11 = static_cast(nb11 / nb10); - p.nb12 = static_cast(nb12 / nb10); - p.nb13 = static_cast(nb13 / nb10); - - p.nb1 = static_cast(nb1 / nb0); - p.nb2 = static_cast(nb2 / nb0); - p.nb3 = static_cast(nb3 / nb0); - - init_fastdiv_values(p.KW, p.KWmp, p.KWL); - init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL); - init_fastdiv_values(p.OW, p.OWmp, p.OWL); - init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL); - - GGML_ASSERT(ne03 == ne2); - GGML_ASSERT(ne02 == ne12); - - // Select the proper variant based on problem size and device parameters (sm count) - - // Problem size (Cout x NPQ) - std::array elements = { p.Cout, p.N * p.OW * p.OH, 1 }; - - const uint32_t sm_count = get_sm_count(); - - uint32_t variant_ntiles[NUM_VARIANTS]; - - for (int var_id = 0; var_id < NUM_VARIANTS; var_id++) { - const uint32_t ntilesy = ceil_div(elements[0], conv_shapes[var_id][0]); // CEIL_DIV(Cout, NB_K) - const uint32_t ntilesx = ceil_div(elements[1], conv_shapes[var_id][2]); // CEIL_DIV(NPQ, NB_NPQ) - variant_ntiles[var_id] = ntilesy * ntilesx; - } - - uint32_t selected_variant_id = CONV_SHAPE_128x128; - - if (elements[0] > 64 && variant_ntiles[CONV_SHAPE_128x128] >= sm_count * 2) { - selected_variant_id = CONV_SHAPE_128x128; - } else if (elements[0] <= 32 && variant_ntiles[CONV_SHAPE_32x256] >= sm_count * 2) { - selected_variant_id = CONV_SHAPE_32x256; - } else { - selected_variant_id = CONV_SHAPE_64x32; - } - - conv2d_variants[selected_variant_id](ctx, src0, src1, dst, p); -} diff --git a/ggml/src/ggml-cuda/conv2d-mm.cuh b/ggml/src/ggml-cuda/conv2d-mm.cuh deleted file mode 100644 index fc547397e53df..0000000000000 --- a/ggml/src/ggml-cuda/conv2d-mm.cuh +++ /dev/null @@ -1,9 +0,0 @@ -#include "common.cuh" - -#define CONV_SHAPE_128x128 0 -#define CONV_SHAPE_64x32 1 -#define CONV_SHAPE_32x256 2 - -#define NUM_VARIANTS 3 - -void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/conv2d-tensor-core.cu b/ggml/src/ggml-cuda/conv2d-tensor-core.cu new file mode 100644 index 0000000000000..728c332e1e5cd --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-tensor-core.cu @@ -0,0 +1,373 @@ +#include "common.cuh" +#include "conv2d-tensor-core.cuh" +#include "convert.cuh" +#include "mma.cuh" + +#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) + +static uint32_t ceil_div(uint32_t M, uint32_t N); +static int get_sm_count(); + +uint32_t ceil_div(uint32_t M, uint32_t N) { + return (M + N - 1) / N; +} + +__align__(16) struct Params { + uint32_t IW, IH; + uint32_t OW, OH; + uint32_t KW, KH; + uint32_t ST_X, ST_Y; + uint32_t PD_X, PD_Y; + uint32_t DL_X, DL_Y; + uint32_t Cin, Cout; + uint32_t B; + // helpers + uint32_t IC_KH_KW, N_OH_OW; + uint32_t IK_TOTAL, IN_TOTAL; + + uint32_t KWmp; + uint32_t KWL; + uint32_t KWKHmp; + uint32_t KWKHL; + uint32_t OWmp; + uint32_t OWL; + uint32_t OWOHmp; + uint32_t OWOHL; +}; + +__constant__ __device__ Params P; + +// see init_fastdiv_values in ggml-vulkan.cpp +__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) { + return (__umulhi(n, mp) + n) >> L; +} + +__device__ struct T_ICKHKW { + const uint32_t ic, kh, kw; +}; + +__device__ struct T_NOHOW { + const uint32_t B, OH, OW; +}; + +__device__ __forceinline__ static int32_t calculate_input_coord(const uint32_t & out_coord, + const uint32_t & kern_coord, + const uint32_t & stride, + const uint32_t & dilation, + const uint32_t & padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ __forceinline__ static uint32_t input_index(const uint32_t & n, + const uint32_t & c, + const uint32_t & y, + const uint32_t & x) { + return n * (P.Cin * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ __forceinline__ static uint32_t kernel_index(const uint32_t & c_out, + const uint32_t & c_in, + const uint32_t & ky, + const uint32_t & kx) { + return c_out * (P.Cin * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ __forceinline__ static uint32_t output_index(const uint32_t & n, + const uint32_t & c, + const uint32_t & y, + const uint32_t & x) { + return n * (P.Cout * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ __forceinline__ static T_ICKHKW unpack_ickhkw(const uint32_t & idx) { + // const uint32_t ic = idx / (P.KW * P.KH); + const uint32_t ic = fastdiv(idx, P.KWKHmp, P.KWKHL); + const uint32_t r = idx - ic * (P.KW * P.KH); + // const uint32_t kh = r / P.KW; + const uint32_t kh = fastdiv(r, P.KWmp, P.KWL); + const uint32_t kw = r - kh * P.KW; + return T_ICKHKW{ ic, kh, kw }; + } + + __device__ __forceinline__ static T_NOHOW unpack_nohow(const uint32_t & idx) { + // const uint32_t n = idx / (P.OH * P.OW); + const uint32_t n = fastdiv(idx, P.OWOHmp, P.OWOHL); + const uint32_t r = idx - n * (P.OH * P.OW); + // const uint32_t oh = r / P.OW; + const uint32_t oh = fastdiv(r, P.OWmp, P.OWL); + const uint32_t ow = r - oh * P.OW; + return T_NOHOW{ n, oh, ow }; + } +}; + +using namespace ggml_cuda_mma; + +typedef tile tile_a; +typedef tile tile_b; +typedef tile tile_acc; + +// --> conv_2d kernel modified to function as a matmul +template +__global__ void __launch_bounds__(NUM_WARPS * WARP_SIZE) conv2d_tensor_cores_kernel(const float * __restrict__ IN, + const half * __restrict__ IK, + float * __restrict__ Out) { + const uint32_t warpId = threadIdx.y; + const uint32_t block_tid = threadIdx.y * blockDim.x + threadIdx.x; + + const uint32_t OC_BASE = blockIdx.x * BS_OC; + const uint32_t NOHOW_BASE = blockIdx.y * BS_NOHOW; + + __shared__ half A_sh[BS_OC * BS_ICKHKW]; + __shared__ half B_sh[BS_NOHOW * BS_ICKHKW]; + + const uint32_t Ar = block_tid / BS_ICKHKW; + const uint32_t Ac = block_tid % BS_ICKHKW; + + constexpr uint32_t ArpWg = WG_SIZE / BS_ICKHKW; + + const uint32_t Br = block_tid / BS_ICKHKW; + const uint32_t Bc = block_tid % BS_ICKHKW; + + constexpr uint32_t BrpWg = WG_SIZE / BS_ICKHKW; + + tile_a a_frag; + tile_b b_frag; + tile_acc c_frag[NUM_TILES_PER_WARP]; + +#pragma unroll + for (uint32_t id_ickhkw = 0; id_ickhkw < P.IC_KH_KW; id_ickhkw += BS_ICKHKW) { + const uint32_t cached_ickhkw_idx = id_ickhkw + Ac; + + const T_ICKHKW ickhkw = layout::unpack_ickhkw(cached_ickhkw_idx); + +#pragma unroll + for (uint32_t i = 0; i < BS_OC; i += ArpWg) { + const uint32_t gOC = OC_BASE + (Ar + i); + half val = IK[min(cached_ickhkw_idx + (gOC * P.IC_KH_KW), P.IK_TOTAL - 1)]; + + if (((cached_ickhkw_idx) >= P.IC_KH_KW) || (gOC >= P.Cout)) { + val = 0.0f; + } + A_sh[(i + Ar) * BS_ICKHKW + Ac] = val; + } +#pragma unroll + for (uint32_t i = 0; i < BS_NOHOW; i += BrpWg) { + const uint32_t gNOHOW = NOHOW_BASE + (i + Br); + half val = 0.0f; + const T_NOHOW nohow = layout::unpack_nohow(gNOHOW); + + const int32_t in_y = calculate_input_coord(nohow.OH, ickhkw.kh, P.ST_Y, P.DL_Y, P.PD_Y); + const int32_t in_x = calculate_input_coord(nohow.OW, ickhkw.kw, P.ST_X, P.DL_X, P.PD_X); + + val = ggml_cuda_cast( + IN[min(max(layout::input_index(nohow.B, ickhkw.ic, in_y, in_x), 0), P.IN_TOTAL - 1)]); + if (in_y < 0 || in_y >= P.IH || in_x < 0 || in_x >= P.IW) { + val = 0.0f; + } + B_sh[(i + Br) * BS_ICKHKW + Bc] = val; + } + __syncthreads(); + +#pragma unroll + for (uint32_t i = 0; i < NUM_TILES_PER_WARP; i++) { + const uint32_t warp = warpId * NUM_TILES_PER_WARP + i; + const uint32_t WARP_OC = warp / NUM_WARPS_NOHOW; + const uint32_t WARP_NOHOW = warp % NUM_WARPS_NOHOW; + + const half * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const half * B_warp_base = B_sh + WARP_NOHOW * WMMA_N * BS_ICKHKW; + +#pragma unroll + for (uint32_t k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + const half * A_k_ptr = A_warp_base + k_tile; + const half * B_k_ptr = B_warp_base + k_tile; + load_ldmatrix(a_frag, (const half2 *) A_k_ptr, BS_ICKHKW / 2); + load_ldmatrix(b_frag, (const half2 *) B_k_ptr, BS_ICKHKW / 2); + ggml_cuda_mma::mma(c_frag[i], a_frag, b_frag); + } + } + + __syncthreads(); + } + +#pragma unroll + for (uint32_t i = 0; i < NUM_TILES_PER_WARP; i++) { + const uint32_t warp = warpId * NUM_TILES_PER_WARP + i; + const uint32_t WARP_OC = warp / NUM_WARPS_NOHOW; + const uint32_t WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const uint32_t OC_WARP_BASE = OC_BASE + WARP_OC * WMMA_M; + const uint32_t NOHOW_WARP_BASE = NOHOW_BASE + WARP_NOHOW * WMMA_N; +#pragma unroll + for (uint32_t l = 0; l < tile_acc::ne; ++l) { + const uint32_t e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const uint32_t m = e / WMMA_N; + const uint32_t n = e % WMMA_N; + + const uint32_t oc = OC_WARP_BASE + m; + const uint32_t nohow = NOHOW_WARP_BASE + n; + const T_NOHOW out_nohow = layout::unpack_nohow(nohow); + if (oc < P.Cout && nohow < (P.N_OH_OW)) { + Out[layout::output_index(out_nohow.B, oc, out_nohow.OH, out_nohow.OW)] = c_frag[i].x[l]; + } + } + } +} + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); +} + +constexpr int conv_shapes[][NUM_VARIANTS] = { + { 128, 64, 32 }, // BS_OC + { 16, 32, 16 }, // BS_ICKHKW + { 128, 32, 256 }, // BS_NOHOW +}; + +int get_sm_count() { + int device; + cudaGetDevice(&device); + + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); + return sm_count; +} + +template +void conv_2d_tensor_core(const float * src0, + const half * src1, + float * dst, + const Params & p, + const cudaStream_t & st) { + constexpr uint32_t WG_SIZE = 256; + static_assert(WG_SIZE % WARP_SIZE == 0); + + constexpr uint32_t NUM_WARPS = WG_SIZE / WARP_SIZE; + + constexpr uint32_t BS_OC = conv_shapes[0][CONV_SHAPE]; + constexpr uint32_t BS_ICKHKW = conv_shapes[1][CONV_SHAPE]; + constexpr uint32_t BS_NOHOW = conv_shapes[2][CONV_SHAPE]; + + static_assert(BS_OC % WMMA_M == 0 && BS_NOHOW % WMMA_N == 0); + + constexpr uint32_t NUM_WARPS_NEED = (BS_OC * BS_NOHOW) / (WMMA_M * WMMA_N); + constexpr uint32_t NUM_WARPS_NOHOW = BS_NOHOW / WMMA_N; + + static_assert(NUM_WARPS_NEED % NUM_WARPS == 0); + + constexpr uint32_t NUM_TILES_PER_WARP = NUM_WARPS_NEED / NUM_WARPS; + + const int64_t NOHOW = p.B * p.OW * p.OH; + const uint32_t NB_OC = CEIL_DIV(p.Cout, BS_OC); + const uint32_t NB_NOHOW = CEIL_DIV(NOHOW, BS_NOHOW); + + cudaMemcpyToSymbolAsync(P, &p, sizeof(Params), 0, cudaMemcpyHostToDevice, st); + + dim3 gridDim(NB_OC, NB_NOHOW); + constexpr dim3 blockDim(WARP_SIZE, NUM_WARPS); + + conv2d_tensor_cores_kernel<<>>(src0, src1, dst); +} + +void ggml_cuda_op_conv2d_tensor_core(const uint32_t & IW, + const uint32_t & IH, + const uint32_t & OW, + const uint32_t & OH, + const uint32_t & KW, + const uint32_t & KH, + const uint32_t & ST_X, + const uint32_t & ST_Y, + const uint32_t & PD_X, + const uint32_t & PD_Y, + const uint32_t & DL_X, + const uint32_t & DL_Y, + const uint32_t & IC, + const uint32_t & OC, + const uint32_t & B, + const float * IN, + const half * IK, + float * output, + const cudaStream_t & st) { + using Conv2DFuncPtr = void (*)(const float *, const half *, float *, const Params &, const cudaStream_t &); + + Conv2DFuncPtr conv2d_variants[NUM_VARIANTS]; + + conv2d_variants[CONV_SHAPE_128x128] = &conv_2d_tensor_core; + conv2d_variants[CONV_SHAPE_64x32] = &conv_2d_tensor_core; + conv2d_variants[CONV_SHAPE_32x256] = &conv_2d_tensor_core; + + Params p{}; + p.Cout = OC; + p.Cin = IC; + p.B = B; + + p.KW = KW; + p.KH = KH; + p.IW = IW; + p.IH = IH; + p.OW = OW; + p.OH = OH; + + p.ST_X = ST_X; + p.ST_Y = ST_Y; + p.PD_X = PD_X; + p.PD_Y = PD_Y; + p.DL_X = DL_X; + p.DL_Y = DL_Y; + p.IC_KH_KW = IC * KH * KW; + p.IK_TOTAL = p.IC_KH_KW * p.Cout; + + p.N_OH_OW = B * OH * OW; + p.IN_TOTAL = B * IC * IH * IW; + + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL); + + // Problem size (Cout x NOHOW) + std::array elements = { p.Cout, p.B * p.OW * p.OH, 1 }; + + const uint32_t sm_count = get_sm_count(); + + uint32_t variant_ntiles[NUM_VARIANTS]; + + for (int var_id = 0; var_id < NUM_VARIANTS; var_id++) { + const uint32_t ntilesy = ceil_div(elements[0], conv_shapes[var_id][0]); // CEIL_DIV(Cout, NB_OC) + const uint32_t ntilesx = ceil_div(elements[1], conv_shapes[var_id][2]); // CEIL_DIV(NOHOW, NB_NOHOW) + variant_ntiles[var_id] = ntilesy * ntilesx; + } + + uint32_t selected_variant_id = CONV_SHAPE_128x128; + + if (elements[0] > 64 && variant_ntiles[CONV_SHAPE_128x128] >= sm_count * 2) { + selected_variant_id = CONV_SHAPE_128x128; + } else if (elements[0] <= 32 && variant_ntiles[CONV_SHAPE_32x256] >= sm_count * 2) { + selected_variant_id = CONV_SHAPE_32x256; + } else { + selected_variant_id = CONV_SHAPE_64x32; + } + + conv2d_variants[selected_variant_id](IN, IK, output, p, st); +} diff --git a/ggml/src/ggml-cuda/conv2d-tensor-core.cuh b/ggml/src/ggml-cuda/conv2d-tensor-core.cuh new file mode 100644 index 0000000000000..4e1073b377637 --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-tensor-core.cuh @@ -0,0 +1,31 @@ +#include "common.cuh" + +#define CONV_SHAPE_128x128 0 +#define CONV_SHAPE_64x32 1 +#define CONV_SHAPE_32x256 2 + +#define NUM_VARIANTS 3 + +#define WMMA_M 16 +#define WMMA_N 16 +#define WMMA_K 16 + +void ggml_cuda_op_conv2d_tensor_core(const uint32_t & IW, + const uint32_t & IH, + const uint32_t & OW, + const uint32_t & OH, + const uint32_t & KW, + const uint32_t & KH, + const uint32_t & ST_X, + const uint32_t & ST_Y, + const uint32_t & PD_X, + const uint32_t & PD_Y, + const uint32_t & DL_X, + const uint32_t & DL_Y, + const uint32_t & IC, + const uint32_t & OC, + const uint32_t & B, + const float * IN, + const half * IK, + float * output, + const cudaStream_t & st); diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index deaca3d648d5a..43310d967d0a2 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,388 +1,123 @@ +#include "conv2d-tensor-core.cuh" #include "conv2d.cuh" #include "convert.cuh" -#include - struct conv_params { - const int IW, IH; - const int OW, OH; - const int KW, KH; - const int ST_X, ST_Y; - const int PD_X, PD_Y; - const int DL_X, DL_Y; - const int IC, OC; - const int B; - // helpers - const int IC_KH_KW, N_OH_OW; + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; }; -__device__ __forceinline__ static int calculate_input_coord(int out_coord, - int kern_coord, - int stride, - int dilation, - int padding) { +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, - int & ic, - int & kh, - int & kw, - const conv_params & P) { - ic = idx / (P.KW * P.KH); - int r = idx - ic * (P.KW * P.KH); - kh = r / P.KW; - kw = r - kh * P.KW; - } - - __device__ __forceinline__ static void unpack_nohow(int64_t idx, - int & n, - int & oh, - int & ow, - const conv_params & P) { - n = idx / (P.OH * P.OW); - int r = idx - n * (P.OH * P.OW); - oh = r / P.OW; - ow = r - oh * P.OW; - } -}; - -template class float_mma { - private: - static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - // for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] ... [14,0] - // lane 1 will store and compute for [0,1], [2,1], [4,1] ... [14,1] - float acc[num_acc]; - - public: - __device__ __forceinline__ float_mma() { -#pragma unroll - for (int i = 0; i < num_acc; i++) { - acc[i] = 0.0f; - } - } - - __device__ __forceinline__ void mma(const float * __restrict__ A_sh, - const float * __restrict__ B_sh, - const int strideA, - const int strideB) { - const int lane_id = threadIdx.x % WARP_SIZE; - -#pragma unroll - for (int i = 0; i < num_acc; i++) { - const int e = lane_id + i * WARP_SIZE; - if (e >= WMMA_M * WMMA_N) { - continue; - } - const int m = e / WMMA_N; - const int n = e % WMMA_N; - -#pragma unroll - for (int k = 0; k < WMMA_K; k++) { - const float a = A_sh[m * strideA + k]; - const float b = B_sh[k * strideB + n]; - acc[i] = fmaf(a, b, acc[i]); - } - } - } - - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, - float * __restrict__ OUT, - const conv_params & P) const { - const int lane_id = threadIdx.x % WARP_SIZE; - -#pragma unroll - for (int i = 0; i < num_acc; i++) { - const int e = lane_id + i * WARP_SIZE; - if (e >= WMMA_M * WMMA_N) { - continue; - } - const int m = e / WMMA_N; - const int n = e % WMMA_N; - - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; - - if (oc < P.OC && nohow < P.N_OH_OW) { - int n_, oh, ow; - layout::unpack_nohow(nohow, n_, oh, ow, P); - OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; - } - } - } -}; - -#if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE))) -# include "mma.cuh" -using namespace ggml_cuda_mma; - -typedef tile tile_a; -typedef tile tile_b; -typedef tile tile_acc; - -template class half_mma { - private: - tile_a a_frag; - tile_b b_frag; - tile_acc c_frag; - - public: - __device__ __forceinline__ half_mma() {} - - __device__ __forceinline__ void mma(const half * __restrict__ A_sh, - const half * __restrict__ B_sh, - const int strideA, - const int strideB) { - load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); - load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); - ggml_cuda_mma::mma(c_frag, a_frag, b_frag); - } - - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, - float * __restrict__ OUT, - const conv_params & P) const { -# pragma unroll - for (int l = 0; l < tile_acc::ne; ++l) { - const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - if (e >= WMMA_M * WMMA_N) { - continue; - } - const int m = e / WMMA_N; - const int n = e % WMMA_N; - - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; - - if (oc < P.OC && nohow < (P.N_OH_OW)) { - int n_, oh, ow; - layout::unpack_nohow(nohow, n_, oh, ow, P); - OUT[layout::output_index(n_, oc, oh, ow, P)] = c_frag.x[l]; - } - } + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); } }; -#else +template +static __global__ void conv2d_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; -template class half_mma { - public: - static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - // eg. for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] .. [14,0] - float acc[num_acc]; - - __device__ __forceinline__ half_mma() { -# pragma unroll - for (int i = 0; i < num_acc; i++) { - acc[i] = 0.0f; - } + if (global_idx >= P.TOTAL) { + return; } - __device__ __forceinline__ void mma(const half * __restrict__ A_sh, - const half * __restrict__ B_sh, - const int strideA, - const int strideB) { - const int lane_id = threadIdx.x % WARP_SIZE; - -# pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + int64_t n, c_out, out_y, out_x; + Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); -# pragma unroll - for (int k = 0; k < WMMA_K; k++) { - const half a = A_sh[m * strideA + k]; - const half b = B_sh[k * strideB + n]; - acc[i] = fmaf(__half2float(a), __half2float(b), acc[i]); - } - } - } + float acc = 0.0f; - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, - float * __restrict__ OUT, - const conv_params & P) const { - const int lane_id = threadIdx.x % WARP_SIZE; + for (int64_t c_in = 0; c_in < P.IC; ++c_in) { + kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); -# pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { + const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { + const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); - if (oc < P.OC && nohow < P.N_OH_OW) { - int n_, oh, ow; - layout::unpack_nohow(nohow, n_, oh, ow, P); - OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; + const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; + const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * ggml_cuda_cast(kernel_val)); } } } -}; - -#endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) - -template class mma, int num_warps> -__global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const float * __restrict__ IN, - const T * __restrict__ IK, - float * __restrict__ Out, - const conv_params P) { - extern __shared__ unsigned char smem_raw[]; - - const int warpId = threadIdx.y; - const int linear_tid = threadIdx.y * blockDim.x + threadIdx.x; - - const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const int NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); - const int NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - - const int NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; - mma acc[NUM_TILES_PER_WARP]; - - const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; - const int BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; - - const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; - - unsigned char * ptr = smem_raw; - - const int A_total = BS_OC * BS_ICKHKW; - const int B_total = BS_ICKHKW * BS_NOHOW; - - size_t offsetA = (size_t) A_total * sizeof(T); - T * A_sh = reinterpret_cast(ptr); - ptr += offsetA; - - size_t offsetB = (size_t) B_total * sizeof(T); - T * B_sh = reinterpret_cast(ptr); - ptr += offsetB; - - for (int t = 0; t < NUM_IC_TILES; ++t) { -#pragma unroll - for (int tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { - const int row = tid / BS_ICKHKW; - const int col = tid % BS_ICKHKW; - - const int shared_oc = BLOCK_OC_BASE + row; - const int shared_ickhkw = t * BS_ICKHKW + col; - - T val = ggml_cuda_cast(0); - if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { - int ic, kh, kw; - layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - - const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; - } - A_sh[row * BS_ICKHKW + col] = val; - } -#pragma unroll - for (int tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { - const int brow = tid / BS_NOHOW; - const int bcol = tid % BS_NOHOW; - - const int IC_KH_KW_IDX = t * BS_ICKHKW + brow; - const int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; - - T val = ggml_cuda_cast(0); - if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { - int n, oh, ow; - int ic, kh, kw; - layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); - layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); - const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); - const int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); - if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) { - const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P); - val = ggml_cuda_cast(IN[in_idx]); - } - } - B_sh[brow * BS_NOHOW + bcol] = val; - } - - __syncthreads(); - -#pragma unroll - for (int i = 0; i < NUM_TILES_PER_WARP; i++) { - const int warp = warpId + i * num_warps; - if (warp >= NUM_WARPS_NEED) { - continue; - } - const int WARP_OC = warp / NUM_WARPS_NOHOW; - const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; - - const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; - const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; - -#pragma unroll - for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { - const T * A_k_ptr = A_warp_base + k_tile; - const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; - acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); - } - } - __syncthreads(); - } - -#pragma unroll - for (int i = 0; i < NUM_TILES_PER_WARP; i++) { - const int warp = warpId + i * num_warps; - if (warp >= NUM_WARPS_NEED) { - continue; - } - const int WARP_OC = warp / NUM_WARPS_NOHOW; - const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; - const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; - acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); - } + // [N, OC, OH, OW] + output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; } -template class mma> -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { - GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); - GGML_ASSERT(BS_ICKHKW % WMMA_K == 0); - - const int NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; - const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; - - constexpr int NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; - - const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); - const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); - const size_t shared_bytes = A_bytes + B_bytes; - - dim3 grid(NUM_BL, 1, 1); - dim3 block(WARP_SIZE, NUM_WARPS, 1); - - conv2d_kernel<<>>(X_D, K_D, Y_D, P); +template +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); -} +// static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +// conv2d_cuda(X_D, K_D, Y_D, P, st); +// } -static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -421,13 +156,13 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - const int IC_KH_KW = IC * KH * KW; - const int N_OH_OW = B * OH * OW; - const conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, - PD_Y, DL_X, DL_Y, IC, OC, B, IC_KH_KW, N_OH_OW }; + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; if (kernel->type == GGML_TYPE_F16) { - conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); + ggml_cuda_op_conv2d_tensor_core(IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, X_D, + (const half *) K_D, Y_D, st); + // conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); } else { conv2d_cuda_f32(X_D, K_D, Y_D, params, st); } diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index 3a1a5f28b572c..ce4802c7ed797 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,14 +1,5 @@ #pragma once #include "common.cuh" -#define BS_OC 32 -#define BS_ICKHKW 16 -#define BS_NOHOW 32 - -#define WMMA_M 16 -#define WMMA_N 16 -#define WMMA_K 16 - -#define CUDA_CONV2D_BLOCK_SIZE 128 - +#define CUDA_CONV2D_BLOCK_SIZE 256 void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 424faae02877f..94ab1ec0f5a90 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,7 +13,6 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" -#include "ggml-cuda/conv2d-mm.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2623,13 +2622,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_im2col_3d(ctx, dst); break; case GGML_OP_CONV_2D: - if (!getenv("GGML_CUDA_USE_LEGACY_CONV") && - (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && - dst->type == GGML_TYPE_F32)) { - ggml_cuda_op_conv2d_mm(ctx, dst); - } else { - ggml_cuda_op_conv2d(ctx, dst); - } + ggml_cuda_op_conv2d(ctx, dst); break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); From e1ab1f044d270b79f7d55c36e9bebfff503a4821 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Sat, 1 Nov 2025 22:41:09 +0530 Subject: [PATCH 11/12] using internal fastdiv and sm count function --- ggml/src/ggml-cuda/conv2d-tensor-core.cu | 97 +++++++----------------- 1 file changed, 29 insertions(+), 68 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-tensor-core.cu b/ggml/src/ggml-cuda/conv2d-tensor-core.cu index 728c332e1e5cd..f6798002aedf3 100644 --- a/ggml/src/ggml-cuda/conv2d-tensor-core.cu +++ b/ggml/src/ggml-cuda/conv2d-tensor-core.cu @@ -3,13 +3,8 @@ #include "convert.cuh" #include "mma.cuh" -#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) - -static uint32_t ceil_div(uint32_t M, uint32_t N); -static int get_sm_count(); - -uint32_t ceil_div(uint32_t M, uint32_t N) { - return (M + N - 1) / N; +constexpr static size_t ceil_div(const size_t m, const size_t n) { + return (m + n - 1) / n; } __align__(16) struct Params { @@ -25,23 +20,15 @@ __align__(16) struct Params { uint32_t IC_KH_KW, N_OH_OW; uint32_t IK_TOTAL, IN_TOTAL; - uint32_t KWmp; - uint32_t KWL; - uint32_t KWKHmp; - uint32_t KWKHL; - uint32_t OWmp; - uint32_t OWL; - uint32_t OWOHmp; - uint32_t OWOHL; + // fastdiv + uint3 KW_fastdiv; + uint3 KWKH_fastdiv; + uint3 OW_fastdiv; + uint3 OWOH_fastdiv; }; __constant__ __device__ Params P; -// see init_fastdiv_values in ggml-vulkan.cpp -__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) { - return (__umulhi(n, mp) + n) >> L; -} - __device__ struct T_ICKHKW { const uint32_t ic, kh, kw; }; @@ -82,20 +69,20 @@ struct whcn_layout { __device__ __forceinline__ static T_ICKHKW unpack_ickhkw(const uint32_t & idx) { // const uint32_t ic = idx / (P.KW * P.KH); - const uint32_t ic = fastdiv(idx, P.KWKHmp, P.KWKHL); + const uint32_t ic = fastdiv(idx, P.KWKH_fastdiv); const uint32_t r = idx - ic * (P.KW * P.KH); // const uint32_t kh = r / P.KW; - const uint32_t kh = fastdiv(r, P.KWmp, P.KWL); + const uint32_t kh = fastdiv(r, P.KW_fastdiv); const uint32_t kw = r - kh * P.KW; return T_ICKHKW{ ic, kh, kw }; } __device__ __forceinline__ static T_NOHOW unpack_nohow(const uint32_t & idx) { // const uint32_t n = idx / (P.OH * P.OW); - const uint32_t n = fastdiv(idx, P.OWOHmp, P.OWOHL); + const uint32_t n = fastdiv(idx, P.OWOH_fastdiv); const uint32_t r = idx - n * (P.OH * P.OW); // const uint32_t oh = r / P.OW; - const uint32_t oh = fastdiv(r, P.OWmp, P.OWL); + const uint32_t oh = fastdiv(r, P.OW_fastdiv); const uint32_t ow = r - oh * P.OW; return T_NOHOW{ n, oh, ow }; } @@ -113,7 +100,6 @@ template @@ -222,43 +208,18 @@ __global__ void __launch_bounds__(NUM_WARPS * WARP_SIZE) conv2d_tensor_cores_ker } } -// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. -// Precompute mp (m' in the paper) and L such that division -// can be computed using a multiply (high 32b of 64b result) -// and a shift: -// -// n/d = (mulhi(n, mp) + n) >> L; -static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { - // compute L = ceil(log2(d)); - L = 0; - while (L < 32 && (uint32_t{ 1 } << L) < d) { - L++; - } - - mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); -} - constexpr int conv_shapes[][NUM_VARIANTS] = { { 128, 64, 32 }, // BS_OC { 16, 32, 16 }, // BS_ICKHKW { 128, 32, 256 }, // BS_NOHOW }; -int get_sm_count() { - int device; - cudaGetDevice(&device); - - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); - return sm_count; -} - template -void conv_2d_tensor_core(const float * src0, - const half * src1, - float * dst, - const Params & p, - const cudaStream_t & st) { +static void conv_2d_tensor_core(const float * src0, + const half * src1, + float * dst, + const Params & p, + const cudaStream_t & st) { constexpr uint32_t WG_SIZE = 256; static_assert(WG_SIZE % WARP_SIZE == 0); @@ -270,24 +231,24 @@ void conv_2d_tensor_core(const float * src0, static_assert(BS_OC % WMMA_M == 0 && BS_NOHOW % WMMA_N == 0); - constexpr uint32_t NUM_WARPS_NEED = (BS_OC * BS_NOHOW) / (WMMA_M * WMMA_N); + constexpr uint32_t NUM_TILES_TOTAL = (BS_OC * BS_NOHOW) / (WMMA_M * WMMA_N); constexpr uint32_t NUM_WARPS_NOHOW = BS_NOHOW / WMMA_N; - static_assert(NUM_WARPS_NEED % NUM_WARPS == 0); + static_assert(NUM_TILES_TOTAL % NUM_WARPS == 0); - constexpr uint32_t NUM_TILES_PER_WARP = NUM_WARPS_NEED / NUM_WARPS; + constexpr uint32_t NUM_TILES_PER_WARP = NUM_TILES_TOTAL / NUM_WARPS; const int64_t NOHOW = p.B * p.OW * p.OH; - const uint32_t NB_OC = CEIL_DIV(p.Cout, BS_OC); - const uint32_t NB_NOHOW = CEIL_DIV(NOHOW, BS_NOHOW); + const uint32_t NB_OC = ceil_div(p.Cout, BS_OC); + const uint32_t NB_NOHOW = ceil_div(NOHOW, BS_NOHOW); cudaMemcpyToSymbolAsync(P, &p, sizeof(Params), 0, cudaMemcpyHostToDevice, st); dim3 gridDim(NB_OC, NB_NOHOW); constexpr dim3 blockDim(WARP_SIZE, NUM_WARPS); - conv2d_tensor_cores_kernel<<>>(src0, src1, dst); + conv2d_tensor_cores_kernel<<>>(src0, src1, dst); } void ggml_cuda_op_conv2d_tensor_core(const uint32_t & IW, @@ -341,15 +302,15 @@ void ggml_cuda_op_conv2d_tensor_core(const uint32_t & IW, p.N_OH_OW = B * OH * OW; p.IN_TOTAL = B * IC * IH * IW; - init_fastdiv_values(p.KW, p.KWmp, p.KWL); - init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL); - init_fastdiv_values(p.OW, p.OWmp, p.OWL); - init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL); + p.KW_fastdiv = init_fastdiv_values(p.KW); + p.KWKH_fastdiv = init_fastdiv_values(p.KW * p.KH); + p.OW_fastdiv = init_fastdiv_values(p.OW); + p.OWOH_fastdiv = init_fastdiv_values(p.OW * p.OH); // Problem size (Cout x NOHOW) - std::array elements = { p.Cout, p.B * p.OW * p.OH, 1 }; + std::array elements = { p.Cout, p.B * p.OW * p.OH }; - const uint32_t sm_count = get_sm_count(); + const uint32_t sm_count = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; uint32_t variant_ntiles[NUM_VARIANTS]; From 2e1c8819bcb2ac70686d5d8e70d62d88d97ae066 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Tue, 4 Nov 2025 23:42:02 +0530 Subject: [PATCH 12/12] resolves @Green-Sky suggestions --- ggml/src/ggml-cuda/conv2d-tensor-core.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-tensor-core.cu b/ggml/src/ggml-cuda/conv2d-tensor-core.cu index f6798002aedf3..1241368385fde 100644 --- a/ggml/src/ggml-cuda/conv2d-tensor-core.cu +++ b/ggml/src/ggml-cuda/conv2d-tensor-core.cu @@ -103,9 +103,8 @@ template -__global__ void __launch_bounds__(NUM_WARPS * WARP_SIZE) conv2d_tensor_cores_kernel(const float * __restrict__ IN, - const half * __restrict__ IK, - float * __restrict__ Out) { +__global__ void __launch_bounds__(NUM_WARPS * WARP_SIZE) + conv2d_tensor_cores_kernel(const float * __restrict__ IN, const half * __restrict__ IK, float * __restrict__ Out) { const uint32_t warpId = threadIdx.y; const uint32_t block_tid = threadIdx.y * blockDim.x + threadIdx.x;