Skip to content

Commit a98ad84

Browse files
jspark1105facebook-github-bot
authored andcommitted
optimize exact row-wise sparse adagrad on cpu (#589)
Summary: Pull Request resolved: #589 Use embedding spmdm JIT'ed FBGEMM kernel for gradient coalescing and also use rowwise sparses adagrad JIT'ed kernel Reviewed By: jiyuanzFB Differential Revision: D27562367 fbshipit-source-id: 7c0b6f4f2628e19706772e786c8f3ac253eac21b
1 parent 1dfb0c3 commit a98ad84

File tree

3 files changed

+265
-73
lines changed

3 files changed

+265
-73
lines changed

fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp

Lines changed: 147 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,31 @@
1212
#include <ATen/AccumulateType.h>
1313

1414
#include "codegen/embedding_forward_split_cpu.h"
15+
#include "fbgemm/FbgemmEmbedding.h"
16+
#include "fbgemm/Types.h"
1517

1618
using namespace at;
1719

20+
namespace internal {
21+
template <typename T>
22+
struct half2float16 {
23+
using type = T;
24+
};
25+
26+
template <>
27+
struct half2float16<at::Half> {
28+
using type = fbgemm::float16;
29+
};
30+
} // namespace internal
31+
1832
namespace {
1933
template <typename scalar_t>
2034
void split_embedding_backward_exact_cpu_kernel(
2135
Tensor grad_output,
2236
Tensor host_weights,
2337
const TensorAccessor<int64_t, 1> weights_offsets_data,
2438
const TensorAccessor<int, 1> D_offsets_data,
39+
Tensor hash_size_cumsum,
2540
Tensor indices,
2641
Tensor offsets,
2742
int64_t pooling_mode,
@@ -37,50 +52,146 @@ void split_embedding_backward_exact_cpu_kernel(
3752
{% endif %}
3853
{{ args.split_cpu_kernel_args | join(", ") }}) {
3954
using grad_t = acc_type<scalar_t, true>;
40-
::internal::BatchedHyperCompressedSparseColumn batched_csc;
41-
::internal::batched_csr2csc(
42-
batched_csc,
43-
num_tables,
44-
B,
45-
offsets.accessor<int64_t, 1>(),
46-
indices.accessor<int64_t, 1>(),
47-
indice_weights.defined()
48-
? indice_weights.accessor<grad_t, 1>()
49-
: TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr),
50-
pooling_mode,
51-
table_to_feature_offset);
52-
std::vector<int>& table_ptr = batched_csc.table_ptr;
53-
std::vector<int>& column_ptr = batched_csc.column_ptr;
5455

55-
auto grad_output_data = grad_output.accessor<grad_t, 2>();
56+
// const auto grad_output_accessor = grad_output.accessor<grad_t, 2>();
57+
const grad_t* grad_output_data = grad_output.data_ptr<grad_t>();
5658
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
59+
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
60+
61+
const bool has_weights = indice_weights.defined();
62+
auto grad_stride = grad_output.size(1);
5763

58-
const bool has_weights = !batched_csc.weights.empty();
64+
std::vector<::internal::BatchedHyperCompressedSparseColumn> batched_cscs(
65+
num_tables);
66+
67+
at::parallel_for(0, num_tables, 0, [&](int64_t t_begin, int64_t t_end) {
68+
for (int t = t_begin; t < t_end; ++t) {
69+
int feature_begin = table_to_feature_offset[t];
70+
71+
::internal::batched_csr2csc(
72+
batched_cscs[t],
73+
1,
74+
B,
75+
offsets.accessor<int64_t, 1>(),
76+
indices.accessor<int64_t, 1>(),
77+
indice_weights.defined()
78+
? indice_weights.accessor<grad_t, 1>()
79+
: TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr),
80+
pooling_mode,
81+
table_to_feature_offset + t);
82+
}
83+
});
5984

6085
for (int t = 0; t < num_tables; ++t) {
6186
int feature_begin = table_to_feature_offset[t];
87+
88+
int c_begin = batched_cscs[t].table_ptr[0];
89+
int c_end = batched_cscs[t].table_ptr[1];
90+
std::vector<int>& col_segment_ptr = batched_cscs[t].column_segment_ptr;
91+
std::vector<int64_t>& col_segment_indices =
92+
batched_cscs[t].column_segment_indices;
93+
94+
int64_t hash_size;
95+
int t_temp = feature_begin + 1;
96+
do {
97+
hash_size =
98+
hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[feature_begin];
99+
++t_temp;
100+
} while (hash_size == 0);
101+
62102
const auto D_begin = D_offsets_data[feature_begin];
63103
const auto D =
64104
D_offsets_data[feature_begin + 1] - D_offsets_data[feature_begin];
65105
const auto table_begin = weights_offsets_data[feature_begin];
66-
grad_t grad_buffer[D];
67-
for (int c = table_ptr[t]; c < table_ptr[t + 1]; ++c) {
68-
memset(grad_buffer, 0, D * sizeof(grad_t));
69-
int idx = batched_csc.column_indices[c];
70-
const int64_t embedding_begin = table_begin + idx * D;
71-
for (int r = column_ptr[c]; r < column_ptr[c + 1]; ++r) {
72-
int f_times_b = batched_csc.row_indices[r];
73-
int feature = f_times_b / B;
74-
int b = f_times_b % B;
75-
int D_offset = D_begin + (feature - feature_begin) * D;
76-
for (int64_t d = 0; d < D; ++d) {
77-
grad_buffer[d] += has_weights
78-
? grad_output_data[b][D_offset + d] * batched_csc.weights[r]
79-
: grad_output_data[b][D_offset + d];
106+
107+
{% if optimizer == "rowwise_adagrad" %}
108+
constexpr bool use_fbgemm = std::is_same<scalar_t, float>::value;
109+
// || std::is_same<scalar_t, at::Half>::value;
110+
if (use_fbgemm &&
111+
table_to_feature_offset[t + 1] == table_to_feature_offset[t] + 1) {
112+
// fbgemm handles common case of no shared table
113+
using fbgemm_weight_t = typename ::internal::half2float16<scalar_t>::type;
114+
auto spmdm_kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
115+
fbgemm_weight_t,
116+
/*IndexType=*/int32_t,
117+
/*OffsetType=*/int32_t>(
118+
D,
119+
!batched_cscs[t].weights.empty(),
120+
/*normalize_by_lengths=*/false,
121+
/*prefetch=*/16,
122+
/*is_weight_positional=*/false,
123+
/*use_offsets=*/true,
124+
/*output_stride=*/-1,
125+
/*input_stride=*/grad_stride);
126+
auto rowwise_adagrad_kernel =
127+
fbgemm::GenerateSparseAdaGrad</*IndexType=*/int64_t>(
128+
D, /*rowwise=*/true);
129+
130+
constexpr int C_BLOCK = 64;
131+
at::parallel_for(c_begin, c_end, C_BLOCK, [&](int64_t c0, int64_t c1) {
132+
grad_t grad_blocked_buffer[C_BLOCK * D];
133+
for (int64_t c = c0; c < c1; c += C_BLOCK) {
134+
const int* offsets_begin_ptr = col_segment_ptr.data() + c;
135+
int64_t c_block_end = std::min(c + C_BLOCK, c1);
136+
bool success = spmdm_kernel(
137+
c_block_end - c,
138+
col_segment_ptr[c_block_end] - *offsets_begin_ptr,
139+
B,
140+
reinterpret_cast<const fbgemm_weight_t*>(
141+
grad_output_data + D_begin),
142+
batched_cscs[t].row_indices.data() + *offsets_begin_ptr,
143+
offsets_begin_ptr,
144+
batched_cscs[t].weights.empty()
145+
? nullptr
146+
: batched_cscs[t].weights.data() + *offsets_begin_ptr,
147+
reinterpret_cast<float*>(grad_blocked_buffer));
148+
// TODO: more friendly error msg.
149+
TORCH_CHECK(success);
150+
int num_rows_processed = rowwise_adagrad_kernel(
151+
c_block_end - c,
152+
hash_size * D,
153+
reinterpret_cast<float*>(&host_weights_data[table_begin]),
154+
reinterpret_cast<const float*>(grad_blocked_buffer),
155+
reinterpret_cast<float*>(
156+
&momentum1_host[momentum1_offsets_data[feature_begin]]),
157+
col_segment_indices.data() + c,
158+
eps,
159+
-learning_rate,
160+
/*weight_decay=*/0,
161+
/*counter=*/nullptr,
162+
/*counter_halflife=*/0);
163+
// TODO: more friendly error msg.
164+
TORCH_CHECK(num_rows_processed == c_block_end - c);
165+
} // for each c
166+
}); // parallel for
167+
} else
168+
{% endif %}
169+
{
170+
// no fbgemm
171+
// TODO: to parallelize, we should easily identify segments belong to
172+
// the same column.
173+
grad_t grad_buffer[D];
174+
for (int c = c_begin; c < c_end; ++c) {
175+
int64_t idx = col_segment_indices[c];
176+
if (c == c_begin || col_segment_indices[c - 1] != idx) {
177+
memset(grad_buffer, 0, D * sizeof(grad_t));
80178
}
81-
}
82-
{{ split_weight_update_cpu }}
83-
}
179+
const int64_t embedding_begin = table_begin + idx * D;
180+
int D_offset = D_begin + batched_cscs[t].column_segment_ids[c] * D;
181+
for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
182+
int b = batched_cscs[t].row_indices[r];
183+
for (int64_t d = 0; d < D; ++d) {
184+
grad_buffer[d] += !batched_cscs[t].weights.empty()
185+
? grad_output_data[b * grad_stride + D_offset + d] *
186+
batched_cscs[t].weights[r]
187+
: grad_output_data[b * grad_stride + D_offset + d];
188+
}
189+
}
190+
if (c == c_end - 1 || col_segment_indices[c + 1] != idx) {
191+
{{ split_weight_update_cpu }}
192+
}
193+
} // for each c
194+
} // no fbgemm
84195
} // for each table
85196
}
86197

@@ -200,13 +311,16 @@ void split_embedding_backward_exact_cpu_dense_kernel(
200311
const auto momentum2_offsets_data = momentum2_offsets.accessor<int64_t, 1>();
201312
{% endif %}
202313

314+
grad_output = grad_output.contiguous();
315+
203316
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
204317
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&]() {
205318
split_embedding_backward_exact_cpu_kernel<scalar_t>(
206319
grad_output,
207320
host_weights,
208321
weights_offsets_data,
209322
D_offsets_data,
323+
hash_size_cumsum,
210324
indices,
211325
offsets,
212326
pooling_mode,

fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp

Lines changed: 111 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -346,53 +346,127 @@ void batched_csr2csc(
346346
const int* table_to_feature_offset) {
347347
batched_csc.num_tables = num_tables;
348348
batched_csc.table_ptr.resize(num_tables + 1);
349-
int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B];
349+
int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B] -
350+
batched_csr_offsets[table_to_feature_offset[0] * B];
350351
batched_csc.row_indices.resize(nnz);
351352
bool has_weights = batched_csr_weights.data() != nullptr;
352353
if (has_weights || pooling_mode == MEAN) {
353354
batched_csc.weights.resize(nnz);
354355
}
355356

356-
batched_csc.table_ptr.push_back(0);
357-
batched_csc.column_ptr.push_back(0);
357+
batched_csc.table_ptr[0] = 0;
358+
batched_csc.column_segment_ptr.push_back(0);
358359
int column_ptr_curr = 0;
359360
for (int t = 0; t < num_tables; ++t) {
360-
std::unordered_map<int64_t, std::vector<std::pair<int, scalar_t>>>
361-
non_empty_columns;
362-
for (int feature = table_to_feature_offset[t];
363-
feature < table_to_feature_offset[t + 1];
364-
++feature) {
365-
for (int b = 0; b < B; ++b) {
366-
int64_t pool_begin = batched_csr_offsets[feature * B + b];
367-
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
368-
int64_t L = pool_end - pool_begin;
369-
// MEAN pooling will not work with indice_weights!
370-
double scale_factor =
371-
(pooling_mode == MEAN && !has_weights && L > 0) ? 1.0 / L : 1.0;
372-
for (int64_t p = pool_begin; p < pool_end; ++p) {
373-
non_empty_columns[batched_csr_indices[p]].emplace_back(
374-
feature * B + b,
375-
scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f));
361+
int num_non_empty_segments = 0;
362+
if (batched_csc.weights.empty()) {
363+
std::unordered_map<int64_t, std::vector<std::vector<int>>>
364+
non_empty_columns;
365+
int f_begin = table_to_feature_offset[t];
366+
int f_end = table_to_feature_offset[t + 1];
367+
368+
for (int feature = f_begin; feature < f_end; ++feature) {
369+
for (int b = 0; b < B; ++b) {
370+
int64_t pool_begin = batched_csr_offsets[feature * B + b];
371+
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
372+
for (int64_t p = pool_begin; p < pool_end; ++p) {
373+
auto itr = non_empty_columns.find(batched_csr_indices[p]);
374+
if (itr == non_empty_columns.end()) {
375+
itr = non_empty_columns
376+
.emplace(
377+
batched_csr_indices[p],
378+
std::vector<std::vector<int>>(f_end - f_begin))
379+
.first;
380+
}
381+
if (itr->second[feature - f_begin].empty()) {
382+
++num_non_empty_segments;
383+
}
384+
itr->second[feature - f_begin].push_back(b);
385+
}
376386
}
377-
}
378-
} // for each feature
379-
380-
batched_csc.table_ptr[t + 1] =
381-
batched_csc.table_ptr[t] + non_empty_columns.size();
382-
batched_csc.column_ptr.reserve(batched_csc.table_ptr[t + 1] + 1);
383-
batched_csc.column_indices.reserve(batched_csc.table_ptr[t + 1]);
384-
for (auto const& column : non_empty_columns) {
385-
batched_csc.column_ptr.push_back(column_ptr_curr + column.second.size());
386-
batched_csc.column_indices.push_back(column.first);
387-
388-
for (auto const& non_zero : column.second) {
389-
batched_csc.row_indices[column_ptr_curr] = non_zero.first;
390-
if (!batched_csc.weights.empty()) {
391-
batched_csc.weights[column_ptr_curr] = non_zero.second;
387+
} // for each feature
388+
389+
batched_csc.table_ptr[t + 1] =
390+
batched_csc.table_ptr[t] + num_non_empty_segments;
391+
batched_csc.column_segment_ptr.reserve(batched_csc.table_ptr[t + 1] + 1);
392+
batched_csc.column_segment_indices.reserve(batched_csc.table_ptr[t + 1]);
393+
batched_csc.column_segment_ids.reserve(batched_csc.table_ptr[t + 1]);
394+
for (auto const& column : non_empty_columns) {
395+
int feature = f_begin;
396+
for (auto const& column_segment : column.second) {
397+
if (!column_segment.empty()) {
398+
batched_csc.column_segment_ptr.push_back(
399+
column_ptr_curr + column_segment.size());
400+
batched_csc.column_segment_indices.push_back(column.first);
401+
batched_csc.column_segment_ids.push_back(feature - f_begin);
402+
memcpy(
403+
&batched_csc.row_indices[column_ptr_curr],
404+
column_segment.data(),
405+
column_segment.size() * sizeof(int));
406+
column_ptr_curr += column_segment.size();
407+
}
408+
++feature;
409+
} // for each column segment
410+
} // for each column
411+
} else {
412+
// !batched_csc.weights.empty()
413+
std::unordered_map<
414+
int64_t,
415+
std::vector<std::vector<std::pair<int, scalar_t>>>>
416+
non_empty_columns;
417+
int f_begin = table_to_feature_offset[t];
418+
int f_end = table_to_feature_offset[t + 1];
419+
for (int feature = f_begin; feature < f_end; ++feature) {
420+
for (int b = 0; b < B; ++b) {
421+
int64_t pool_begin = batched_csr_offsets[feature * B + b];
422+
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
423+
int64_t L = pool_end - pool_begin;
424+
// MEAN pooling will not work with indice_weights!
425+
double scale_factor =
426+
(pooling_mode == MEAN && !has_weights && L > 0) ? 1.0 / L : 1.0;
427+
for (int64_t p = pool_begin; p < pool_end; ++p) {
428+
auto itr = non_empty_columns.find(batched_csr_indices[p]);
429+
if (itr == non_empty_columns.end()) {
430+
itr = non_empty_columns
431+
.emplace(
432+
batched_csr_indices[p],
433+
std::vector<std::vector<std::pair<int, scalar_t>>>(
434+
f_end - f_begin))
435+
.first;
436+
}
437+
if (itr->second[feature - f_begin].empty()) {
438+
++num_non_empty_segments;
439+
}
440+
itr->second[feature - f_begin].emplace_back(
441+
b,
442+
scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f));
443+
}
392444
}
393-
++column_ptr_curr;
394-
}
395-
}
445+
} // for each feature
446+
447+
batched_csc.table_ptr[t + 1] =
448+
batched_csc.table_ptr[t] + num_non_empty_segments;
449+
batched_csc.column_segment_ptr.reserve(batched_csc.table_ptr[t + 1] + 1);
450+
batched_csc.column_segment_indices.reserve(batched_csc.table_ptr[t + 1]);
451+
batched_csc.column_segment_ids.reserve(batched_csc.table_ptr[t + 1]);
452+
for (auto const& column : non_empty_columns) {
453+
int feature = f_begin;
454+
for (auto const& column_segment : column.second) {
455+
if (!column_segment.empty()) {
456+
batched_csc.column_segment_ptr.push_back(
457+
column_ptr_curr + column_segment.size());
458+
batched_csc.column_segment_indices.push_back(column.first);
459+
batched_csc.column_segment_ids.push_back(feature - f_begin);
460+
for (auto const& non_zero : column_segment) {
461+
batched_csc.row_indices[column_ptr_curr] = non_zero.first;
462+
batched_csc.weights[column_ptr_curr] = non_zero.second;
463+
++column_ptr_curr;
464+
}
465+
}
466+
++feature;
467+
} // for each column segment
468+
} // for each column
469+
} // !batched_csc.weights.empty()
396470
} // for each matrix (table)
397471

398472
assert(column_ptr_curr == nnz);

0 commit comments

Comments
 (0)