@@ -30,6 +30,7 @@ struct quantize_perf_params {
3030 bool op_quantize_row_q_reference = false ;
3131 bool op_quantize_row_q = false ;
3232 bool op_dequantize_row_q = false ;
33+ bool op_quantize_row_q_dot = false ;
3334 bool op_vec_dot_q = false ;
3435};
3536
@@ -147,6 +148,8 @@ int main(int argc, char * argv[]) {
147148 params.op_quantize_row_q = true ;
148149 } else if (op == " dequantize_row_q" ) {
149150 params.op_dequantize_row_q = true ;
151+ } else if (op == " quantize_row_q_dot" ) {
152+ params.op_quantize_row_q_dot = true ;
150153 } else if (op == " vec_dot_q" ) {
151154 params.op_vec_dot_q = true ;
152155 } else {
@@ -184,8 +187,8 @@ int main(int argc, char * argv[]) {
184187 if (params.test_sizes .empty ()) {
185188 params.test_sizes .push_back (L1_SIZE);
186189 }
187- if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_vec_dot_q )) {
188- params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_vec_dot_q = true ;
190+ if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_quantize_row_q_dot || params. op_vec_dot_q )) {
191+ params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_quantize_row_q_dot = params. op_vec_dot_q = true ;
189192 }
190193
191194 std::sort (params.test_sizes .begin (), params.test_sizes .end ());
@@ -225,7 +228,7 @@ int main(int argc, char * argv[]) {
225228 if (qfns.quantize_row_q ) {
226229 printf (" %s\n " , ggml_type_name (type));
227230
228- if (params.op_quantize_row_q_reference ) {
231+ if (params.op_quantize_row_q_reference && qfns. quantize_row_q_reference ) {
229232 printf (" quantize_row_q_reference\n " );
230233 for (size_t size : params.test_sizes ) {
231234 printf (" %zu values (%.2f MB)\n " , size, 4 *size/(float )(1024 *1024 ));
@@ -239,7 +242,7 @@ int main(int argc, char * argv[]) {
239242 printf (" \n " );
240243 }
241244
242- if (params.op_quantize_row_q ) {
245+ if (params.op_quantize_row_q && qfns. quantize_row_q ) {
243246 printf (" quantize_row_q\n " );
244247 for (size_t size : params.test_sizes ) {
245248 printf (" %zu values (%.2f MB)\n " , size, 4 *size/(float )(1024 *1024 ));
@@ -253,7 +256,7 @@ int main(int argc, char * argv[]) {
253256 printf (" \n " );
254257 }
255258
256- if (params.op_dequantize_row_q ) {
259+ if (params.op_dequantize_row_q && qfns. dequantize_row_q ) {
257260 printf (" dequantize_row_q\n " );
258261 qfns.quantize_row_q (test_data1, test_q1, largest);
259262 for (size_t size : params.test_sizes ) {
@@ -268,7 +271,21 @@ int main(int argc, char * argv[]) {
268271 printf (" \n " );
269272 }
270273
271- if (params.op_vec_dot_q ) {
274+ if (params.op_quantize_row_q_dot && qfns.quantize_row_q_dot ) {
275+ printf (" quantize_row_q_dot\n " );
276+ for (size_t size : params.test_sizes ) {
277+ printf (" %zu values (%.2f MB)\n " , size, 4 *size/(float )(1024 *1024 ));
278+ auto quantize_fn = [&](void ) {
279+ qfns.quantize_row_q_dot (test_data1, test_q1, size);
280+ return test_q1[0 ];
281+ };
282+ size_t quantized_size = size / ggml_blck_size (type) * ggml_type_size (type);
283+ benchmark_function (size, quantized_size, quantize_fn);
284+ }
285+ printf (" \n " );
286+ }
287+
288+ if (params.op_vec_dot_q && qfns.vec_dot_q ) {
272289 printf (" vec_dot_q\n " );
273290 qfns.quantize_row_q (test_data1, test_q1, largest);
274291 qfns.quantize_row_q (test_data2, test_q2, largest);
0 commit comments