@@ -1063,20 +1063,20 @@ Tensor {{ embedding_cuda_op }}(
10631063 // Compute shared memory size for cta_per_row
10641064 constexpr auto kCacheAccBytes = sizeof (at::acc_type<cache_t , true >);
10651065 {% if is_rocm %}
1066- int32_t total_L = indices.numel ();
1067- int32_t num_cta_per_row_groups;
1068- int32_t work_group_size;
1069- if (total_L/total_B > 1 ){
1070- num_cta_per_row_groups = (kMaxThreads /4 ) / kWarpSize ;
1071- work_group_size = (kMaxThreads /4 );
1072- }
1073- else {
1074- num_cta_per_row_groups = kMaxThreads / kWarpSize ;
1075- work_group_size = kMaxThreads ;
1076- }
1066+ int32_t total_L = indices.numel ();
1067+ int32_t num_cta_per_row_groups;
1068+ int32_t work_group_size;
1069+ if (total_L/total_B > 1 ) {
1070+ num_cta_per_row_groups = (kMaxThreads /4 ) / kWarpSize ;
1071+ work_group_size = (kMaxThreads /4 );
1072+ }
1073+ else {
1074+ num_cta_per_row_groups = kMaxThreads / kWarpSize ;
1075+ work_group_size = kMaxThreads ;
1076+ }
10771077 {%- else %}
1078- int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize ;
1079- int32_t work_group_size = kMaxThreads ;
1078+ int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize ;
1079+ const int32_t work_group_size = kMaxThreads ;
10801080 {%- endif %}
10811081 const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes (
10821082 &num_cta_per_row_groups,
0 commit comments