@@ -15,7 +15,7 @@ struct kernel_bounds {
1515 int x_min, x_max;
1616};
1717
18- __device__ inline kernel_bounds calculate_kernel_bounds (int out_x, int out_y, const conv_params & params) {
18+ __device__ __forceinline__ kernel_bounds calculate_kernel_bounds (int out_x, int out_y, const conv_params & params) {
1919 kernel_bounds bounds;
2020 bounds.y_min = max (0 , (params.padding_y - out_y * params.stride_y + params.dilation_y - 1 ) / params.dilation_y );
2121 bounds.y_max =
@@ -28,7 +28,7 @@ __device__ inline kernel_bounds calculate_kernel_bounds(int out_x, int out_y, co
2828 return bounds;
2929}
3030
31- __device__ inline int calculate_input_coord (int out_coord, int kern_coord, int stride, int dilation, int padding) {
31+ __device__ __forceinline__ int calculate_input_coord (int out_coord, int kern_coord, int stride, int dilation, int padding) {
3232 return out_coord * stride + kern_coord * dilation - padding;
3333}
3434
@@ -84,8 +84,8 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
8484 const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
8585 const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
8686 const int channels, const int batches) {
87- int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
88- int total_elements = batches * channels * out_h * out_w;
87+ const int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
88+ const int total_elements = batches * channels * out_h * out_w;
8989
9090 if (global_idx >= total_elements) {
9191 return ;
0 commit comments