diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 142dd66903aaa..deaca3d648d5a 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,122 +1,388 @@ #include "conv2d.cuh" #include "convert.cuh" -struct conv_params { - const int64_t IW, IH; - const int64_t OW, OH; - const int64_t KW, KH; - const int64_t ST_X, ST_Y; - const int64_t PD_X, PD_Y; - const int64_t DL_X, DL_Y; - const int64_t IC, OC; - const int64_t B; - const int64_t TOTAL; -}; +#include -struct kernel_bounds { - int64_t y_min, y_max; - int64_t x_min, x_max; +struct conv_params { + const int IW, IH; + const int OW, OH; + const int KW, KH; + const int ST_X, ST_Y; + const int PD_X, PD_Y; + const int DL_X, DL_Y; + const int IC, OC; + const int B; + // helpers + const int IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { - return (a > b) ? a : b; -} - -__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { - return (a < b) ? a : b; -} - -__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { - kernel_bounds bounds; - bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - return bounds; -} - -__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { +__device__ __forceinline__ static int calculate_input_coord(int out_coord, + int kern_coord, + int stride, + int dilation, + int padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ static void unpack_indices(int64_t global_idx, - const conv_params & P, - int64_t & n, - int64_t & c, - int64_t & out_y, - int64_t & out_x) { - out_x = global_idx % P.OW; - out_y = (global_idx / P.OW) % P.OH; - c = (global_idx / (P.OW * P.OH)) % P.OC; - n = global_idx / (P.OW * P.OH * P.OC); + __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, + int & ic, + int & kh, + int & kw, + const conv_params & P) { + ic = idx / (P.KW * P.KH); + int r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; + } + + __device__ __forceinline__ static void unpack_nohow(int64_t idx, + int & n, + int & oh, + int & ow, + const conv_params & P) { + n = idx / (P.OH * P.OW); + int r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; } }; -template -static __global__ void conv2d_kernel(const float * __restrict__ input, - const T * __restrict__ kernel, - float * __restrict__ output, - const conv_params P) { - const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; +template class float_mma { + private: + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + // for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] ... [14,0] + // lane 1 will store and compute for [0,1], [2,1], [4,1] ... [14,1] + float acc[num_acc]; - if (global_idx >= P.TOTAL) { - return; + public: + __device__ __forceinline__ float_mma() { +#pragma unroll + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; + } } - int64_t n, c_out, out_y, out_x; - Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + __device__ __forceinline__ void mma(const float * __restrict__ A_sh, + const float * __restrict__ B_sh, + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; - float acc = 0.0f; +#pragma unroll + for (int i = 0; i < num_acc; i++) { + const int e = lane_id + i * WARP_SIZE; + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; + +#pragma unroll + for (int k = 0; k < WMMA_K; k++) { + const float a = A_sh[m * strideA + k]; + const float b = B_sh[k * strideB + n]; + acc[i] = fmaf(a, b, acc[i]); + } + } + } - for (int64_t c_in = 0; c_in < P.IC; ++c_in) { - kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { + const int lane_id = threadIdx.x % WARP_SIZE; - for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { - const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); +#pragma unroll + for (int i = 0; i < num_acc; i++) { + const int e = lane_id + i * WARP_SIZE; + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; - for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { - const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; - const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; - const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; - acc += (input_val * ggml_cuda_cast(kernel_val)); + if (oc < P.OC && nohow < P.N_OH_OW) { + int n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } } } +}; + +#if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE))) +# include "mma.cuh" +using namespace ggml_cuda_mma; + +typedef tile tile_a; +typedef tile tile_b; +typedef tile tile_acc; + +template class half_mma { + private: + tile_a a_frag; + tile_b b_frag; + tile_acc c_frag; + + public: + __device__ __forceinline__ half_mma() {} + + __device__ __forceinline__ void mma(const half * __restrict__ A_sh, + const half * __restrict__ B_sh, + const int strideA, + const int strideB) { + load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); + load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); + ggml_cuda_mma::mma(c_frag, a_frag, b_frag); + } + + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { +# pragma unroll + for (int l = 0; l < tile_acc::ne; ++l) { + const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < (P.N_OH_OW)) { + int n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + OUT[layout::output_index(n_, oc, oh, ow, P)] = c_frag.x[l]; + } + } + } +}; + +#else + +template class half_mma { + public: + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + // eg. for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] .. [14,0] + float acc[num_acc]; + + __device__ __forceinline__ half_mma() { +# pragma unroll + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const half * __restrict__ A_sh, + const half * __restrict__ B_sh, + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; + +# pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + +# pragma unroll + for (int k = 0; k < WMMA_K; k++) { + const half a = A_sh[m * strideA + k]; + const half b = B_sh[k * strideB + n]; + acc[i] = fmaf(__half2float(a), __half2float(b), acc[i]); + } + } + } + + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { + const int lane_id = threadIdx.x % WARP_SIZE; - // [N, OC, OH, OW] - output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; +# pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < P.N_OH_OW) { + int n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; + } + } + } +}; + +#endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) + +template class mma, int num_warps> +__global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const float * __restrict__ IN, + const T * __restrict__ IK, + float * __restrict__ Out, + const conv_params P) { + extern __shared__ unsigned char smem_raw[]; + + const int warpId = threadIdx.y; + const int linear_tid = threadIdx.y * blockDim.x + threadIdx.x; + + const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + + const int NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; + + mma acc[NUM_TILES_PER_WARP]; + + const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; + const int BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; + + const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + + unsigned char * ptr = smem_raw; + + const int A_total = BS_OC * BS_ICKHKW; + const int B_total = BS_ICKHKW * BS_NOHOW; + + size_t offsetA = (size_t) A_total * sizeof(T); + T * A_sh = reinterpret_cast(ptr); + ptr += offsetA; + + size_t offsetB = (size_t) B_total * sizeof(T); + T * B_sh = reinterpret_cast(ptr); + ptr += offsetB; + + for (int t = 0; t < NUM_IC_TILES; ++t) { +#pragma unroll + for (int tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { + const int row = tid / BS_ICKHKW; + const int col = tid % BS_ICKHKW; + + const int shared_oc = BLOCK_OC_BASE + row; + const int shared_ickhkw = t * BS_ICKHKW + col; + + T val = ggml_cuda_cast(0); + if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { + int ic, kh, kw; + layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); + + const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; + } + A_sh[row * BS_ICKHKW + col] = val; + } +#pragma unroll + for (int tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { + const int brow = tid / BS_NOHOW; + const int bcol = tid % BS_NOHOW; + + const int IC_KH_KW_IDX = t * BS_ICKHKW + brow; + const int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + + T val = ggml_cuda_cast(0); + if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { + int n, oh, ow; + int ic, kh, kw; + layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); + layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); + const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); + const int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); + if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) { + const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P); + val = ggml_cuda_cast(IN[in_idx]); + } + } + B_sh[brow * BS_NOHOW + bcol] = val; + } + + __syncthreads(); + +#pragma unroll + for (int i = 0; i < NUM_TILES_PER_WARP; i++) { + const int warp = warpId + i * num_warps; + if (warp >= NUM_WARPS_NEED) { + continue; + } + const int WARP_OC = warp / NUM_WARPS_NOHOW; + const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; + + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + +#pragma unroll + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + const T * A_k_ptr = A_warp_base + k_tile; + const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; + acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); + } + } + __syncthreads(); + } + +#pragma unroll + for (int i = 0; i < NUM_TILES_PER_WARP; i++) { + const int warp = warpId + i * num_warps; + if (warp >= NUM_WARPS_NEED) { + continue; + } + const int WARP_OC = warp / NUM_WARPS_NOHOW; + const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); + } } -template -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; - conv2d_kernel<<>>(X_D, K_D, Y_D, P); +template class mma> +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { + GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); + GGML_ASSERT(BS_ICKHKW % WMMA_K == 0); + + const int NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; + + constexpr int NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + + const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); + const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); + const size_t shared_bytes = A_bytes + B_bytes; + + dim3 grid(NUM_BL, 1, 1); + dim3 block(WARP_SIZE, NUM_WARPS, 1); + + conv2d_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -155,11 +421,13 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - const int64_t total = B * OC * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + const int IC_KH_KW = IC * KH * KW; + const int N_OH_OW = B * OH * OW; + const conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, + PD_Y, DL_X, DL_Y, IC, OC, B, IC_KH_KW, N_OH_OW }; if (kernel->type == GGML_TYPE_F16) { - conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); } else { conv2d_cuda_f32(X_D, K_D, Y_D, params, st); } diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index ce4802c7ed797..3a1a5f28b572c 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,5 +1,14 @@ #pragma once #include "common.cuh" -#define CUDA_CONV2D_BLOCK_SIZE 256 +#define BS_OC 32 +#define BS_ICKHKW 16 +#define BS_NOHOW 32 + +#define WMMA_M 16 +#define WMMA_N 16 +#define WMMA_K 16 + +#define CUDA_CONV2D_BLOCK_SIZE 128 + void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);