@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7373
7474 float wt_sum = 0 .f ;
7575
76- extern __shared__ float data_topk_shared[];
77- float * wt_shared_ptr = data_topk_shared + threadIdx .y * n_expert_used;
76+ float output_weights[experts_per_thread];
7877
7978 for (int k = 0 ; k < n_expert_used; k++) {
8079 float max_val = wt[0 ];
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
9998 }
10099 }
101100
101+ if ((k & (WARP_SIZE - 1 )) == threadIdx .x ) {
102+ output_weights[k / WARP_SIZE] = max_val;
103+ }
104+
102105 if ((max_expert & (WARP_SIZE - 1 )) == threadIdx .x ) {
103106 wt[max_expert / WARP_SIZE] = -INFINITY;
104107
105- wt_shared_ptr[k] = max_val;
106- ids[k] = max_expert;
108+ ids[k] = max_expert;
107109 if constexpr (with_norm) {
108110 wt_sum += max_val;
109111 }
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
115117 const float inv_sum = 1 .0f / wt_sum;
116118
117119 for (int i = threadIdx .x ; i < n_expert_used; i += WARP_SIZE) {
118- wt_shared_ptr [i] = wt_shared_ptr[i] * inv_sum;
120+ output_weights [i] *= inv_sum;
119121 }
120122 }
121123
122- for (int i = threadIdx .x ; i < n_expert_used; i += WARP_SIZE) {
123- weights[i] = wt_shared_ptr[i];
124+ #pragma unroll
125+ for (int i = 0 ; i < experts_per_thread; i++) {
126+ const int idx = i * WARP_SIZE + threadIdx .x ;
127+ if (idx < n_expert_used) {
128+ weights[idx] = output_weights[i];
129+ }
124130 }
125131}
126132
@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
137143 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
138144 cudaStream_t stream = ctx.stream ();
139145
140- const int nbytes_shared = n_expert_used * rows_per_block * sizeof (float );
141-
142146 switch (n_expert) {
143147 case 1 :
144148 topk_moe_cuda<1 , with_norm>
145- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
149+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
146150 break ;
147151 case 2 :
148152 topk_moe_cuda<2 , with_norm>
149- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
153+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
150154 break ;
151155 case 4 :
152156 topk_moe_cuda<4 , with_norm>
153- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
157+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
154158 break ;
155159 case 8 :
156160 topk_moe_cuda<8 , with_norm>
157- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
161+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
158162 break ;
159163 case 16 :
160164 topk_moe_cuda<16 , with_norm>
161- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
165+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
162166 break ;
163167 case 32 :
164168 topk_moe_cuda<32 , with_norm>
165- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
169+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
166170 break ;
167171 case 64 :
168172 topk_moe_cuda<64 , with_norm>
169- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
173+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
170174 break ;
171175 case 128 :
172176 topk_moe_cuda<128 , with_norm>
173- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
177+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
174178 break ;
175179 case 256 :
176180 topk_moe_cuda<256 , with_norm>
177- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
181+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
178182 break ;
179183 case 512 :
180184 topk_moe_cuda<512 , with_norm>
181- <<<grid_dims, block_dims, nbytes_shared , stream>>> (logits, weights, ids, n_rows, n_expert_used);
185+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
182186 break ;
183187 default :
184188 GGML_ASSERT (false && " fatal error" );
0 commit comments