@@ -32,8 +32,6 @@ __device__ inline int calculate_input_coord(int out_coord, int kern_coord, int s
3232 return out_coord * stride + kern_coord * dilation - padding;
3333}
3434
35- // ───────────── Memory layout abstractions ─────────────
36-
3735struct whcn_layout {
3836 __device__ static int input_index (int n, int c, int y, int x, const conv_params & params) {
3937 return n * (params.channels * params.in_w * params.in_h ) + c * params.in_w * params.in_h + y * params.in_w + x;
@@ -80,40 +78,12 @@ struct cwhn_layout {
8078 }
8179};
8280
83- // ───────────── Generic convolution computation ─────────────
84-
8581template <typename T, typename Layout>
86- const __device__ inline T compute_conv2d_dw_pixel (const T * __restrict__ input, const T * __restrict__ kernel,
87- const conv_params & params, int batch_idx, int channel_idx,
88- int out_y_idx, int out_x_idx) {
89- T accumulator = 0 ;
90- kernel_bounds bounds = calculate_kernel_bounds (out_x_idx, out_y_idx, params);
91-
92- for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
93- int in_y_idx = calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y );
94-
95- for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
96- int in_x_idx =
97- calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x );
98-
99- const T input_val = input[Layout::input_index (batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
100- const T kernel_val = kernel[Layout::kernel_index (channel_idx, kern_y, kern_x, params)];
101-
102- accumulator += input_val * kernel_val;
103- }
104- }
105-
106- return accumulator;
107- }
108-
109- // ───────────── Kernel instantiations ─────────────
110-
111- template <typename T>
112- __global__ void conv2d_dw_whcn_kernel (const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
113- const int in_w, const int in_h, const int out_w, const int out_h,
114- const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
115- const int padding_x, const int padding_y, const int dilation_x,
116- const int dilation_y, const int channels, const int batches) {
82+ __global__ void conv2d_dw_kernel (const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
83+ const int in_w, const int in_h, const int out_w, const int out_h,
84+ const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
85+ const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86+ const int channels, const int batches) {
11787 int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
11888 int total_elements = batches * channels * out_h * out_w;
11989
@@ -125,42 +95,31 @@ __global__ void conv2d_dw_whcn_kernel(const T * __restrict__ in, const T * __res
12595 stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
12696
12797 int batch_idx, channel_idx, out_y_idx, out_x_idx;
128- whcn_layout ::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
98+ Layout ::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
12999
130- T result = compute_conv2d_dw_pixel<T, whcn_layout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
131- out[whcn_layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
132- }
100+ T accumulator = 0 ;
101+ kernel_bounds bounds = calculate_kernel_bounds (out_x_idx, out_y_idx, params);
133102
134- template <typename T>
135- __global__ void conv_2d_dw_cwhn_kernel (const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
136- const int in_w, const int in_h, const int out_w, const int out_h,
137- const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
138- const int padding_x, const int padding_y, const int dilation_x,
139- const int dilation_y, const int channels, const int batches) {
140- int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
141- int total_elements = batches * channels * out_h * out_w;
103+ for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
104+ int in_y_idx = calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y );
142105
143- if (global_idx >= total_elements) {
144- return ;
145- }
106+ for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
107+ int in_x_idx = calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x );
146108
147- conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
148- stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches } ;
109+ const T input_val = input[ Layout::input_index (batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
110+ const T kernel_val = kernel[ Layout::kernel_index (channel_idx, kern_y, kern_x, params)] ;
149111
150- int batch_idx, channel_idx, out_y_idx, out_x_idx;
151- cwhn_layout::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
112+ accumulator += input_val * kernel_val;
113+ }
114+ }
152115
153- const T result =
154- compute_conv2d_dw_pixel<T, cwhn_layout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
155- out[cwhn_layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
116+ output[Layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
156117}
157118
158- // ───────────── dispatcher ─────────────
159119void ggml_cuda_op_conv2d_dw (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160120 const ggml_tensor * kernel = dst->src [0 ];
161121 const ggml_tensor * input = dst->src [1 ];
162122
163- // Only F32→F32 for now
164123 GGML_ASSERT (kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
165124 const float * w_d = (const float *) kernel->data ;
166125 const float * x_d = (const float *) input->data ;
@@ -189,11 +148,11 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
189148 const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1 ) / CUDA_CONV2D_DW_BLOCK_SIZE;
190149
191150 if (ggml_is_contiguous (input)) {
192- conv2d_dw_whcn_kernel <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
151+ conv2d_dw_kernel< float , whcn_layout> <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
193152 x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
194153 dilation_x, dilation_y, channels, batches);
195154 } else if (ggml_is_contiguous_channels (input)) {
196- conv_2d_dw_cwhn_kernel <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
155+ conv2d_dw_kernel< float , cwhn_layout> <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
197156 x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
198157 dilation_x, dilation_y, channels, batches);
199158 } else {
0 commit comments