Skip to content

Commit e076556

Browse files
committed
Make const work_group_size for CUDA
1 parent d2596c7 commit e076556

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,20 +1063,20 @@ Tensor {{ embedding_cuda_op }}(
10631063
// Compute shared memory size for cta_per_row
10641064
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
10651065
{% if is_rocm %}
1066-
int32_t total_L = indices.numel();
1067-
int32_t num_cta_per_row_groups;
1068-
int32_t work_group_size;
1069-
if (total_L/total_B > 1){
1070-
num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize;
1071-
work_group_size = (kMaxThreads/4);
1072-
}
1073-
else{
1074-
num_cta_per_row_groups = kMaxThreads / kWarpSize;
1075-
work_group_size = kMaxThreads;
1076-
}
1066+
int32_t total_L = indices.numel();
1067+
int32_t num_cta_per_row_groups;
1068+
int32_t work_group_size;
1069+
if (total_L/total_B > 1) {
1070+
num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize;
1071+
work_group_size = (kMaxThreads/4);
1072+
}
1073+
else {
1074+
num_cta_per_row_groups = kMaxThreads / kWarpSize;
1075+
work_group_size = kMaxThreads;
1076+
}
10771077
{%- else %}
1078-
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
1079-
int32_t work_group_size = kMaxThreads;
1078+
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
1079+
const int32_t work_group_size = kMaxThreads;
10801080
{%- endif %}
10811081
const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes(
10821082
&num_cta_per_row_groups,

0 commit comments

Comments
 (0)