Skip to content

Commit 1a48542

Browse files
am17anfm240223
authored andcommitted
CUDA: use registers instead of smem in topk-moe (ggml-org#16647)
Uses the technique used in the vulkan PR ggml-org#16641. Neat trick!
1 parent 770aef6 commit 1a48542

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7373

7474
float wt_sum = 0.f;
7575

76-
extern __shared__ float data_topk_shared[];
77-
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
76+
float output_weights[experts_per_thread];
7877

7978
for (int k = 0; k < n_expert_used; k++) {
8079
float max_val = wt[0];
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
9998
}
10099
}
101100

101+
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
102+
output_weights[k / WARP_SIZE] = max_val;
103+
}
104+
102105
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
103106
wt[max_expert / WARP_SIZE] = -INFINITY;
104107

105-
wt_shared_ptr[k] = max_val;
106-
ids[k] = max_expert;
108+
ids[k] = max_expert;
107109
if constexpr (with_norm) {
108110
wt_sum += max_val;
109111
}
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
115117
const float inv_sum = 1.0f / wt_sum;
116118

117119
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
118-
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
120+
output_weights[i] *= inv_sum;
119121
}
120122
}
121123

122-
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
123-
weights[i] = wt_shared_ptr[i];
124+
#pragma unroll
125+
for (int i = 0; i < experts_per_thread; i++) {
126+
const int idx = i * WARP_SIZE + threadIdx.x;
127+
if (idx < n_expert_used) {
128+
weights[idx] = output_weights[i];
129+
}
124130
}
125131
}
126132

@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
137143
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
138144
cudaStream_t stream = ctx.stream();
139145

140-
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
141-
142146
switch (n_expert) {
143147
case 1:
144148
topk_moe_cuda<1, with_norm>
145-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
149+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
146150
break;
147151
case 2:
148152
topk_moe_cuda<2, with_norm>
149-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
153+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
150154
break;
151155
case 4:
152156
topk_moe_cuda<4, with_norm>
153-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
157+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
154158
break;
155159
case 8:
156160
topk_moe_cuda<8, with_norm>
157-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
161+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
158162
break;
159163
case 16:
160164
topk_moe_cuda<16, with_norm>
161-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
165+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
162166
break;
163167
case 32:
164168
topk_moe_cuda<32, with_norm>
165-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
169+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
166170
break;
167171
case 64:
168172
topk_moe_cuda<64, with_norm>
169-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
173+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
170174
break;
171175
case 128:
172176
topk_moe_cuda<128, with_norm>
173-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
177+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
174178
break;
175179
case 256:
176180
topk_moe_cuda<256, with_norm>
177-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
181+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
178182
break;
179183
case 512:
180184
topk_moe_cuda<512, with_norm>
181-
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
185+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
182186
break;
183187
default:
184188
GGML_ASSERT(false && "fatal error");

0 commit comments

Comments
 (0)