From e9d57a34409f1483e070c653e7f4f5178e7de798 Mon Sep 17 00:00:00 2001 From: TFLM-bot Date: Tue, 25 Jul 2023 14:03:42 +0000 Subject: [PATCH] Sync from upstream TF. --- tensorflow/lite/core/api/op_resolver.h | 89 +++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h index e8a4e32771a..b43f1adc664 100644 --- a/tensorflow/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -16,7 +16,10 @@ limitations under the License. #define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ #include +#include #include +#include +#include #include #include "tensorflow/lite/core/api/error_reporter.h" @@ -25,6 +28,16 @@ limitations under the License. namespace tflite { +#ifndef DOXYGEN_SKIP +class OpResolverInternal; // For friend declaration below. +class Subgraph; // For friend declaration below. + +namespace internal { +class CommonOpaqueConversionUtil; // For friend declaration below. +class RegistrationExternalsCache; // Forward decl. +} // namespace internal +#endif + /// Abstract interface that returns TfLiteRegistrations given op codes or custom /// op names. This is the mechanism that ops being referenced in the flatbuffer /// model are mapped to executable function pointers (TfLiteRegistrations). @@ -104,7 +117,9 @@ class OpResolver { return {}; } - virtual ~OpResolver() {} + virtual ~OpResolver() = default; + OpResolver() = default; + OpResolver(const OpResolver& other) = default; private: /// Returns true if this OpResolver may contain any "user defined" ops. @@ -120,8 +135,80 @@ class OpResolver { /// "builtin" ops, and may not support all of the "builtin" op enum values. virtual bool MayContainUserDefinedOps() const { return true; } +#ifndef DOXYGEN_SKIP friend class OpResolverInternal; + friend class Subgraph; // For OpId. + friend class tflite::internal::CommonOpaqueConversionUtil; + friend class tflite::internal::RegistrationExternalsCache; +#endif + + // This holds the identity of an operator. + // Ths is used as the key for the RegistrationExternalsCache below. + struct OpId { + int builtin_code; + const char* custom_name; + int version; + bool operator==(const OpId& other) const { + return builtin_code == other.builtin_code && + custom_name == other.custom_name && version == other.version; + } + struct Hasher { + size_t operator()(const OpId& op_id) const { + size_t hash_builtin_code = std::hash()(op_id.builtin_code); + size_t hash_custom_name = + op_id.custom_name != nullptr + ? std::hash()(std::string(op_id.custom_name)) + : 0; + size_t hash_version = std::hash()(op_id.version); + return Combine(hash_builtin_code, + Combine(hash_custom_name, hash_version)); + } + + private: + static size_t Combine(size_t hash1, size_t hash2) { + constexpr int num_bits_to_rotate_left = 21; + constexpr int num_bits_to_rotate_right = + std::numeric_limits::digits - num_bits_to_rotate_left; + size_t hash1_rotated = (hash1 << num_bits_to_rotate_left) | + (hash1 >> num_bits_to_rotate_right); + return hash1_rotated + hash2; + } + }; + }; + + // A set of 'TfLiteRegistrationExternal' objects whose lifetimes need to + // last at least as long as the lifetime of the OpResolver. + // We use shared_ptr rather than unique_ptr here, to allow the + // RegistrationExternalsCache to be shared with other classes such as the + // InterpreterBuilder and Interpreter. This is so that the + // TfLiteRegistrationExternal objects allocated by an OpResolver, + // which may be referenced by a Subgraph in an Interpreter, can remain live + // even if the OpResolver is destroyed, while also allowing the same + // OpResolver to be used with multiple InterpreterBuilders and multiple + // Interpreters. + mutable std::shared_ptr + registration_externals_cache_; +}; + +#ifndef DOXYGEN_SKIP +// Type for a set of owned 'TfLiteRegistrationExternal' objects. +// This is needed when converting TfLiteRegistration to +// TfLiteRegistrationExternal, to ensure that the number of +// TfLiteRegistrationExternal objects that we allocate is bounded, and to +// ensure that those objects get deallocated at the appropriate time. +// We use a public class rather than a typedef or using declaration here, +// to ensure that the class can be forward-declared. +// WARNING: Experimental interface, subject to change. +namespace internal { +class RegistrationExternalsCache + : private std::unordered_map, + OpResolver::OpId::Hasher> { + friend class ::tflite::Subgraph; + friend class ::tflite::internal::CommonOpaqueConversionUtil; }; +} // namespace internal +#endif // Handles the logic for converting between an OperatorCode structure extracted // from a flatbuffer and information about a registered operator