Skip to content

Commit 9471cfe

Browse files
committed
CUDA: add conv_2d_dw
1 parent d03172c commit 9471cfe

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

ggml/src/ggml-cuda/conv2d-dw.cu

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#include "conv2d-dw.cuh"
2+
3+
struct conv_params {
4+
int in_w, in_h;
5+
int out_w, out_h;
6+
int kernel_w, kernel_h;
7+
int stride_x, stride_y;
8+
int padding_x, padding_y;
9+
int dilation_x, dilation_y;
10+
int channels, batches;
11+
};
12+
13+
struct kernel_bounds {
14+
int y_min, y_max;
15+
int x_min, x_max;
16+
};
17+
18+
__device__ inline kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
19+
kernel_bounds bounds;
20+
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
21+
bounds.y_max =
22+
min(params.kernel_h,
23+
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
24+
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
25+
bounds.x_max =
26+
min(params.kernel_w,
27+
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
28+
return bounds;
29+
}
30+
31+
__device__ inline int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
32+
return out_coord * stride + kern_coord * dilation - padding;
33+
}
34+
35+
// ───────────── Memory layout abstractions ─────────────
36+
37+
struct WHCNLayout {
38+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
39+
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
40+
}
41+
42+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
43+
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
44+
}
45+
46+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
47+
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
48+
y * params.out_w + x;
49+
}
50+
51+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
52+
int & out_x) {
53+
out_x = global_idx % params.out_w;
54+
out_y = (global_idx / params.out_w) % params.out_h;
55+
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
56+
n = global_idx / (params.out_w * params.out_h * params.channels);
57+
}
58+
};
59+
60+
struct CWHNLayout {
61+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
62+
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
63+
}
64+
65+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
66+
return (ky * params.kernel_w + kx) * params.channels + c;
67+
}
68+
69+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
70+
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
71+
x * params.channels + c;
72+
}
73+
74+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
75+
int & out_x) {
76+
c = global_idx % params.channels;
77+
out_x = (global_idx / params.channels) % params.out_w;
78+
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
79+
n = global_idx / (params.channels * params.out_w * params.out_h);
80+
}
81+
};
82+
83+
// ───────────── Generic convolution computation ─────────────
84+
85+
template <typename T, typename Layout>
86+
const __device__ inline T compute_conv2d_dw_pixel(const T * __restrict__ input, const T * __restrict__ kernel,
87+
const conv_params & params, int batch_idx, int channel_idx,
88+
int out_y_idx, int out_x_idx) {
89+
T accumulator = 0;
90+
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
91+
92+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
93+
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
94+
95+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
96+
int in_x_idx =
97+
calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
98+
99+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
100+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
101+
102+
accumulator += input_val * kernel_val;
103+
}
104+
}
105+
106+
return accumulator;
107+
}
108+
109+
// ───────────── Kernel instantiations ─────────────
110+
111+
template <typename T>
112+
__global__ void conv2d_dw_whcn_kernel(const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
113+
const int in_w, const int in_h, const int out_w, const int out_h,
114+
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
115+
const int padding_x, const int padding_y, const int dilation_x,
116+
const int dilation_y, const int channels, const int batches) {
117+
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
118+
int total_elements = batches * channels * out_h * out_w;
119+
120+
if (global_idx >= total_elements) {
121+
return;
122+
}
123+
124+
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
125+
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
126+
127+
int batch_idx, channel_idx, out_y_idx, out_x_idx;
128+
WHCNLayout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
129+
130+
T result = compute_conv2d_dw_pixel<T, WHCNLayout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
131+
out[WHCNLayout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
132+
}
133+
134+
template <typename T>
135+
__global__ void conv_2d_dw_cwhn_kernel(const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
136+
const int in_w, const int in_h, const int out_w, const int out_h,
137+
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
138+
const int padding_x, const int padding_y, const int dilation_x,
139+
const int dilation_y, const int channels, const int batches) {
140+
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
141+
int total_elements = batches * channels * out_h * out_w;
142+
143+
if (global_idx >= total_elements) {
144+
return;
145+
}
146+
147+
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
148+
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
149+
150+
int batch_idx, channel_idx, out_y_idx, out_x_idx;
151+
CWHNLayout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
152+
153+
const T result =
154+
compute_conv2d_dw_pixel<T, CWHNLayout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
155+
out[CWHNLayout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
156+
}
157+
158+
// ───────────── dispatcher ─────────────
159+
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160+
const ggml_tensor * kernel = dst->src[0];
161+
const ggml_tensor * input = dst->src[1];
162+
163+
// Only F32→F32 for now
164+
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
165+
const float * w_d = (const float *) kernel->data;
166+
const float * x_d = (const float *) input->data;
167+
float * y_d = (float *) dst->data;
168+
169+
const int32_t * p = (const int32_t *) dst->op_params;
170+
const int stride_x = p[0];
171+
const int stride_y = p[1];
172+
const int padding_x = p[2];
173+
const int padding_y = p[3];
174+
const int dilation_x = p[4];
175+
const int dilation_y = p[5];
176+
177+
const int in_w = input->ne[0];
178+
const int in_h = input->ne[1];
179+
const int kernel_w = kernel->ne[0];
180+
const int kernel_h = kernel->ne[1];
181+
const int out_w = dst->ne[0];
182+
const int out_h = dst->ne[1];
183+
const int channels = dst->ne[2];
184+
const int batches = dst->ne[3];
185+
186+
cudaStream_t st = ctx.stream();
187+
188+
const int total = batches * channels * out_h * out_w;
189+
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
190+
191+
if (ggml_is_contiguous(input)) {
192+
conv2d_dw_whcn_kernel<<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
193+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
194+
dilation_x, dilation_y, channels, batches);
195+
} else if (ggml_is_contiguous_channels(input)) {
196+
conv_2d_dw_cwhn_kernel<<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
197+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
198+
dilation_x, dilation_y, channels, batches);
199+
} else {
200+
// Unsupported memory layout
201+
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
202+
}
203+
}

ggml/src/ggml-cuda/conv2d-dw.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
#include "common.cuh"
3+
4+
#define CUDA_CONV2D_DW_BLOCK_SIZE 256
5+
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "ggml-cuda/clamp.cuh"
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
14+
#include "ggml-cuda/conv2d-dw.cuh"
1415
#include "ggml-cuda/convert.cuh"
1516
#include "ggml-cuda/count-equal.cuh"
1617
#include "ggml-cuda/cpy.cuh"
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_OPT_STEP_ADAMW:
23532354
ggml_cuda_opt_step_adamw(ctx, dst);
23542355
break;
2356+
case GGML_OP_CONV_2D_DW:
2357+
ggml_cuda_op_conv2d_dw(ctx, dst);
2358+
break;
23552359
default:
23562360
return false;
23572361
}
@@ -3263,6 +3267,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32633267
case GGML_OP_CROSS_ENTROPY_LOSS:
32643268
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32653269
case GGML_OP_OPT_STEP_ADAMW:
3270+
case GGML_OP_CONV_2D_DW:
32663271
return true;
32673272
default:
32683273
return false;

0 commit comments

Comments
 (0)