Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,37 @@ void dram_kv_embedding_inplace_update_cpu(
const int64_t* update_row_idx_ptr = update_row_idx.data_ptr<int64_t>();
const int64_t* update_offsets_ptr = update_offsets.data_ptr<int64_t>();

for (int64_t n = 0; n < N; ++n) {
int32_t t = update_table_idx_ptr[n];
int64_t row_idx = update_row_idx_ptr[n];
SparseType weight_ty = static_cast<SparseType>(weights_tys_ptr[t]);
int32_t D_start = D_offsets_ptr[t];
int32_t D_end = D_offsets_ptr[t + 1];
int64_t window_start = 0;
while (window_start < N) {
int32_t cur_table = update_table_idx_ptr[window_start];
int64_t window_end = window_start;
while (window_end < N && update_table_idx_ptr[window_end] == cur_table) {
++window_end;
}
int window_size = window_end - window_start;

SparseType weight_ty = static_cast<SparseType>(weights_tys_ptr[0]);
int32_t D_start = D_offsets_ptr[0];
int32_t D_end = D_offsets_ptr[1];
int32_t D = D_end - D_start;
int32_t D_bytes =
nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);

int64_t update_weight_offset = update_offsets_ptr[n];
const uint8_t* update_weight_row =
update_weights_ptr + update_weight_offset;
std::vector<uint8_t> tmp(update_weight_row, update_weight_row + D_bytes);
at::Tensor update_weight =
at::from_blob(
tmp.data(), {1, D_bytes}, at::TensorOptions().dtype(at::kByte))
.clone();
at::Tensor row_id =
at::full({1}, row_idx, at::TensorOptions().dtype(at::kLong));
(*embedding_inplace_update_method)({t, row_id, update_weight});
uint8_t* batched_weights_ptr = const_cast<uint8_t*>(
update_weights_ptr + update_offsets_ptr[window_start]);
auto weights_tensor = at::from_blob(
batched_weights_ptr,
{window_size, D_bytes},
at::TensorOptions().dtype(at::kByte));

int64_t* row_ids_ptr =
const_cast<int64_t*>(update_row_idx_ptr + window_start);
auto row_id_tensor = at::from_blob(
row_ids_ptr, {window_size}, at::TensorOptions().dtype(at::kLong));

(*embedding_inplace_update_method)(
{cur_table, row_id_tensor, weights_tensor});
window_start = window_end;
}
if (embedding_log_inplace_update_stats_method.has_value()) {
(*embedding_log_inplace_update_stats_method)({});
Expand Down
Loading