diff --git a/Makefile b/Makefile index 1af52ff2..c6c007c5 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ headers: models: cd tool/builder && dart bin/main.dart model -m textclassification cd tool/builder && dart bin/main.dart model -m textembedding - + cd tool/builder && dart bin/main.dart model -m languagedetection # Runs `ffigen` for all packages generate: generate_core generate_text diff --git a/packages/mediapipe-core/lib/universal_mediapipe_core.dart b/packages/mediapipe-core/lib/universal_mediapipe_core.dart index 0c64cccd..6e5177f1 100644 --- a/packages/mediapipe-core/lib/universal_mediapipe_core.dart +++ b/packages/mediapipe-core/lib/universal_mediapipe_core.dart @@ -62,7 +62,7 @@ class Classifications extends BaseClassifications { }); @override - Iterable get categories => throw UnimplementedError(); + Iterable get categories => throw UnimplementedError(); @override int get headIndex => throw UnimplementedError(); diff --git a/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/containers/classification_result.h b/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/containers/classification_result.h index 43bb87bc..d03fe05e 100644 --- a/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/containers/classification_result.h +++ b/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/containers/classification_result.h @@ -28,7 +28,6 @@ struct Classifications { // The array of predicted categories, usually sorted by descending scores, // e.g. from high to low probability. struct Category* categories; - // The number of elements in the categories array. uint32_t categories_count; @@ -58,7 +57,6 @@ struct ClassificationResult { // exceed the maximum size that the model can process: to solve this, the // input data is split into multiple chunks starting at different timestamps. int64_t timestamp_ms; - // Specifies whether the timestamp contains a valid value. bool has_timestamp_ms; }; diff --git a/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/processors/classifier_options.h b/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/processors/classifier_options.h index be82bc7a..32ad22b0 100644 --- a/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/processors/classifier_options.h +++ b/packages/mediapipe-core/third_party/mediapipe/tasks/c/components/processors/classifier_options.h @@ -41,7 +41,6 @@ struct ClassifierOptions { // category name is not in this set will be filtered out. Duplicate or unknown // category names are ignored. Mutually exclusive with category_denylist. const char** category_allowlist; - // The number of elements in the category allowlist. uint32_t category_allowlist_count; @@ -49,7 +48,6 @@ struct ClassifierOptions { // category name is in this set will be filtered out. Duplicate or unknown // category names are ignored. Mutually exclusive with category_allowlist. const char** category_denylist; - // The number of elements in the category denylist. uint32_t category_denylist_count; }; diff --git a/packages/mediapipe-task-text/example/lib/enumerate.dart b/packages/mediapipe-task-text/example/lib/enumerate.dart index d0e70110..97c941fc 100644 --- a/packages/mediapipe-task-text/example/lib/enumerate.dart +++ b/packages/mediapipe-task-text/example/lib/enumerate.dart @@ -2,12 +2,34 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -extension Enumeratable on List { - Iterable enumerate(S Function(T, int) fn) sync* { +extension EnumeratableList on List { + /// Invokes the callback on each element of the list, optionally stopping + /// after [max] (inclusive) invocations. + Iterable enumerate(S Function(T, int) fn, {int? max}) sync* { int count = 0; while (count < length) { yield fn(this[count], count); count++; + + if (max != null && count >= max) { + return; + } + } + } +} + +extension EnumeratableIterable on Iterable { + /// Invokes the callback on each element of the iterable, optionally stopping + /// after [max] (inclusive) invocations. + Iterable enumerate(S Function(T, int) fn, {int? max}) sync* { + int count = 0; + for (final T obj in this) { + yield fn(obj, count); + count++; + + if (max != null && count >= max) { + return; + } } } } diff --git a/packages/mediapipe-task-text/example/lib/language_detection_demo.dart b/packages/mediapipe-task-text/example/lib/language_detection_demo.dart new file mode 100644 index 00000000..c946f850 --- /dev/null +++ b/packages/mediapipe-task-text/example/lib/language_detection_demo.dart @@ -0,0 +1,158 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:async'; +import 'dart:typed_data'; +import 'package:flutter/material.dart'; +import 'package:getwidget/getwidget.dart'; +import 'package:mediapipe_text/mediapipe_text.dart'; +import 'enumerate.dart'; + +class LanguageDetectionDemo extends StatefulWidget { + const LanguageDetectionDemo({super.key, this.detector}); + + final LanguageDetector? detector; + + @override + State createState() => _LanguageDetectionDemoState(); +} + +class _LanguageDetectionDemoState extends State + with AutomaticKeepAliveClientMixin { + final TextEditingController _controller = TextEditingController(); + final Completer _completer = Completer(); + final results = []; + String? _isProcessing; + + @override + void initState() { + super.initState(); + _controller.text = 'Quiero agua, por favor'; + _initDetector(); + } + + Future _initDetector() async { + if (widget.detector != null) { + return _completer.complete(widget.detector!); + } + + ByteData? bytes = await DefaultAssetBundle.of(context) + .load('assets/language_detector.tflite'); + + final detector = LanguageDetector( + LanguageDetectorOptions.fromAssetBuffer( + bytes.buffer.asUint8List(), + ), + ); + _completer.complete(detector); + bytes = null; + } + + void _prepareForDetection() { + setState(() { + _isProcessing = _controller.text; + results.add(const CircularProgressIndicator.adaptive()); + }); + } + + Future _detect() async { + _prepareForDetection(); + _completer.future.then((detector) async { + final result = await detector.detect(_controller.text); + _showDetectionResults(result); + result.dispose(); + }); + } + + void _showDetectionResults(LanguageDetectorResult result) { + setState( + () { + results.last = Card( + key: Key('prediction-"$_isProcessing" ${results.length}'), + margin: const EdgeInsets.all(10), + child: Column( + children: [ + Padding( + padding: const EdgeInsets.all(10), + child: Text(_isProcessing!), + ), + Padding( + padding: const EdgeInsets.all(10.0), + child: Wrap( + children: [ + ...result.predictions + .enumerate( + (prediction, index) => _languagePrediction( + prediction, + predictionColors[index], + ), + // Take first 4 because the model spits out dozens of + // astronomically low probability language predictions + max: predictionColors.length, + ) + .toList(), + ], + ), + ), + ], + ), + ); + _isProcessing = null; + }, + ); + } + + static final predictionColors = [ + Colors.blue[300]!, + Colors.orange[300]!, + Colors.green[300]!, + Colors.red[300]!, + ]; + + Widget _languagePrediction(LanguagePrediction prediction, Color color) { + return Padding( + padding: const EdgeInsets.only(right: 8), + child: GFButton( + onPressed: null, + text: '${prediction.languageCode} :: ' + '${prediction.probability.roundTo(8)}', + shape: GFButtonShape.pills, + color: color, + ), + ); + } + + @override + Widget build(BuildContext context) { + super.build(context); + return Scaffold( + body: SafeArea( + child: Padding( + padding: const EdgeInsets.all(16.0), + child: SingleChildScrollView( + child: Column( + children: [ + TextField(controller: _controller), + ...results.reversed, + ], + ), + ), + ), + ), + floatingActionButton: FloatingActionButton( + onPressed: + _isProcessing != null && _controller.text != '' ? null : _detect, + child: const Icon(Icons.search), + ), + ); + } + + @override + bool get wantKeepAlive => true; +} + +extension on double { + double roundTo(int decimalPlaces) => + double.parse(toStringAsFixed(decimalPlaces)); +} diff --git a/packages/mediapipe-task-text/example/lib/main.dart b/packages/mediapipe-task-text/example/lib/main.dart index db9ad11f..1830c75a 100644 --- a/packages/mediapipe-task-text/example/lib/main.dart +++ b/packages/mediapipe-task-text/example/lib/main.dart @@ -1,4 +1,5 @@ import 'package:flutter/material.dart'; +import 'language_detection_demo.dart'; import 'logging.dart'; import 'text_classification_demo.dart'; import 'text_embedding_demo.dart'; @@ -31,7 +32,7 @@ class TextTaskPages extends StatefulWidget { class TextTaskPagesState extends State { final PageController controller = PageController(); - final titles = ['Classify', 'Embed']; + final titles = ['Classify', 'Embed', 'Detect Languages']; int titleIndex = 0; void switchToPage(int index) { @@ -61,28 +62,39 @@ class TextTaskPagesState extends State { children: const [ TextClassificationDemo(), TextEmbeddingDemo(), + LanguageDetectionDemo(), ], ), - bottomNavigationBar: ColoredBox( - color: Colors.blueGrey, - child: Row( - mainAxisAlignment: MainAxisAlignment.spaceEvenly, - children: [ - TextButton( - onPressed: () => switchToPage(0), - child: Text( - 'Classify', - style: titleIndex == 0 ? activeTextStyle : inactiveTextStyle, + bottomNavigationBar: SizedBox( + height: 50, + child: ColoredBox( + color: Colors.blueGrey, + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceEvenly, + children: [ + TextButton( + onPressed: () => switchToPage(0), + child: Text( + 'Classify', + style: titleIndex == 0 ? activeTextStyle : inactiveTextStyle, + ), ), - ), - TextButton( - onPressed: () => switchToPage(1), - child: Text( - 'Embed', - style: titleIndex == 1 ? activeTextStyle : inactiveTextStyle, + TextButton( + onPressed: () => switchToPage(1), + child: Text( + 'Embed', + style: titleIndex == 1 ? activeTextStyle : inactiveTextStyle, + ), ), - ), - ], + TextButton( + onPressed: () => switchToPage(2), + child: Text( + 'Detect Languages', + style: titleIndex == 2 ? activeTextStyle : inactiveTextStyle, + ), + ), + ], + ), ), ), ); diff --git a/packages/mediapipe-task-text/example/lib/text_classification_demo.dart b/packages/mediapipe-task-text/example/lib/text_classification_demo.dart index 38266e9f..bda040cb 100644 --- a/packages/mediapipe-task-text/example/lib/text_classification_demo.dart +++ b/packages/mediapipe-task-text/example/lib/text_classification_demo.dart @@ -5,8 +5,10 @@ import 'dart:async'; import 'dart:typed_data'; import 'package:flutter/material.dart'; +import 'package:getwidget/getwidget.dart'; +import 'package:mediapipe_core/mediapipe_core.dart'; import 'package:mediapipe_text/mediapipe_text.dart'; -import 'logging.dart'; +import 'enumerate.dart'; class TextClassificationDemo extends StatefulWidget { const TextClassificationDemo({super.key, this.classifier}); @@ -17,14 +19,13 @@ class TextClassificationDemo extends StatefulWidget { State createState() => _TextClassificationDemoState(); } -class _TextClassificationDemoState extends State { +class _TextClassificationDemoState extends State + with AutomaticKeepAliveClientMixin { final TextEditingController _controller = TextEditingController(); - List results = []; + final Completer _completer = Completer(); + final results = []; String? _isProcessing; - TextClassifier? _classifier; - Completer? _completer; - @override void initState() { super.initState(); @@ -32,28 +33,20 @@ class _TextClassificationDemoState extends State { _initClassifier(); } - Future get classifier { + Future _initClassifier() async { if (widget.classifier != null) { - return Future.value(widget.classifier!); + return _completer.complete(widget.classifier!); } - if (_completer == null) { - _initClassifier(); - } - return _completer!.future; - } - Future _initClassifier() async { - _classifier?.dispose(); - _completer = Completer(); ByteData? classifierBytes = await DefaultAssetBundle.of(context) .load('assets/bert_classifier.tflite'); - TextClassifier classifier = TextClassifier( + final classifier = TextClassifier( TextClassifierOptions.fromAssetBuffer( classifierBytes.buffer.asUint8List(), ), ); - _completer!.complete(classifier); + _completer.complete(classifier); classifierBytes = null; } @@ -64,51 +57,102 @@ class _TextClassificationDemoState extends State { }); } - void _showClassificationResults(TextClassifierResult classification) { - setState(() { - final categoryName = - classification.firstClassification?.firstCategory?.categoryName; - final score = classification.firstClassification?.firstCategory?.score; - // Replace "..." with the results - final message = '"$_isProcessing" $categoryName :: $score'; - log.info(message); - results.last = Card( - key: Key('Classification::"$_isProcessing" ${results.length}'), - margin: const EdgeInsets.all(10), - child: Padding( - padding: const EdgeInsets.all(10), - child: Text(message), - ), - ); - _isProcessing = null; - }); + void _showClassificationResults(TextClassifierResult result) { + final categoryWidgets = []; + for (final classifications in result.classifications) { + categoryWidgets.addAll(_textClassifications(classifications)); + } + + setState( + () { + results.last = Card( + key: Key('Classification::"$_isProcessing" ${results.length}'), + margin: const EdgeInsets.all(10), + child: Column( + children: [ + Padding( + padding: const EdgeInsets.all(10), + child: Text(_isProcessing!), + ), + Padding( + padding: const EdgeInsets.all(10.0), + child: Wrap( + children: [ + ...categoryWidgets, + ], + ), + ), + ], + ), + ); + _isProcessing = null; + }, + ); + } + + static final categoryColors = [ + Colors.blue[300]!, + Colors.orange[300]!, + Colors.green[300]!, + Colors.red[300]!, + ]; + + List _textClassifications(Classifications classifications) { + return classifications.categories + .enumerate((category, index) => + _textClassification(category, categoryColors[index])) + .toList(); + } + + Widget _textClassification(Category category, Color color) { + return Padding( + padding: const EdgeInsets.only(right: 8), + child: GFButton( + onPressed: null, + text: '${category.displayName ?? category.categoryName} :: ' + '${category.score.roundTo(4)}', + shape: GFButtonShape.pills, + color: color, + ), + ); } Future _classify() async { _prepareForClassification(); - final classification = await (await classifier).classify(_controller.text); - _showClassificationResults(classification); + _completer.future.then((classifier) async { + final result = await classifier.classify(_controller.text); + _showClassificationResults(result); + }); } @override - Widget build(BuildContext context) => // - Scaffold( - body: SafeArea( - child: Padding( - padding: const EdgeInsets.all(16.0), - child: Column( - children: [ - TextField(controller: _controller), - ...results, - ], - ), + Widget build(BuildContext context) { + super.build(context); + return Scaffold( + body: SafeArea( + child: Padding( + padding: const EdgeInsets.all(16.0), + child: Column( + children: [ + TextField(controller: _controller), + ...results, + ], ), ), - floatingActionButton: FloatingActionButton( - onPressed: _isProcessing != null && _controller.text != '' - ? null - : _classify, - child: const Icon(Icons.search), - ), - ); + ), + floatingActionButton: FloatingActionButton( + onPressed: + _isProcessing != null && _controller.text != '' ? null : _classify, + child: const Icon(Icons.search), + ), + ); + } + + @override + bool get wantKeepAlive => true; +} + +extension on double { + double roundTo(int decimalPlaces) => + double.parse(toStringAsFixed(decimalPlaces)); } diff --git a/packages/mediapipe-task-text/example/lib/text_embedding_demo.dart b/packages/mediapipe-task-text/example/lib/text_embedding_demo.dart index 9376fd40..c2a92817 100644 --- a/packages/mediapipe-task-text/example/lib/text_embedding_demo.dart +++ b/packages/mediapipe-task-text/example/lib/text_embedding_demo.dart @@ -24,7 +24,8 @@ class TextEmbeddingDemo extends StatefulWidget { State createState() => _TextEmbeddingDemoState(); } -class _TextEmbeddingDemoState extends State { +class _TextEmbeddingDemoState extends State + with AutomaticKeepAliveClientMixin { final TextEditingController _controller = TextEditingController(); List feed = []; EmbeddingType type = EmbeddingType.quantized; @@ -178,102 +179,107 @@ class _TextEmbeddingDemoState extends State { } @override - Widget build(BuildContext context) => // - Scaffold( - body: SafeArea( - child: Padding( - padding: const EdgeInsets.all(16.0), - child: SingleChildScrollView( - child: Column( - children: [ - Row( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Column( - mainAxisAlignment: MainAxisAlignment.start, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Row( - children: [ - const Text('Float:'), - Checkbox( - value: type == EmbeddingType.float, - onChanged: (_) { - toggleMode(); - }, - ), - ], - ), - // Quantized checkbox - Row( - children: [ - const Text('Quantize:'), - Checkbox( - value: type == EmbeddingType.quantized, - onChanged: (bool? newValue) { - toggleMode(); - }, - ), - ], - ), - ], - ), - Column( - mainAxisAlignment: MainAxisAlignment.start, - children: [ - Row( - children: [ - const Text('L2 Normalize:'), - Checkbox( - value: l2Normalize, - onChanged: (_) { - toggleL2Normalize(); - }, - ), - ], - ), - ], - ), - ], - ), - // Float checkbox - TextField(controller: _controller), - ...feed.reversed.toList().enumerate( - (EmbeddingFeedItem feedItem, index) { - return switch (feedItem._type) { - _EmbeddingFeedItemType.result => - TextEmbedderResultDisplay( - embeddedText: feedItem.embeddingResult!, - index: index, - ), - _EmbeddingFeedItemType.emptyComparison => TextButton( - // Subtract `index` from `feed.length` because we - // are looping through the list in reverse order - onPressed: () => _compare(feed.length - index - 1), - style: ButtonStyle( - backgroundColor: WidgetStateProperty.all( - Colors.purple[100], - ), + Widget build(BuildContext context) { + super.build(context); + return Scaffold( + body: SafeArea( + child: Padding( + padding: const EdgeInsets.all(16.0), + child: SingleChildScrollView( + child: Column( + children: [ + Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Column( + mainAxisAlignment: MainAxisAlignment.start, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Row( + children: [ + const Text('Float:'), + Checkbox( + value: type == EmbeddingType.float, + onChanged: (_) { + toggleMode(); + }, + ), + ], + ), + // Quantized checkbox + Row( + children: [ + const Text('Quantize:'), + Checkbox( + value: type == EmbeddingType.quantized, + onChanged: (bool? newValue) { + toggleMode(); + }, + ), + ], + ), + ], + ), + Column( + mainAxisAlignment: MainAxisAlignment.start, + children: [ + Row( + children: [ + const Text('L2 Normalize:'), + Checkbox( + value: l2Normalize, + onChanged: (_) { + toggleL2Normalize(); + }, + ), + ], + ), + ], + ), + ], + ), + // Float checkbox + TextField(controller: _controller), + ...feed.reversed.toList().enumerate( + (EmbeddingFeedItem feedItem, index) { + return switch (feedItem._type) { + _EmbeddingFeedItemType.result => + TextEmbedderResultDisplay( + embeddedText: feedItem.embeddingResult!, + index: index, + ), + _EmbeddingFeedItemType.emptyComparison => TextButton( + // Subtract `index` from `feed.length` because we + // are looping through the list in reverse order + onPressed: () => _compare(feed.length - index - 1), + style: ButtonStyle( + backgroundColor: WidgetStateProperty.all( + Colors.purple[100], ), - child: const Text('Compare'), ), - _EmbeddingFeedItemType.comparison => - ComparisonDisplay(similarity: feedItem.similarity!), - _EmbeddingFeedItemType.incomparable => const Text( - 'Embeddings of different types cannot be compared'), - }; - }, - ), - ], - ), + child: const Text('Compare'), + ), + _EmbeddingFeedItemType.comparison => + ComparisonDisplay(similarity: feedItem.similarity!), + _EmbeddingFeedItemType.incomparable => const Text( + 'Embeddings of different types cannot be compared'), + }; + }, + ), + ], ), ), ), - floatingActionButton: FloatingActionButton( - onPressed: isProcessing || _controller.text == '' ? null : _embed, - child: const Icon(Icons.search), - ), - ); + ), + floatingActionButton: FloatingActionButton( + onPressed: isProcessing || _controller.text == '' ? null : _embed, + child: const Icon(Icons.search), + ), + ); + } + + @override + bool get wantKeepAlive => true; } /// Shows, in the activity feed, the results of invoking `cosineSimilarity` diff --git a/packages/mediapipe-task-text/example/pubspec.yaml b/packages/mediapipe-task-text/example/pubspec.yaml index a8380ebb..d6376364 100644 --- a/packages/mediapipe-task-text/example/pubspec.yaml +++ b/packages/mediapipe-task-text/example/pubspec.yaml @@ -28,4 +28,5 @@ flutter: assets: - assets/bert_classifier.tflite + - assets/language_detector.tflite - assets/universal_sentence_encoder.tflite diff --git a/packages/mediapipe-task-text/example/test/widgets_test.dart b/packages/mediapipe-task-text/example/test/widgets_test.dart index c30892ab..fc94fa4f 100644 --- a/packages/mediapipe-task-text/example/test/widgets_test.dart +++ b/packages/mediapipe-task-text/example/test/widgets_test.dart @@ -2,6 +2,7 @@ import 'package:flutter/material.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:mediapipe_core/mediapipe_core.dart'; import 'package:mediapipe_text/mediapipe_text.dart'; +import 'package:example/language_detection_demo.dart'; import 'package:example/text_classification_demo.dart'; class FakeTextClassifier extends TextClassifier { @@ -18,7 +19,7 @@ class FakeTextClassifier extends TextClassifier { index: 0, score: 0.9, categoryName: 'happy-go-lucky', - displayName: null, + displayName: 'Happy go Lucky', ), ], headIndex: 0, @@ -30,10 +31,33 @@ class FakeTextClassifier extends TextClassifier { } } +class FakeLanguageDetector extends LanguageDetector { + FakeLanguageDetector(LanguageDetectorOptions options) : super(options); + + @override + Future detect(String text) { + return Future.value( + LanguageDetectorResult( + predictions: [ + LanguagePrediction( + languageCode: 'es', + probability: 0.99, + ), + LanguagePrediction( + languageCode: 'en', + probability: 0.01, + ), + ], + ), + ); + } +} + void main() { TestWidgetsFlutterBinding.ensureInitialized(); - testWidgets('TextClassificationResults should show results', - (WidgetTester tester) async { + testWidgets('TextClassificationResult should show results', ( + WidgetTester tester, + ) async { final app = MaterialApp( home: TextClassificationDemo( classifier: FakeTextClassifier( @@ -49,9 +73,28 @@ void main() { find.byKey(const Key('Classification::"Hello, world!" 1')), findsOneWidget, ); + expect(find.text('Happy go Lucky :: 0.9'), findsOneWidget); + }); + + testWidgets('LanguageDetectorResult should show results', ( + WidgetTester tester, + ) async { + final app = MaterialApp( + home: LanguageDetectionDemo( + detector: FakeLanguageDetector( + LanguageDetectorOptions.fromAssetPath('fake'), + ), + ), + ); + + await tester.pumpWidget(app); + await tester.tap(find.byType(Icon)); + await tester.pumpAndSettle(); expect( - find.text('"Hello, world!" happy-go-lucky :: 0.9'), + find.byKey(const Key('prediction-"Quiero agua, por favor" 1')), findsOneWidget, ); + expect(find.text('es :: 0.99'), findsOneWidget); + expect(find.text('en :: 0.01'), findsOneWidget); }); } diff --git a/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detection.dart b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detection.dart new file mode 100644 index 00000000..66095b88 --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detection.dart @@ -0,0 +1,7 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +export 'language_detector_result.dart'; +export 'language_detector_options.dart'; +export 'language_detector.dart'; diff --git a/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector.dart b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector.dart new file mode 100644 index 00000000..8421b673 --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector.dart @@ -0,0 +1,28 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'language_detection.dart'; + +/// {@template LanguageDetector} +/// Predicts the language of an input text. +/// +/// Usage: +/// ```dart +/// final options = LanguageDetectorOptions(); // optional parameters +/// final detector = LanguageDetector(options); +/// final result = await detector.detect('¿Como estas?') +/// print(result.predictions.first.languageCode); +/// > "es" +/// detector.dispose(); +/// ``` +/// {@endtemplate} +abstract class BaseLanguageDetector { + /// {@template LanguageDetector.detect} + /// Sends a [String] value to MediaPipe for language detection. + /// {@endtemplate} + Future detect(String text); + + /// Cleans up all resources. + void dispose(); +} diff --git a/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector_options.dart b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector_options.dart new file mode 100644 index 00000000..b4d52ea3 --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector_options.dart @@ -0,0 +1,31 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'package:mediapipe_core/interface.dart'; + +/// {@template TextClassifierOptions} +/// Configuration object for a MediaPipe text classifier. +/// +/// See also: +/// * [MediaPipe's TextClassifierOptions documentation](https://developers.google.com/mediapipe/api/solutions/js/tasks-text.textclassifieroptions) +/// {@endtemplate} +/// +/// This implementation is not immutable to track whether `dispose` has been +/// called. All values used by pkg:equatable are in fact immutable. +// ignore: must_be_immutable +abstract class BaseLanguageDetectorOptions extends BaseTaskOptions { + /// Contains parameter options for how this classifier should behave, + /// including allow and denylists, thresholds, maximum results, etc. + /// + /// See also: + /// * [BaseClassifierOptions] for each available field. + BaseClassifierOptions get classifierOptions; + + @override + String toString() => 'LanguageDetectorOptions(baseOptions: $baseOptions, ' + 'classifierOptions: $classifierOptions)'; + + @override + List get props => [baseOptions, classifierOptions]; +} diff --git a/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector_result.dart b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector_result.dart new file mode 100644 index 00000000..31e4f76a --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/interface/tasks/language_detection/language_detector_result.dart @@ -0,0 +1,28 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'package:mediapipe_core/interface.dart'; + +/// {@template LanguageDetectionResult} +/// Container with results of MediaPipe's language detection task. +/// +/// See also: +/// * [MediaPipe's LanguageDetectionResult documentation](https://developers.google.com/mediapipe/api/solutions/java/com/google/mediapipe/tasks/text/languagedetector/LanguageDetectorResult) +/// {@endtemplate} +abstract class BaseLanguageDetectorResult extends TaskResult { + /// A list of predictions from the LanguageDetector. + Iterable get predictions; +} + +/// {@template LanguagePrediction} +/// A language code and its probability. Used as part of the output of +/// a language detector. +/// {@endtemplate} +abstract class BaseLanguagePrediction { + /// The i18n language / locale code for the prediction. + String get languageCode; + + /// The probability for the prediction. + double get probability; +} diff --git a/packages/mediapipe-task-text/lib/src/interface/tasks/tasks.dart b/packages/mediapipe-task-text/lib/src/interface/tasks/tasks.dart index 3f846929..d5915de3 100644 --- a/packages/mediapipe-task-text/lib/src/interface/tasks/tasks.dart +++ b/packages/mediapipe-task-text/lib/src/interface/tasks/tasks.dart @@ -2,5 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +export 'language_detection/language_detection.dart'; export 'text_embedding/text_embedding.dart'; export 'text_classification/text_classification.dart'; diff --git a/packages/mediapipe-task-text/lib/src/interface/tasks/text_embedding/text_embedder.dart b/packages/mediapipe-task-text/lib/src/interface/tasks/text_embedding/text_embedder.dart index f4d72715..864c9a4e 100644 --- a/packages/mediapipe-task-text/lib/src/interface/tasks/text_embedding/text_embedder.dart +++ b/packages/mediapipe-task-text/lib/src/interface/tasks/text_embedding/text_embedder.dart @@ -15,7 +15,8 @@ abstract class BaseTextEmbedder { Future embed(String text); /// {@template TextEmbedder.cosineSimilarity} - /// Sends a [String] value to MediaPipe for conversion into an [Embedding]. + /// Compares the similarity between two [Embedding] values. Identical + /// embeddings will yield a similarity value of 1.0. /// {@endtemplate} Future cosineSimilarity(Embedding a, Embedding b); diff --git a/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detection.dart b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detection.dart new file mode 100644 index 00000000..28208370 --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detection.dart @@ -0,0 +1,8 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +export 'language_detector_result.dart'; +export 'language_detector_executor.dart'; +export 'language_detector_options.dart'; +export 'language_detector.dart'; diff --git a/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector.dart b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector.dart new file mode 100644 index 00000000..c40411de --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector.dart @@ -0,0 +1,106 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:async'; +import 'dart:io' as io; +import 'dart:isolate'; +import 'package:async/async.dart'; +import 'package:logging/logging.dart'; +import 'package:mediapipe_core/mediapipe_core.dart'; +import 'package:mediapipe_text/interface.dart'; +import 'package:mediapipe_text/io.dart'; + +final _log = Logger('LanguageDetector'); + +/// LanguageDetector implementation able to use FFI and `dart:io`. +class LanguageDetector extends BaseLanguageDetector { + /// Generative constructor. + LanguageDetector(this._options) : _readyCompleter = Completer() { + _createIsolate(_options).then((results) { + _events = results.$1; + _sendPort = results.$2; + _readyCompleter.complete(); + }); + } + + late SendPort _sendPort; + late StreamQueue _events; + final Completer _readyCompleter; + Future get _ready => _readyCompleter.future; + + final LanguageDetectorOptions _options; + + /// Closes down the background isolate, releasing all resources. + @override + void dispose() => _sendPort.send(null); + + /// {@macro LanguageDetector.embed} + @override + Future detect(String text) async { + _log.fine('Detecting language of "${text.shorten()}"'); + await _ready; + _sendPort.send(text); + while (true) { + final response = await _events.next; + if (response is LanguageDetectorResult) { + return response; + } else if (response is String) { + _log.fine(response); + } else { + throw Exception( + 'Unexpected language detection result of type ${response.runtimeType} ' + ': $response', + ); + } + } + } +} + +Future<(StreamQueue, SendPort)> _createIsolate( + LanguageDetectorOptions options) async { + final p = ReceivePort(); + await Isolate.spawn( + (SendPort port) => _languageDetectionService( + port, + options, + ), + p.sendPort, + ); + + final events = StreamQueue(p); + final SendPort sendPort = await events.next; + return (events, sendPort); +} + +Future _languageDetectionService( + SendPort p, + LanguageDetectorOptions options, +) async { + final commandPort = ReceivePort(); + p.send(commandPort.sendPort); + + Logger.root.level = Level.FINEST; + Logger.root.onRecord.listen((record) { + io.stdout.writeln('${record.level.name} [${record.loggerName}]' + '[' + '${record.time.hour.toString()}:' + '${record.time.minute.toString().padLeft(2, "0")}:' + '${record.time.second.toString().padLeft(2, "0")}.' + '${record.time.millisecond.toString().padRight(3, "0")}' + '] ${record.message}'); + }); + + final executor = LanguageDetectorExecutor(options); + + await for (final String? message in commandPort) { + if (message != null) { + final LanguageDetectorResult result = executor.detect(message); + p.send(result); + } else { + break; + } + } + executor.dispose(); + Isolate.exit(); +} diff --git a/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_executor.dart b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_executor.dart new file mode 100644 index 00000000..c896e83b --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_executor.dart @@ -0,0 +1,77 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:ffi'; +import 'package:ffi/ffi.dart'; +import 'package:logging/logging.dart'; +import 'package:mediapipe_core/io.dart'; +import 'package:mediapipe_text/io.dart'; +import 'package:mediapipe_text/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart' + as bindings; + +final _log = Logger('LanguageDetectorExecutor'); + +/// Executes MediaPipe's "detect language" task. +/// +/// {@macro TaskExecutor} +class LanguageDetectorExecutor extends TaskExecutor< + bindings.LanguageDetectorOptions, + LanguageDetectorOptions, + bindings.LanguageDetectorResult, + LanguageDetectorResult> { + /// {@macro LanguageDetectorExecutor} + LanguageDetectorExecutor(super.options); + + @override + final String taskName = 'LanguageDetection'; + + @override + Pointer createWorker( + Pointer options, + Pointer> error, + ) { + _log.fine('Creating LanguageDetector in native memory'); + final worker = bindings.language_detector_create(options, error); + _log.finest( + 'Created LanguageDetector at 0x${worker.address.toRadixString(16)}', + ); + return worker; + } + + @override + Pointer createResultsPointer() { + _log.fine('Allocating LanguageDetectorResult in native memory'); + final results = calloc(); + _log.finest( + 'Allocated LanguageDetectorResult at 0x${results.address.toRadixString(16)}', + ); + return results; + } + + @override + int closeWorker(Pointer worker, Pointer> error) { + final status = bindings.language_detector_close(worker, error); + _log.finest('Closed LanguageDetector in native memory with status $status'); + return status; + } + + /// Passes [text] to MediaPipe for classification, yielding a + /// [LanguageDetectorResult] or throwing an exception. + LanguageDetectorResult detect(String text) { + final resultPtr = createResultsPointer(); + final errorMessageMemory = calloc>(); + final textMemory = text.copyToNative(); + final status = bindings.language_detector_detect( + worker, + textMemory, + resultPtr, + errorMessageMemory, + ); + _log.finest('Detected with status $status'); + textMemory.free(); + handleErrorMessage(errorMessageMemory, status); + errorMessageMemory.free(1); + return LanguageDetectorResult.native(resultPtr); + } +} diff --git a/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_options.dart b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_options.dart new file mode 100644 index 00000000..77553c29 --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_options.dart @@ -0,0 +1,94 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:typed_data'; + +import 'dart:ffi'; +import 'package:ffi/ffi.dart'; +import 'package:mediapipe_core/io.dart'; +import 'package:mediapipe_text/interface.dart'; +import '../../third_party/mediapipe/generated/mediapipe_text_bindings.dart' + as bindings; + +/// {@macro LanguageDetectorOptions} +/// +/// This io-friendly implementation is not immutable strictly to track whether +/// [dispose] has been called. +// ignore: must_be_immutable +class LanguageDetectorOptions extends BaseLanguageDetectorOptions + with TaskOptions { + /// {@macro LanguageDetectorOptions} + LanguageDetectorOptions({ + required this.baseOptions, + this.classifierOptions = const ClassifierOptions(), + }); + + /// {@macro LanguageDetectorOptions.fromAssetPath} + factory LanguageDetectorOptions.fromAssetPath( + String assetPath, { + ClassifierOptions classifierOptions = const ClassifierOptions(), + }) { + return LanguageDetectorOptions( + baseOptions: BaseOptions.path(assetPath), + classifierOptions: classifierOptions, + ); + } + + /// {@macro LanguageDetectorOptions.fromAssetBuffer} + factory LanguageDetectorOptions.fromAssetBuffer( + Uint8List assetBuffer, { + ClassifierOptions classifierOptions = const ClassifierOptions(), + }) => + LanguageDetectorOptions( + baseOptions: BaseOptions.memory(assetBuffer), + classifierOptions: classifierOptions, + ); + + @override + final BaseOptions baseOptions; + + @override + final ClassifierOptions classifierOptions; + + /// {@macro TaskOptions.memory} + Pointer? _pointer; + + @override + Pointer copyToNative() { + _pointer = calloc(); + baseOptions.assignToStruct(_pointer!.ref.base_options); + classifierOptions.assignToStruct(_pointer!.ref.classifier_options); + return _pointer!; + } + + bool _isClosed = false; + + /// Tracks whether [dispose] has been called. + bool get isClosed => _isClosed; + + @override + void dispose() { + assert(() { + if (isClosed) { + throw Exception( + 'Attempted to call dispose on an already-disposed task options' + 'object. Task options should only ever be disposed after they are at ' + 'end-of-life and will never be accessed again.', + ); + } + if (_pointer == null) { + throw Exception( + 'Attempted to call dispose on a LanguageDetectorOptions object which ' + 'was never used by a LanguageDetector, which you do not need to do. ' + 'Did you forget to create your LanguageDetector?', + ); + } + return true; + }()); + baseOptions.freeStructFields(_pointer!.ref.base_options); + classifierOptions.freeStructFields(_pointer!.ref.classifier_options); + calloc.free(_pointer!); + _isClosed = true; + } +} diff --git a/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_result.dart b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_result.dart new file mode 100644 index 00000000..7e054df7 --- /dev/null +++ b/packages/mediapipe-task-text/lib/src/io/tasks/language_detection/language_detector_result.dart @@ -0,0 +1,106 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:ffi'; +import 'package:mediapipe_core/io.dart'; +import 'package:mediapipe_text/interface.dart'; +import '../../third_party/mediapipe/generated/mediapipe_text_bindings.dart' + as bindings; + +/// {@macro LanguageDetectionResult} +class LanguageDetectorResult extends BaseLanguageDetectorResult + with IOTaskResult { + /// {@macro LanguageDetectionResult} + LanguageDetectorResult({ + required Iterable predictions, + }) : _predictions = predictions, + _pointer = null; + + /// {@template LanguageDetectorResult.native} + /// Initializes a [LanguageDetectorResult] instance as a wrapper around native + /// memory. + /// + /// See also: + /// * [TextEmbedderExecutor.embed] where this is called. + /// {@endtemplate} + LanguageDetectorResult.native(this._pointer); + + final Pointer? _pointer; + + Iterable? _predictions; + @override + Iterable get predictions => + _predictions ??= _getpredictions(); + Iterable _getpredictions() { + if (_pointer.isNullOrNullPointer) { + throw Exception( + 'Could not determine value for LanguageDetectorResult.predictions', + ); + } + return LanguagePrediction.fromNativeArray( + _pointer!.ref.predictions, + _pointer!.ref.predictions_count, + ); + } +} + +/// {@macro LanguagePrediction} +class LanguagePrediction extends BaseLanguagePrediction { + /// {@macro LanguagePrediction} + LanguagePrediction({ + required String languageCode, + required double probability, + }) : _languageCode = languageCode, + _probability = probability, + _pointer = null; + + /// Initializes a [LanguagePrediction] instance as a wrapper around native + /// memory. + /// + /// {@macro Container.memoryManagement} + LanguagePrediction.native(this._pointer); + + final Pointer? _pointer; + + String? _languageCode; + + @override + String get languageCode => _languageCode ??= _getLanguageCode(); + String _getLanguageCode() { + if (_pointer.isNullOrNullPointer) { + throw Exception( + 'Could not determine value for ' + 'LanguagePrediction.languageCode', + ); + } + if (_pointer!.ref.language_code.isNullPointer) { + throw Exception('Corrupted memory in LanguagePrediction'); + } + return _pointer!.ref.language_code.toDartString(); + } + + double? _probability; + @override + double get probability => _probability ??= _getProbability(); + double _getProbability() { + if (_pointer.isNullOrNullPointer) { + throw Exception( + 'Could not determine value for ' + 'LanguageDetector.probability', + ); + } + return _pointer!.ref.probability; + } + + /// Accepts a pointer to a list of structs, and a count representing the length + /// of the list, and returns a list of pure-Dart [Category] instances. + static Iterable fromNativeArray( + Pointer structs, + int count, + ) sync* { + for (int i = 0; i < count; i++) { + yield LanguagePrediction.native(structs + i); + } + } +} diff --git a/packages/mediapipe-task-text/lib/src/io/tasks/tasks.dart b/packages/mediapipe-task-text/lib/src/io/tasks/tasks.dart index c7ff4814..a873e9b0 100644 --- a/packages/mediapipe-task-text/lib/src/io/tasks/tasks.dart +++ b/packages/mediapipe-task-text/lib/src/io/tasks/tasks.dart @@ -2,5 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +export 'language_detection/language_detection.dart'; export 'text_classification/text_classification.dart'; export 'text_embedding/text_embedding.dart'; diff --git a/packages/mediapipe-task-text/lib/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart b/packages/mediapipe-task-text/lib/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart index 57653f0a..df8ef994 100644 --- a/packages/mediapipe-task-text/lib/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart +++ b/packages/mediapipe-task-text/lib/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart @@ -18,6 +18,41 @@ import 'dart:ffi' as ffi; import 'package:mediapipe_core/src/io/third_party/mediapipe/generated/mediapipe_common_bindings.dart' as imp1; +@ffi.Native< + ffi.Pointer Function(ffi.Pointer, + ffi.Pointer>)>(symbol: 'language_detector_create') +external ffi.Pointer language_detector_create( + ffi.Pointer options, + ffi.Pointer> error_msg, +); + +@ffi.Native< + ffi.Int Function( + ffi.Pointer, + ffi.Pointer, + ffi.Pointer, + ffi.Pointer>)>(symbol: 'language_detector_detect') +external int language_detector_detect( + ffi.Pointer detector, + ffi.Pointer utf8_str, + ffi.Pointer result, + ffi.Pointer> error_msg, +); + +@ffi.Native)>( + symbol: 'language_detector_close_result') +external void language_detector_close_result( + ffi.Pointer result, +); + +@ffi.Native< + ffi.Int Function(ffi.Pointer, + ffi.Pointer>)>(symbol: 'language_detector_close') +external int language_detector_close( + ffi.Pointer detector, + ffi.Pointer> error_msg, +); + @ffi.Native< ffi.Pointer Function(ffi.Pointer, ffi.Pointer>)>(symbol: 'text_classifier_create') @@ -102,6 +137,26 @@ external int text_embedder_cosine_similarity( ffi.Pointer> error_msg, ); +final class LanguageDetectorPrediction extends ffi.Struct { + external ffi.Pointer language_code; + + @ffi.Float() + external double probability; +} + +final class LanguageDetectorResult extends ffi.Struct { + external ffi.Pointer predictions; + + @ffi.Uint32() + external int predictions_count; +} + +final class LanguageDetectorOptions extends ffi.Struct { + external imp1.BaseOptions base_options; + + external imp1.ClassifierOptions classifier_options; +} + final class TextClassifierOptions extends ffi.Struct { external imp1.BaseOptions base_options; @@ -118,12 +173,6 @@ final class TextEmbedderOptions extends ffi.Struct { typedef TextEmbedderResult = imp1.EmbeddingResult; -const int __bool_true_false_are_defined = 1; - -const int true1 = 1; - -const int false1 = 0; - const int __WORDSIZE = 64; const int __DARWIN_ONLY_64_BIT_INO_T = 1; @@ -289,3 +338,9 @@ const int WINT_MAX = 2147483647; const int SIG_ATOMIC_MIN = -2147483648; const int SIG_ATOMIC_MAX = 2147483647; + +const int __bool_true_false_are_defined = 1; + +const int true1 = 1; + +const int false1 = 0; diff --git a/packages/mediapipe-task-text/lib/universal_mediapipe_text.dart b/packages/mediapipe-task-text/lib/universal_mediapipe_text.dart index 7bcf112f..7d930db4 100644 --- a/packages/mediapipe-task-text/lib/universal_mediapipe_text.dart +++ b/packages/mediapipe-task-text/lib/universal_mediapipe_text.dart @@ -60,8 +60,7 @@ class TextClassifierResult extends BaseTextClassifierResult { TextClassifierResult({required Iterable classifications}); @override - Iterable get classifications => - throw UnimplementedError(); + Iterable get classifications => throw UnimplementedError(); @override // ignore: must_call_super @@ -127,3 +126,72 @@ class TextEmbedderResult extends BaseEmbedderResult { // ignore: must_call_super void dispose() => throw UnimplementedError(); } + +/// {@macro LanguageDetector} +class LanguageDetector extends BaseLanguageDetector { + /// {@macro LanguageDetector} + LanguageDetector(LanguageDetectorOptions options); + + @override + Future detect(String text) => + throw UnimplementedError(); + + @override + void dispose() => throw UnimplementedError(); +} + +/// {@macro LanguageDetectorOptions} +class LanguageDetectorOptions extends BaseLanguageDetectorOptions { + /// {@template LanguageDetectorOptions.fromAssetPath} + /// Convenience constructor that looks for the model asset at the given file + /// system location. + /// {@endtemplate} + LanguageDetectorOptions.fromAssetPath( + String assetPath, { + ClassifierOptions classifierOptions = const ClassifierOptions(), + }); + + /// {@template LanguageDetectorOptions.fromAssetBuffer} + /// Convenience constructor that uses a model existing in memory. + /// {@endtemplate} + LanguageDetectorOptions.fromAssetBuffer( + Uint8List assetBuffer, { + ClassifierOptions classifierOptions = const ClassifierOptions(), + }); + + @override + BaseOptions get baseOptions => throw UnimplementedError(); + + @override + ClassifierOptions get classifierOptions => throw UnimplementedError(); +} + +/// {@macro LanguageDetectorResult} +class LanguageDetectorResult extends BaseLanguageDetectorResult { + /// {@template LanguageDetectorResult.fake} + /// Instantiates a [LanguageDetectorResult] with fake data for testing. + /// {@endtemplate} + LanguageDetectorResult({required Iterable predictions}); + + @override + Iterable get predictions => throw UnimplementedError(); + + @override + // ignore: must_call_super + void dispose() => throw UnimplementedError(); +} + +/// {@macro LanguagePrediction} +class LanguagePrediction extends BaseLanguagePrediction { + /// {@macro LanguagePrediction} + LanguagePrediction({ + required String languageCode, + required double probability, + }); + + @override + String get languageCode => throw UnimplementedError(); + + @override + double get probability => throw UnimplementedError(); +} diff --git a/packages/mediapipe-task-text/test/language_detector_executor_test.dart b/packages/mediapipe-task-text/test/language_detector_executor_test.dart new file mode 100644 index 00000000..888b9308 --- /dev/null +++ b/packages/mediapipe-task-text/test/language_detector_executor_test.dart @@ -0,0 +1,104 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// `native-assets` tag allows test runs to opt in or out of running integration +// tests via `flutter test -x native-assets` or `flutter test -t native-assets` +@Tags(['native-assets']) + +import 'dart:io' as io; +import 'package:flutter_test/flutter_test.dart'; +import 'package:path/path.dart' as path; +import 'package:mediapipe_core/io.dart'; +import 'package:mediapipe_text/io.dart'; + +void main() { + final pathToModel = path.joinAll([ + io.Directory.current.absolute.path, + 'example/assets/language_detector.tflite', + ]); + final modelBytes = io.File(pathToModel).readAsBytesSync(); + + group('LanguageDetectorExecutor should', () { + test('run a task', () { + final executor = LanguageDetectorExecutor( + LanguageDetectorOptions.fromAssetBuffer(modelBytes), + ); + final LanguageDetectorResult result = executor.detect('Hello, world!'); + expect(result.predictions, isNotEmpty); + result.dispose(); + executor.dispose(); + }); + + test('run multiple tasks', () { + final executor = LanguageDetectorExecutor( + LanguageDetectorOptions.fromAssetBuffer(modelBytes), + ); + final LanguageDetectorResult result = executor.detect('Hello, world!'); + expect(result.predictions, isNotEmpty); + final LanguageDetectorResult result2 = + executor.detect('Hello, world, again!'); + expect(result2.predictions, isNotEmpty); + result.dispose(); + executor.dispose(); + }); + + test('unpack a result', () { + final executor = LanguageDetectorExecutor( + LanguageDetectorOptions.fromAssetBuffer(modelBytes), + ); + final LanguageDetectorResult result = executor.detect('Hello, world!'); + final prediction = result.predictions.first; + expect(prediction.languageCode, equals('en')); + expect(prediction.probability, greaterThan(0.99)); + result.dispose(); + executor.dispose(); + }); + + test('unpack a Spanish result', () { + final executor = LanguageDetectorExecutor( + LanguageDetectorOptions.fromAssetBuffer(modelBytes), + ); + final LanguageDetectorResult result = executor.detect('¡Hola, mundo!'); + final prediction = result.predictions.first; + expect(prediction.languageCode, equals('es')); + expect(prediction.probability, greaterThan(0.99)); + result.dispose(); + executor.dispose(); + }); + + test('use the denylist', () { + final executor = LanguageDetectorExecutor( + LanguageDetectorOptions.fromAssetBuffer( + modelBytes, + classifierOptions: ClassifierOptions( + categoryDenylist: ['en'], + ), + ), + ); + final LanguageDetectorResult result = executor.detect('Hello, world!'); + final prediction = result.predictions.first; + expect(prediction.languageCode, 'de'); + expect(prediction.probability, closeTo(0.0011, 0.0001)); + result.dispose(); + executor.dispose(); + }); + + test('use the allowlist', () { + final executor = LanguageDetectorExecutor( + LanguageDetectorOptions.fromAssetBuffer( + modelBytes, + classifierOptions: ClassifierOptions( + categoryAllowlist: ['en'], + ), + ), + ); + final LanguageDetectorResult result = executor.detect('Hello, world!'); + expect(result.predictions, hasLength(1)); + final prediction = result.predictions.first; + expect(prediction.languageCode, equals('en')); + result.dispose(); + executor.dispose(); + }); + }); +} diff --git a/packages/mediapipe-task-text/test/language_detector_result_test.dart b/packages/mediapipe-task-text/test/language_detector_result_test.dart new file mode 100644 index 00000000..2af83831 --- /dev/null +++ b/packages/mediapipe-task-text/test/language_detector_result_test.dart @@ -0,0 +1,43 @@ +// Copyright 2014 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:ffi'; +import 'package:ffi/ffi.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:mediapipe_core/io.dart'; +import 'package:mediapipe_text/io.dart'; +import 'package:mediapipe_text/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart' + as bindings; + +void main() { + group('LanguageDetectorResult.native should', () { + test('load an empty object', () { + final Pointer ptr = + calloc(); + // These fields are provided by the real MediaPipe implementation, but + // Dart ignores them because they are meaningless in context of text tasks + ptr.ref.predictions_count = 0; + + final result = LanguageDetectorResult.native(ptr); + expect(result.predictions, isEmpty); + }); + + test('load a hydrated object', () { + final Pointer resultPtr = + calloc(); + + final predictionsPtr = calloc(2); + predictionsPtr[0].language_code = 'es'.copyToNative(); + predictionsPtr[0].probability = 0.99; + predictionsPtr[1].language_code = 'en'.copyToNative(); + predictionsPtr[1].probability = 0.01; + + resultPtr.ref.predictions_count = 2; + resultPtr.ref.predictions = predictionsPtr; + + final result = LanguageDetectorResult.native(resultPtr); + expect(result.predictions, hasLength(2)); + }, timeout: const Timeout(Duration(milliseconds: 10))); + }); +} diff --git a/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/language_detector/language_detector.h b/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/language_detector/language_detector.h new file mode 100644 index 00000000..523f46ec --- /dev/null +++ b/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/language_detector/language_detector.h @@ -0,0 +1,90 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ +#define MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ + +#include + +#include "../../../../../../../mediapipe-core/third_party/mediapipe/tasks/c/components/processors/classifier_options.h" +#include "../../../../../../../mediapipe-core/third_party/mediapipe/tasks/c/core/base_options.h" + +#ifndef MP_EXPORT +#define MP_EXPORT __attribute__((visibility("default"))) +#endif // MP_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +// A language code and its probability. +struct LanguageDetectorPrediction { + // An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, + // "ja"-Latn for Japanese (romaji). + char* language_code; + + float probability; +}; + +// Task output. +struct LanguageDetectorResult { + struct LanguageDetectorPrediction* predictions; + + // The count of predictions. + uint32_t predictions_count; +}; + +// The options for configuring a MediaPipe language detector task. +struct LanguageDetectorOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // Options for configuring the detector behavior, such as score threshold, + // number of results, etc. + struct ClassifierOptions classifier_options; +}; + +// Creates a LanguageDetector from the provided `options`. +// Returns a pointer to the language detector on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT void* language_detector_create( + struct LanguageDetectorOptions* options, char** error_msg); + +// Performs language detection on the input `text`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int language_detector_detect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, + char** error_msg); + +// Frees the memory allocated inside a LanguageDetectorResult result. Does not +// free the result pointer itself. +MP_EXPORT void language_detector_close_result(LanguageDetectorResult* result); + +// Shuts down the LanguageDetector when all the work is done. Frees all memory. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int language_detector_close(void* detector, char** error_msg); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ diff --git a/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/text_embedder/text_embedder.h b/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/text_embedder/text_embedder.h index 5b3b5300..d4e12901 100644 --- a/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/text_embedder/text_embedder.h +++ b/packages/mediapipe-task-text/third_party/mediapipe/tasks/c/text/text_embedder/text_embedder.h @@ -75,7 +75,7 @@ MP_EXPORT int text_embedder_close(void* embedder, char** error_msg); MP_EXPORT int text_embedder_cosine_similarity(const struct Embedding* u, const struct Embedding* v, double* similarity, - char** error_msg); + char** error_msg); #ifdef __cplusplus } // extern C diff --git a/tool/builder/lib/sync_headers.dart b/tool/builder/lib/sync_headers.dart index 956805bb..6d23aad9 100644 --- a/tool/builder/lib/sync_headers.dart +++ b/tool/builder/lib/sync_headers.dart @@ -17,6 +17,7 @@ final processors = 'mediapipe/tasks/c/components/processors'; final core = 'mediapipe/tasks/c/core'; final tc = 'mediapipe/tasks/c/text/text_classifier'; final te = 'mediapipe/tasks/c/text/text_embedder'; +final ld = 'mediapipe/tasks/c/text/language_detector'; /// google/flutter-mediapipe package paths final corePackage = 'packages/mediapipe-core/third_party'; @@ -35,6 +36,7 @@ List<(String, String, String, Function(io.File)?)> headerPaths = [ (processors, corePackage, 'embedder_options.h', null), (tc, textPackage, 'text_classifier.h', relativeIncludes), (te, textPackage, 'text_embedder.h', relativeIncludes), + (ld, textPackage, 'language_detector.h', relativeIncludes), ]; /// Command to copy all necessary C header files into this repository. diff --git a/tool/ci_script_shared.sh b/tool/ci_script_shared.sh index 787d327d..1ed76a08 100644 --- a/tool/ci_script_shared.sh +++ b/tool/ci_script_shared.sh @@ -1,11 +1,13 @@ function ci_text_package() { # Download bert_classifier.tflite model into example/assets for integration tests - echo "Downloading TextClassification model" pushd ../../tool/builder dart pub get + echo "Downloading TextClassification model" dart bin/main.dart model -m textclassification echo "Downloading TextEmbedding model" dart bin/main.dart model -m textembedding + echo "Downloading Language Detection model" + dart bin/main.dart model -m languagedetection popd }