diff --git a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp index 0bb026f2e9..5309a78f7b 100644 --- a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp @@ -156,27 +156,37 @@ void dram_kv_embedding_inplace_update_cpu( const int64_t* update_row_idx_ptr = update_row_idx.data_ptr(); const int64_t* update_offsets_ptr = update_offsets.data_ptr(); - for (int64_t n = 0; n < N; ++n) { - int32_t t = update_table_idx_ptr[n]; - int64_t row_idx = update_row_idx_ptr[n]; - SparseType weight_ty = static_cast(weights_tys_ptr[t]); - int32_t D_start = D_offsets_ptr[t]; - int32_t D_end = D_offsets_ptr[t + 1]; + int64_t window_start = 0; + while (window_start < N) { + int32_t cur_table = update_table_idx_ptr[window_start]; + int64_t window_end = window_start; + while (window_end < N && update_table_idx_ptr[window_end] == cur_table) { + ++window_end; + } + int window_size = window_end - window_start; + + SparseType weight_ty = static_cast(weights_tys_ptr[0]); + int32_t D_start = D_offsets_ptr[0]; + int32_t D_end = D_offsets_ptr[1]; int32_t D = D_end - D_start; int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment); - int64_t update_weight_offset = update_offsets_ptr[n]; - const uint8_t* update_weight_row = - update_weights_ptr + update_weight_offset; - std::vector tmp(update_weight_row, update_weight_row + D_bytes); - at::Tensor update_weight = - at::from_blob( - tmp.data(), {1, D_bytes}, at::TensorOptions().dtype(at::kByte)) - .clone(); - at::Tensor row_id = - at::full({1}, row_idx, at::TensorOptions().dtype(at::kLong)); - (*embedding_inplace_update_method)({t, row_id, update_weight}); + uint8_t* batched_weights_ptr = const_cast( + update_weights_ptr + update_offsets_ptr[window_start]); + auto weights_tensor = at::from_blob( + batched_weights_ptr, + {window_size, D_bytes}, + at::TensorOptions().dtype(at::kByte)); + + int64_t* row_ids_ptr = + const_cast(update_row_idx_ptr + window_start); + auto row_id_tensor = at::from_blob( + row_ids_ptr, {window_size}, at::TensorOptions().dtype(at::kLong)); + + (*embedding_inplace_update_method)( + {cur_table, row_id_tensor, weights_tensor}); + window_start = window_end; } if (embedding_log_inplace_update_stats_method.has_value()) { (*embedding_log_inplace_update_stats_method)({});