diff --git a/packages/mediapipe-core/lib/src/task_options.dart b/packages/mediapipe-core/lib/src/task_options.dart index c0f6fa7b..6291efb2 100644 --- a/packages/mediapipe-core/lib/src/task_options.dart +++ b/packages/mediapipe-core/lib/src/task_options.dart @@ -21,31 +21,68 @@ import 'third_party/mediapipe/generated/mediapipe_common_bindings.dart' /// classifier's desired behavior. class BaseOptions extends Equatable { /// Generative constructor that creates a [BaseOptions] instance. - const BaseOptions({this.modelAssetBuffer, this.modelAssetPath}) - : assert( + const BaseOptions._({ + this.modelAssetBuffer, + this.modelAssetPath, + this.modelAssetBufferCount, + required _BaseOptionsType type, + }) : assert( !(modelAssetBuffer == null && modelAssetPath == null), 'You must supply either `modelAssetBuffer` or `modelAssetPath`', ), assert( !(modelAssetBuffer != null && modelAssetPath != null), 'You must only supply one of `modelAssetBuffer` and `modelAssetPath`', - ); + ), + assert( + (modelAssetBuffer == null) == (modelAssetBufferCount == null), + 'modelAssetBuffer and modelAssetBufferCount must only be submitted ' + 'together', + ), + _type = type; + + /// Constructor for [BaseOptions] classes using a file system path. + /// + /// In practice, this is unsupported, as assets in Flutter are bundled into + /// the build output and not available on disk. However, it can potentially + /// be helpful for testing / development purposes. + factory BaseOptions.path(String path) => BaseOptions._( + modelAssetPath: path, + type: _BaseOptionsType.path, + ); + + /// Constructor for [BaseOptions] classes using an in-memory pointer to the + /// MediaPipe SDK. + /// + /// In practice, this is the only option supported for production builds. + factory BaseOptions.memory(Uint8List buffer) { + return BaseOptions._( + modelAssetBuffer: buffer, + modelAssetBufferCount: buffer.lengthInBytes, + type: _BaseOptionsType.memory, + ); + } /// The model asset file contents as bytes; final Uint8List? modelAssetBuffer; + /// The size of the model assets buffer (or `0` if not set). + final int? modelAssetBufferCount; + /// Path to the model asset file. final String? modelAssetPath; + final _BaseOptionsType _type; + /// Converts this pure-Dart representation into C-memory suitable for the /// MediaPipe SDK to instantiate various classifiers. Pointer toStruct() { final struct = calloc(); - if (modelAssetPath != null) { + if (_type == _BaseOptionsType.path) { struct.ref.model_asset_path = prepareString(modelAssetPath!); } - if (modelAssetBuffer != null) { + if (_type == _BaseOptionsType.memory) { struct.ref.model_asset_buffer = prepareUint8List(modelAssetBuffer!); struct.ref.model_asset_buffer_count = modelAssetBuffer!.lengthInBytes; } @@ -53,19 +90,15 @@ class BaseOptions extends Equatable { } @override - List get props => [modelAssetBuffer, modelAssetPath]; - - /// Releases all C memory held by this [bindings.BaseOptions] struct. - static void freeStruct(bindings.BaseOptions struct) { - if (struct.model_asset_buffer.address != 0) { - calloc.free(struct.model_asset_buffer); - } - if (struct.model_asset_path.address != 0) { - calloc.free(struct.model_asset_path); - } - } + List get props => [ + modelAssetBuffer, + modelAssetPath, + modelAssetBufferCount, + ]; } +enum _BaseOptionsType { path, memory } + /// Dart representation of MediaPipe's "ClassifierOptions" concept. /// /// Classifier options shared across MediaPipe classification tasks. diff --git a/packages/mediapipe-core/lib/src/third_party/mediapipe/generated/mediapipe_common_bindings.dart b/packages/mediapipe-core/lib/src/third_party/mediapipe/generated/mediapipe_common_bindings.dart index f6217336..51a2a2ef 100644 --- a/packages/mediapipe-core/lib/src/third_party/mediapipe/generated/mediapipe_common_bindings.dart +++ b/packages/mediapipe-core/lib/src/third_party/mediapipe/generated/mediapipe_common_bindings.dart @@ -24,7 +24,7 @@ final class BaseOptions extends ffi.Struct { external ffi.Pointer model_asset_path; - @ffi.UnsignedInt() + @ffi.Int() external int model_asset_buffer_count; } diff --git a/packages/mediapipe-core/test/task_options_test.dart b/packages/mediapipe-core/test/task_options_test.dart index 143a9ff4..f0947679 100644 --- a/packages/mediapipe-core/test/task_options_test.dart +++ b/packages/mediapipe-core/test/task_options_test.dart @@ -9,32 +9,16 @@ import 'package:test/test.dart'; import 'package:mediapipe_core/mediapipe_core.dart'; void main() { - group('BaseOptions constructor should', () { - test('enforce exactly one of modelPath and modelBuffer', () { - expect( - () => BaseOptions( - modelAssetPath: 'abc', - modelAssetBuffer: Uint8List.fromList([1, 2, 3]), - ), - throwsA(TypeMatcher()), - ); - - expect(BaseOptions.new, throwsA(TypeMatcher())); - }); - }); - group('BaseOptions.toStruct/fromStruct should', () { test('allocate memory in C for a modelAssetPath', () { - final options = BaseOptions(modelAssetPath: 'abc'); + final options = BaseOptions.path('abc'); final struct = options.toStruct(); expect(toDartString(struct.ref.model_asset_path), 'abc'); expectNullPtr(struct.ref.model_asset_buffer); }); test('allocate memory in C for a modelAssetBuffer', () { - final options = BaseOptions( - modelAssetBuffer: Uint8List.fromList([1, 2, 3]), - ); + final options = BaseOptions.memory(Uint8List.fromList([1, 2, 3])); final struct = options.toStruct(); expect( toUint8List(struct.ref.model_asset_buffer), @@ -44,9 +28,7 @@ void main() { }); test('allocate memory in C for a modelAssetBuffer containing 0', () { - final options = BaseOptions( - modelAssetBuffer: Uint8List.fromList([1, 2, 0, 3]), - ); + final options = BaseOptions.memory(Uint8List.fromList([1, 2, 0, 3])); final struct = options.toStruct(); expect( toUint8List(struct.ref.model_asset_buffer), diff --git a/packages/mediapipe-task-text/build.dart b/packages/mediapipe-task-text/build.dart new file mode 100644 index 00000000..f35c7e3b --- /dev/null +++ b/packages/mediapipe-task-text/build.dart @@ -0,0 +1,41 @@ +import 'dart:io'; +import 'package:native_assets_cli/native_assets_cli.dart'; +import 'package:http/http.dart' as http; + +const cloudAssetFilename = 'libtext_classifier-v0.0.3.dylib'; +const localAssetFilename = 'libtext_classifier.dylib'; +const assetLocation = + 'https://storage.googleapis.com/random-storage-asdf/$cloudAssetFilename'; + +Future main(List args) async { + final buildConfig = await BuildConfig.fromArgs(args); + final buildOutput = BuildOutput(); + final downloadFileLocation = buildConfig.outDir.resolve(localAssetFilename); + if (!buildConfig.dryRun) { + final downloadUri = Uri.parse(assetLocation); + final downloadResponse = await http.get(downloadUri); + final downloadedFile = File(downloadFileLocation.toFilePath()); + if (downloadResponse.statusCode == 200) { + if (downloadedFile.existsSync()) { + downloadedFile.deleteSync(); + } + downloadedFile.createSync(); + downloadedFile.writeAsBytes(downloadResponse.bodyBytes); + } else { + throw Exception( + '${downloadResponse.statusCode} :: ${downloadResponse.body}'); + } + } + buildOutput.dependencies.dependencies + .add(buildConfig.packageRoot.resolve('build.dart')); + buildOutput.assets.add( + Asset( + // What should this `id` be? + id: 'package:mediapipe_text/src/mediapipe_text_bindings.dart', + linkMode: LinkMode.dynamic, + target: Target.macOSArm64, + path: AssetAbsolutePath(downloadFileLocation), + ), + ); + await buildOutput.writeToFile(outDir: buildConfig.outDir); +} diff --git a/packages/mediapipe-task-text/example/lib/main.dart b/packages/mediapipe-task-text/example/lib/main.dart index acbf1237..26a4d060 100644 --- a/packages/mediapipe-task-text/example/lib/main.dart +++ b/packages/mediapipe-task-text/example/lib/main.dart @@ -38,6 +38,7 @@ class _MainAppState extends State { super.initState(); _controller.text = 'Hello, world!'; _initClassifier(); + Future.delayed(const Duration(milliseconds: 500)).then((_) => _classify()); } Future _initClassifier() async { diff --git a/packages/mediapipe-task-text/example/pubspec.yaml b/packages/mediapipe-task-text/example/pubspec.yaml index 714a533a..1335c091 100644 --- a/packages/mediapipe-task-text/example/pubspec.yaml +++ b/packages/mediapipe-task-text/example/pubspec.yaml @@ -25,4 +25,3 @@ flutter: assets: - assets/bert_classifier.tflite - - assets/libtext_classifier.dylib diff --git a/packages/mediapipe-task-text/lib/src/tasks/text_classification/containers/text_classifier_options.dart b/packages/mediapipe-task-text/lib/src/tasks/text_classification/containers/text_classifier_options.dart index f7fb7e55..00a76d51 100644 --- a/packages/mediapipe-task-text/lib/src/tasks/text_classification/containers/text_classifier_options.dart +++ b/packages/mediapipe-task-text/lib/src/tasks/text_classification/containers/text_classifier_options.dart @@ -32,7 +32,7 @@ class TextClassifierOptions { }) { assert(!kIsWeb, 'fromAssetPath cannot be used on the web'); return TextClassifierOptions( - baseOptions: BaseOptions(modelAssetPath: assetPath), + baseOptions: BaseOptions.path(assetPath), classifierOptions: classifierOptions, ); } @@ -45,7 +45,7 @@ class TextClassifierOptions { ClassifierOptions classifierOptions = const ClassifierOptions(), }) => TextClassifierOptions( - baseOptions: BaseOptions(modelAssetBuffer: assetBuffer), + baseOptions: BaseOptions.memory(assetBuffer), classifierOptions: classifierOptions, ); diff --git a/packages/mediapipe-task-text/pubspec.yaml b/packages/mediapipe-task-text/pubspec.yaml index 830a0af7..a443216d 100644 --- a/packages/mediapipe-task-text/pubspec.yaml +++ b/packages/mediapipe-task-text/pubspec.yaml @@ -16,8 +16,11 @@ dependencies: sdk: flutter http: ^1.1.0 logging: ^1.2.0 + http: ^1.1.0 mediapipe_core: path: ../mediapipe-core + native_assets_cli: ^0.3.0 + native_toolchain_c: ^0.3.0 dev_dependencies: ffigen: ^9.0.1