1212#include < ATen/AccumulateType.h>
1313
1414#include " codegen/embedding_forward_split_cpu.h"
15+ #include " fbgemm/FbgemmEmbedding.h"
16+ #include " fbgemm/Types.h"
1517
1618using namespace at ;
1719
20+ namespace internal {
21+ template <typename T>
22+ struct half2float16 {
23+ using type = T;
24+ };
25+
26+ template <>
27+ struct half2float16 <at::Half> {
28+ using type = fbgemm::float16;
29+ };
30+ } // namespace internal
31+
1832namespace {
1933template <typename scalar_t >
2034void split_embedding_backward_exact_cpu_kernel (
2135 Tensor grad_output,
2236 Tensor host_weights,
2337 const TensorAccessor<int64_t , 1 > weights_offsets_data,
2438 const TensorAccessor<int , 1 > D_offsets_data,
39+ Tensor hash_size_cumsum,
2540 Tensor indices,
2641 Tensor offsets,
2742 int64_t pooling_mode,
@@ -37,50 +52,146 @@ void split_embedding_backward_exact_cpu_kernel(
3752 {% endif %}
3853 {{ args.split_cpu_kernel_args | join (" , " ) }}) {
3954 using grad_t = acc_type<scalar_t , true >;
40- ::internal::BatchedHyperCompressedSparseColumn batched_csc;
41- ::internal::batched_csr2csc (
42- batched_csc,
43- num_tables,
44- B,
45- offsets.accessor<int64_t , 1 >(),
46- indices.accessor<int64_t, 1>(),
47- indice_weights.defined()
48- ? indice_weights.accessor<grad_t, 1>()
49- : TensorAccessor<grad_t, 1>(nullptr , nullptr , nullptr ),
50- pooling_mode,
51- table_to_feature_offset);
52- std::vector<int >& table_ptr = batched_csc.table_ptr ;
53- std::vector<int >& column_ptr = batched_csc.column_ptr ;
5455
55- auto grad_output_data = grad_output.accessor <grad_t , 2 >();
56+ // const auto grad_output_accessor = grad_output.accessor<grad_t, 2>();
57+ const grad_t * grad_output_data = grad_output.data_ptr <grad_t >();
5658 auto host_weights_data = host_weights.accessor <scalar_t , 1 >();
59+ const auto hash_size_cumsum_data = hash_size_cumsum.accessor <int64_t , 1 >();
60+
61+ const bool has_weights = indice_weights.defined ();
62+ auto grad_stride = grad_output.size (1 );
5763
58- const bool has_weights = !batched_csc.weights .empty ();
64+ std::vector<::internal::BatchedHyperCompressedSparseColumn> batched_cscs (
65+ num_tables);
66+
67+ at::parallel_for (0 , num_tables, 0 , [&](int64_t t_begin, int64_t t_end) {
68+ for (int t = t_begin; t < t_end; ++t) {
69+ int feature_begin = table_to_feature_offset[t];
70+
71+ ::internal::batched_csr2csc (
72+ batched_cscs[t],
73+ 1 ,
74+ B,
75+ offsets.accessor<int64_t , 1 >(),
76+ indices.accessor<int64_t, 1>(),
77+ indice_weights.defined()
78+ ? indice_weights.accessor<grad_t, 1>()
79+ : TensorAccessor<grad_t, 1>(nullptr , nullptr , nullptr ),
80+ pooling_mode,
81+ table_to_feature_offset + t);
82+ }
83+ });
5984
6085 for (int t = 0 ; t < num_tables; ++t) {
6186 int feature_begin = table_to_feature_offset[t];
87+
88+ int c_begin = batched_cscs[t].table_ptr [0 ];
89+ int c_end = batched_cscs[t].table_ptr [1 ];
90+ std::vector<int >& col_segment_ptr = batched_cscs[t].column_segment_ptr ;
91+ std::vector<int64_t >& col_segment_indices =
92+ batched_cscs[t].column_segment_indices ;
93+
94+ int64_t hash_size;
95+ int t_temp = feature_begin + 1 ;
96+ do {
97+ hash_size =
98+ hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[feature_begin];
99+ ++t_temp;
100+ } while (hash_size == 0 );
101+
62102 const auto D_begin = D_offsets_data[feature_begin];
63103 const auto D =
64104 D_offsets_data[feature_begin + 1 ] - D_offsets_data[feature_begin];
65105 const auto table_begin = weights_offsets_data[feature_begin];
66- grad_t grad_buffer[D];
67- for (int c = table_ptr[t]; c < table_ptr[t + 1 ]; ++c) {
68- memset (grad_buffer, 0 , D * sizeof (grad_t ));
69- int idx = batched_csc.column_indices [c];
70- const int64_t embedding_begin = table_begin + idx * D;
71- for (int r = column_ptr[c]; r < column_ptr[c + 1 ]; ++r) {
72- int f_times_b = batched_csc.row_indices [r];
73- int feature = f_times_b / B;
74- int b = f_times_b % B;
75- int D_offset = D_begin + (feature - feature_begin) * D;
76- for (int64_t d = 0 ; d < D; ++d) {
77- grad_buffer[d] += has_weights
78- ? grad_output_data[b][D_offset + d] * batched_csc.weights [r]
79- : grad_output_data[b][D_offset + d];
106+
107+ {% if optimizer == " rowwise_adagrad" %}
108+ constexpr bool use_fbgemm = std::is_same<scalar_t , float >::value;
109+ // || std::is_same<scalar_t, at::Half>::value;
110+ if (use_fbgemm &&
111+ table_to_feature_offset[t + 1 ] == table_to_feature_offset[t] + 1 ) {
112+ // fbgemm handles common case of no shared table
113+ using fbgemm_weight_t = typename ::internal::half2float16<scalar_t >::type;
114+ auto spmdm_kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
115+ fbgemm_weight_t ,
116+ /* IndexType=*/ int32_t ,
117+ /* OffsetType=*/ int32_t >(
118+ D,
119+ !batched_cscs[t].weights .empty (),
120+ /* normalize_by_lengths=*/ false ,
121+ /* prefetch=*/ 16 ,
122+ /* is_weight_positional=*/ false ,
123+ /* use_offsets=*/ true ,
124+ /* output_stride=*/ -1 ,
125+ /* input_stride=*/ grad_stride);
126+ auto rowwise_adagrad_kernel =
127+ fbgemm::GenerateSparseAdaGrad</* IndexType=*/ int64_t >(
128+ D, /* rowwise=*/ true );
129+
130+ constexpr int C_BLOCK = 64 ;
131+ at::parallel_for (c_begin, c_end, C_BLOCK, [&](int64_t c0, int64_t c1) {
132+ grad_t grad_blocked_buffer[C_BLOCK * D];
133+ for (int64_t c = c0; c < c1; c += C_BLOCK) {
134+ const int * offsets_begin_ptr = col_segment_ptr.data () + c;
135+ int64_t c_block_end = std::min (c + C_BLOCK, c1);
136+ bool success = spmdm_kernel (
137+ c_block_end - c,
138+ col_segment_ptr[c_block_end] - *offsets_begin_ptr,
139+ B,
140+ reinterpret_cast <const fbgemm_weight_t *>(
141+ grad_output_data + D_begin),
142+ batched_cscs[t].row_indices .data () + *offsets_begin_ptr,
143+ offsets_begin_ptr,
144+ batched_cscs[t].weights .empty ()
145+ ? nullptr
146+ : batched_cscs[t].weights .data () + *offsets_begin_ptr,
147+ reinterpret_cast <float *>(grad_blocked_buffer));
148+ // TODO: more friendly error msg.
149+ TORCH_CHECK (success);
150+ int num_rows_processed = rowwise_adagrad_kernel (
151+ c_block_end - c,
152+ hash_size * D,
153+ reinterpret_cast <float *>(&host_weights_data[table_begin]),
154+ reinterpret_cast <const float *>(grad_blocked_buffer),
155+ reinterpret_cast <float *>(
156+ &momentum1_host[momentum1_offsets_data[feature_begin]]),
157+ col_segment_indices.data () + c,
158+ eps,
159+ -learning_rate,
160+ /* weight_decay=*/ 0 ,
161+ /* counter=*/ nullptr ,
162+ /* counter_halflife=*/ 0 );
163+ // TODO: more friendly error msg.
164+ TORCH_CHECK (num_rows_processed == c_block_end - c);
165+ } // for each c
166+ }); // parallel for
167+ } else
168+ {% endif %}
169+ {
170+ // no fbgemm
171+ // TODO: to parallelize, we should easily identify segments belong to
172+ // the same column.
173+ grad_t grad_buffer[D];
174+ for (int c = c_begin; c < c_end; ++c) {
175+ int64_t idx = col_segment_indices[c];
176+ if (c == c_begin || col_segment_indices[c - 1 ] != idx) {
177+ memset (grad_buffer, 0 , D * sizeof (grad_t ));
80178 }
81- }
82- {{ split_weight_update_cpu }}
83- }
179+ const int64_t embedding_begin = table_begin + idx * D;
180+ int D_offset = D_begin + batched_cscs[t].column_segment_ids [c] * D;
181+ for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1 ]; ++r) {
182+ int b = batched_cscs[t].row_indices [r];
183+ for (int64_t d = 0 ; d < D; ++d) {
184+ grad_buffer[d] += !batched_cscs[t].weights .empty ()
185+ ? grad_output_data[b * grad_stride + D_offset + d] *
186+ batched_cscs[t].weights [r]
187+ : grad_output_data[b * grad_stride + D_offset + d];
188+ }
189+ }
190+ if (c == c_end - 1 || col_segment_indices[c + 1 ] != idx) {
191+ {{ split_weight_update_cpu }}
192+ }
193+ } // for each c
194+ } // no fbgemm
84195 } // for each table
85196}
86197
@@ -200,13 +311,16 @@ void split_embedding_backward_exact_cpu_dense_kernel(
200311 const auto momentum2_offsets_data = momentum2_offsets.accessor <int64_t , 1 >();
201312 {% endif %}
202313
314+ grad_output = grad_output.contiguous ();
315+
203316 AT_DISPATCH_FLOATING_TYPES_AND_HALF (
204317 host_weights.scalar_type (), " split_embedding_backward_exact_cpu" , [&]() {
205318 split_embedding_backward_exact_cpu_kernel<scalar_t >(
206319 grad_output,
207320 host_weights,
208321 weights_offsets_data,
209322 D_offsets_data,
323+ hash_size_cumsum,
210324 indices,
211325 offsets,
212326 pooling_mode,
0 commit comments