11
11
#include " tensorflow/lite/micro/micro_interpreter.h"
12
12
13
13
#include " libtf.h"
14
-
15
- #define LIBTF_MAX_OPS 72
14
+ #define LIBTF_MAX_OPS 80
16
15
17
16
extern " C" {
18
17
// These are set by openmv py_tf.c code to redirect printing to an error message buffer...
@@ -62,16 +61,108 @@ extern "C" {
62
61
}
63
62
}
64
63
65
- static void libtf_init_op_resolver (tflite::MicroMutableOpResolver<LIBTF_MAX_OPS> &resolver) {
64
+ typedef void (*libtf_resolver_init_t ) (tflite::MicroMutableOpResolver<LIBTF_MAX_OPS> &);
65
+
66
+ static void libtf_reduced_ops_init_op_resolver (tflite::MicroMutableOpResolver<LIBTF_MAX_OPS> &resolver) {
67
+ // resolver.AddAbs();
68
+ resolver.AddAdd ();
69
+ resolver.AddAddN ();
70
+ // resolver.AddArgMax();
71
+ // resolver.AddArgMin();
72
+ resolver.AddAveragePool2D ();
73
+ // resolver.AddBatchMatMul(); - Doesn't compile in OpenMV Cam firmware
74
+ // resolver.AddBatchToSpaceNd();
75
+ // resolver.AddCast();
76
+ // resolver.AddCeil();
77
+ // resolver.AddCircularBuffer(); - Doesn't compile in OpenMV Cam firmware
78
+ // resolver.AddComplexAbs();
79
+ // resolver.AddConcatenation();
80
+ resolver.AddConv2D ();
81
+ // resolver.AddCos();
82
+ resolver.AddDepthwiseConv2D ();
83
+ // resolver.AddDequantize();
84
+ // resolver.AddDetectionPostprocess();
85
+ // resolver.AddDiv();
86
+ // resolver.AddElu();
87
+ // resolver.AddEqual();
88
+ // resolver.AddEthosU();
89
+ // resolver.AddExp();
90
+ // resolver.AddExpandDims();
91
+ // resolver.AddFloor();
92
+ resolver.AddFullyConnected ();
93
+ // resolver.AddGather();
94
+ // resolver.AddGreater();
95
+ // resolver.AddGreaterEqual();
96
+ // resolver.AddHardSwish();
97
+ // resolver.AddImag();
98
+ // resolver.AddL2Normalization();
99
+ // resolver.AddL2Pool2D();
100
+ resolver.AddLeakyRelu ();
101
+ // resolver.AddLess();
102
+ // resolver.AddLessEqual();
103
+ // resolver.AddLog();
104
+ // resolver.AddLogicalAnd();
105
+ // resolver.AddLogicalNot();
106
+ // resolver.AddLogicalOr();
107
+ resolver.AddLogistic ();
108
+ resolver.AddMaxPool2D ();
109
+ // resolver.AddMaximum();
110
+ resolver.AddMean ();
111
+ // resolver.AddMinimum();
112
+ // resolver.AddMul();
113
+ // resolver.AddNeg();
114
+ // resolver.AddNotEqual();
115
+ // resolver.AddPack();
116
+ resolver.AddPad ();
117
+ // resolver.AddPadV2();
118
+ // resolver.AddPrelu();
119
+ // resolver.AddQuantize();
120
+ // resolver.AddReal();
121
+ // resolver.AddReduceMax();
122
+ // resolver.AddReduceMin();
123
+ resolver.AddRelu ();
124
+ resolver.AddRelu6 ();
125
+ resolver.AddReshape ();
126
+ // resolver.AddResizeNearestNeighbor();
127
+ // resolver.AddRfft2D(); - Doesn't compile in OpenMV Cam firmware
128
+ // resolver.AddRound();
129
+ // resolver.AddRsqrt();
130
+ // resolver.AddSelect();
131
+ // resolver.AddSelectV2();
132
+ resolver.AddShape ();
133
+ // resolver.AddSin();
134
+ // resolver.AddSlice(); - Doesn't compile in OpenMV Cam firmware
135
+ resolver.AddSoftmax ();
136
+ // resolver.AddSpaceToBatchNd();
137
+ // resolver.AddSplit();
138
+ // resolver.AddSplitV();
139
+ // resolver.AddSqrt();
140
+ // resolver.AddSquare();
141
+ // resolver.AddSquaredDifference(); - Doesn't compile in OpenMV Cam firmware
142
+ // resolver.AddSqueeze();
143
+ // resolver.AddStridedSlice();
144
+ resolver.AddSub ();
145
+ // resolver.AddSum();
146
+ // resolver.AddSvdf();
147
+ resolver.AddTanh ();
148
+ // resolver.AddTranspose(); - Doesn't compile in OpenMV Cam firmware
149
+ // resolver.AddTransposeConv();
150
+ // resolver.AddUnpack();
151
+ // resolver.AddZerosLike();
152
+ }
153
+
154
+ static void libtf_all_ops_init_op_resolver (tflite::MicroMutableOpResolver<LIBTF_MAX_OPS> &resolver) {
66
155
resolver.AddAbs ();
67
156
resolver.AddAdd ();
68
157
resolver.AddAddN ();
69
158
resolver.AddArgMax ();
70
159
resolver.AddArgMin ();
71
160
resolver.AddAveragePool2D ();
72
- // resolver.AddBatchMatMul();
161
+ // resolver.AddBatchMatMul(); - Doesn't compile in OpenMV Cam firmware
73
162
resolver.AddBatchToSpaceNd ();
163
+ resolver.AddCast ();
74
164
resolver.AddCeil ();
165
+ // resolver.AddCircularBuffer(); - Doesn't compile in OpenMV Cam firmware
75
166
resolver.AddComplexAbs ();
76
167
resolver.AddConcatenation ();
77
168
resolver.AddConv2D ();
@@ -121,35 +212,37 @@ extern "C" {
121
212
resolver.AddRelu6 ();
122
213
resolver.AddReshape ();
123
214
resolver.AddResizeNearestNeighbor ();
124
- // resolver.AddRfft2D();
215
+ // resolver.AddRfft2D(); - Doesn't compile in OpenMV Cam firmware
125
216
resolver.AddRound ();
126
217
resolver.AddRsqrt ();
127
218
// resolver.AddSelect();
128
219
// resolver.AddSelectV2();
129
220
resolver.AddShape ();
130
221
resolver.AddSin ();
131
- // resolver.AddSlice();
222
+ // resolver.AddSlice(); - Doesn't compile in OpenMV Cam firmware
132
223
resolver.AddSoftmax ();
133
224
resolver.AddSpaceToBatchNd ();
134
225
resolver.AddSplit ();
135
226
resolver.AddSplitV ();
136
227
resolver.AddSqrt ();
137
228
resolver.AddSquare ();
138
- // resolver.AddSquaredDifference();
229
+ // resolver.AddSquaredDifference(); - Doesn't compile in OpenMV Cam firmware
139
230
resolver.AddSqueeze ();
140
231
resolver.AddStridedSlice ();
141
232
resolver.AddSub ();
142
233
resolver.AddSum ();
143
234
resolver.AddSvdf ();
144
235
resolver.AddTanh ();
145
- // resolver.AddTranspose();
236
+ // resolver.AddTranspose(); - Doesn't compile in OpenMV Cam firmware
146
237
resolver.AddTransposeConv ();
147
238
resolver.AddUnpack ();
239
+ resolver.AddZerosLike ();
148
240
}
149
241
150
- int libtf_get_parameters (const unsigned char *model_data,
151
- unsigned char *tensor_arena, size_t tensor_arena_size,
152
- libtf_parameters_t *params) {
242
+ static int libtf_get_parameters (const unsigned char *model_data,
243
+ unsigned char *tensor_arena, size_t tensor_arena_size,
244
+ libtf_parameters_t *params,
245
+ libtf_resolver_init_t libtf_resolver_init) {
153
246
RegisterDebugLogCallback (libtf_debug_log);
154
247
155
248
tflite::MicroErrorReporter micro_error_reporter;
@@ -168,7 +261,7 @@ extern "C" {
168
261
}
169
262
170
263
tflite::MicroMutableOpResolver<LIBTF_MAX_OPS> resolver;
171
- libtf_init_op_resolver (resolver);
264
+ libtf_resolver_init (resolver);
172
265
173
266
tflite::MicroInterpreter interpreter (model, resolver, tensor_arena, tensor_arena_size, error_reporter);
174
267
@@ -287,13 +380,36 @@ extern "C" {
287
380
return 0 ;
288
381
}
289
382
290
- int libtf_invoke (const unsigned char *model_data,
291
- unsigned char *tensor_arena,
292
- libtf_parameters_t *params,
293
- libtf_input_data_callback_t input_callback,
294
- void *input_callback_data,
295
- libtf_output_data_callback_t output_callback,
296
- void *output_callback_data) {
383
+ int libtf_reduced_ops_get_parameters (const unsigned char *model_data,
384
+ unsigned char *tensor_arena,
385
+ size_t tensor_arena_size,
386
+ libtf_parameters_t *params) {
387
+ return libtf_get_parameters (model_data,
388
+ tensor_arena,
389
+ tensor_arena_size,
390
+ params,
391
+ libtf_reduced_ops_init_op_resolver);
392
+ }
393
+
394
+ int libtf_all_ops_get_parameters (const unsigned char *model_data,
395
+ unsigned char *tensor_arena,
396
+ size_t tensor_arena_size,
397
+ libtf_parameters_t *params) {
398
+ return libtf_get_parameters (model_data,
399
+ tensor_arena,
400
+ tensor_arena_size,
401
+ params,
402
+ libtf_all_ops_init_op_resolver);
403
+ }
404
+
405
+ static int libtf_invoke (const unsigned char *model_data,
406
+ unsigned char *tensor_arena,
407
+ libtf_parameters_t *params,
408
+ libtf_input_data_callback_t input_callback,
409
+ void *input_callback_data,
410
+ libtf_output_data_callback_t output_callback,
411
+ void *output_callback_data,
412
+ libtf_resolver_init_t libtf_resolver_init) {
297
413
RegisterDebugLogCallback (libtf_debug_log);
298
414
299
415
tflite::MicroErrorReporter micro_error_reporter;
@@ -314,7 +430,7 @@ extern "C" {
314
430
}
315
431
316
432
tflite::MicroMutableOpResolver<LIBTF_MAX_OPS> resolver;
317
- libtf_init_op_resolver (resolver);
433
+ libtf_resolver_init (resolver);
318
434
319
435
tflite::MicroInterpreter interpreter (model, resolver, tensor_arena, tensor_arena_size, error_reporter);
320
436
@@ -335,6 +451,40 @@ extern "C" {
335
451
return 0 ;
336
452
}
337
453
454
+ int libtf_reduced_ops_invoke (const unsigned char *model_data,
455
+ unsigned char *tensor_arena,
456
+ libtf_parameters_t *params,
457
+ libtf_input_data_callback_t input_callback,
458
+ void *input_callback_data,
459
+ libtf_output_data_callback_t output_callback,
460
+ void *output_callback_data) {
461
+ return libtf_invoke (model_data,
462
+ tensor_arena,
463
+ params,
464
+ input_callback,
465
+ input_callback_data,
466
+ output_callback,
467
+ output_callback_data,
468
+ libtf_reduced_ops_init_op_resolver);
469
+ }
470
+
471
+ int libtf_all_ops_invoke (const unsigned char *model_data,
472
+ unsigned char *tensor_arena,
473
+ libtf_parameters_t *params,
474
+ libtf_input_data_callback_t input_callback,
475
+ void *input_callback_data,
476
+ libtf_output_data_callback_t output_callback,
477
+ void *output_callback_data) {
478
+ return libtf_invoke (model_data,
479
+ tensor_arena,
480
+ params,
481
+ input_callback,
482
+ input_callback_data,
483
+ output_callback,
484
+ output_callback_data,
485
+ libtf_all_ops_init_op_resolver);
486
+ }
487
+
338
488
int libtf_initialize_micro_features () {
339
489
RegisterDebugLogCallback (libtf_debug_log);
340
490
0 commit comments