Skip to content

Commit 9482ef1

Browse files
committed
model memory troubleshooting
1 parent 00d7c39 commit 9482ef1

File tree

5 files changed

+60
-77
lines changed

5 files changed

+60
-77
lines changed

packages/mediapipe-core/lib/src/task_options.dart

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,51 +21,84 @@ import 'third_party/mediapipe/generated/mediapipe_common_bindings.dart'
2121
/// classifier's desired behavior.
2222
class BaseOptions extends Equatable {
2323
/// Generative constructor that creates a [BaseOptions] instance.
24-
const BaseOptions({this.modelAssetBuffer, this.modelAssetPath})
25-
: assert(
24+
const BaseOptions._({
25+
this.modelAssetBuffer,
26+
this.modelAssetPath,
27+
this.modelAssetBufferCount,
28+
required _BaseOptionsType type,
29+
}) : assert(
2630
!(modelAssetBuffer == null && modelAssetPath == null),
2731
'You must supply either `modelAssetBuffer` or `modelAssetPath`',
2832
),
2933
assert(
3034
!(modelAssetBuffer != null && modelAssetPath != null),
3135
'You must only supply one of `modelAssetBuffer` and `modelAssetPath`',
32-
);
36+
),
37+
assert(
38+
(modelAssetBuffer == null) == (modelAssetBufferCount == null),
39+
'modelAssetBuffer and modelAssetBufferCount must only be submitted '
40+
'together',
41+
),
42+
_type = type;
43+
44+
/// Constructor for [BaseOptions] classes using a file system path.
45+
///
46+
/// In practice, this is unsupported, as assets in Flutter are bundled into
47+
/// the build output and not available on disk. However, it can potentially
48+
/// be helpful for testing / development purposes.
49+
factory BaseOptions.path(String path) => BaseOptions._(
50+
modelAssetPath: path,
51+
type: _BaseOptionsType.path,
52+
);
53+
54+
/// Constructor for [BaseOptions] classes using an in-memory pointer to the
55+
/// MediaPipe SDK.
56+
///
57+
/// In practice, this is the only option supported for production builds.
58+
factory BaseOptions.memory(Uint8List buffer) {
59+
return BaseOptions._(
60+
modelAssetBuffer: buffer,
61+
modelAssetBufferCount: buffer.lengthInBytes,
62+
type: _BaseOptionsType.memory,
63+
);
64+
}
3365

3466
/// The model asset file contents as bytes;
3567
final Uint8List? modelAssetBuffer;
3668

69+
/// The size of the model assets buffer (or `0` if not set).
70+
final int? modelAssetBufferCount;
71+
3772
/// Path to the model asset file.
3873
final String? modelAssetPath;
3974

75+
final _BaseOptionsType _type;
76+
4077
/// Converts this pure-Dart representation into C-memory suitable for the
4178
/// MediaPipe SDK to instantiate various classifiers.
4279
Pointer<bindings.BaseOptions> toStruct() {
4380
final struct = calloc<bindings.BaseOptions>();
4481

45-
if (modelAssetPath != null) {
82+
if (_type == _BaseOptionsType.path) {
4683
struct.ref.model_asset_path = prepareString(modelAssetPath!);
4784
}
48-
if (modelAssetBuffer != null) {
85+
if (_type == _BaseOptionsType.memory) {
4986
struct.ref.model_asset_buffer = prepareUint8List(modelAssetBuffer!);
5087
struct.ref.model_asset_buffer_count = modelAssetBuffer!.lengthInBytes;
5188
}
5289
return struct;
5390
}
5491

5592
@override
56-
List<Object?> get props => [modelAssetBuffer, modelAssetPath];
57-
58-
/// Releases all C memory held by this [bindings.BaseOptions] struct.
59-
static void freeStruct(bindings.BaseOptions struct) {
60-
if (struct.model_asset_buffer.address != 0) {
61-
calloc.free(struct.model_asset_buffer);
62-
}
63-
if (struct.model_asset_path.address != 0) {
64-
calloc.free(struct.model_asset_path);
65-
}
66-
}
93+
List<Object?> get props => [
94+
modelAssetBuffer,
95+
modelAssetPath,
96+
modelAssetBufferCount,
97+
];
6798
}
6899

100+
enum _BaseOptionsType { path, memory }
101+
69102
/// Dart representation of MediaPipe's "ClassifierOptions" concept.
70103
///
71104
/// Classifier options shared across MediaPipe classification tasks.

packages/mediapipe-core/lib/src/third_party/mediapipe/generated/mediapipe_common_bindings.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ final class BaseOptions extends ffi.Struct {
2424

2525
external ffi.Pointer<ffi.Char> model_asset_path;
2626

27-
@ffi.UnsignedInt()
27+
@ffi.Int()
2828
external int model_asset_buffer_count;
2929
}
3030

packages/mediapipe-core/test/task_options_test.dart

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,16 @@ import 'package:test/test.dart';
99
import 'package:mediapipe_core/mediapipe_core.dart';
1010

1111
void main() {
12-
group('BaseOptions constructor should', () {
13-
test('enforce exactly one of modelPath and modelBuffer', () {
14-
expect(
15-
() => BaseOptions(
16-
modelAssetPath: 'abc',
17-
modelAssetBuffer: Uint8List.fromList([1, 2, 3]),
18-
),
19-
throwsA(TypeMatcher<AssertionError>()),
20-
);
21-
22-
expect(BaseOptions.new, throwsA(TypeMatcher<AssertionError>()));
23-
});
24-
});
25-
2612
group('BaseOptions.toStruct/fromStruct should', () {
2713
test('allocate memory in C for a modelAssetPath', () {
28-
final options = BaseOptions(modelAssetPath: 'abc');
14+
final options = BaseOptions.path('abc');
2915
final struct = options.toStruct();
3016
expect(toDartString(struct.ref.model_asset_path), 'abc');
3117
expectNullPtr(struct.ref.model_asset_buffer);
3218
});
3319

3420
test('allocate memory in C for a modelAssetBuffer', () {
35-
final options = BaseOptions(
36-
modelAssetBuffer: Uint8List.fromList([1, 2, 3]),
37-
);
21+
final options = BaseOptions.memory(Uint8List.fromList([1, 2, 3]));
3822
final struct = options.toStruct();
3923
expect(
4024
toUint8List(struct.ref.model_asset_buffer),
@@ -44,9 +28,7 @@ void main() {
4428
});
4529

4630
test('allocate memory in C for a modelAssetBuffer containing 0', () {
47-
final options = BaseOptions(
48-
modelAssetBuffer: Uint8List.fromList([1, 2, 0, 3]),
49-
);
31+
final options = BaseOptions.memory(Uint8List.fromList([1, 2, 0, 3]));
5032
final struct = options.toStruct();
5133
expect(
5234
toUint8List(struct.ref.model_asset_buffer),

packages/mediapipe-task-text/example/lib/main.dart

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import 'dart:io' as io;
22
import 'dart:typed_data';
33
import 'package:logging/logging.dart';
4-
import 'package:path/path.dart' as path;
54
import 'package:flutter/material.dart';
65
import 'package:mediapipe_text/mediapipe_text.dart';
7-
import 'package:path_provider/path_provider.dart';
86

97
final _log = Logger('TextClassificationExample');
108

@@ -33,54 +31,24 @@ class _MainAppState extends State<MainApp> {
3331
late final TextClassifier _classifier;
3432
final TextEditingController _controller = TextEditingController();
3533
String? results;
34+
late final ByteData classifierBytes;
3635

3736
@override
3837
void initState() {
3938
super.initState();
4039
_controller.text = 'Hello, world!';
4140
_initClassifier();
41+
Future.delayed(const Duration(milliseconds: 500)).then((_) => _classify());
4242
}
4343

4444
Future<void> _initClassifier() async {
45-
// getApplicationDocumentsDirectory().then((dir) {
46-
// print(dir.absolute.path);
47-
// });
48-
// final dir = await getApplicationSupportDirectory();
49-
// print('app support: ${dir.absolute.path}');
50-
51-
// DefaultAssetBundle.of(context).
52-
53-
final ByteData classifierBytes = await DefaultAssetBundle.of(context)
45+
classifierBytes = await DefaultAssetBundle.of(context)
5446
.load('assets/bert_classifier.tflite');
5547

56-
// final dir = io.Directory(path.current);
57-
// final modelPath = path.joinAll(
58-
// [dir.absolute.path, 'assets/bert_classifier.tflite'],
59-
// );
60-
// _log.finest('modelPath: $modelPath');
61-
// if (io.File(modelPath).existsSync()) {
62-
// _log.fine('Successfully found model.');
63-
// } else {
64-
// _log.severe('Invalid model path \n\t$modelPath.\n\nModel not found.');
65-
// io.exit(1);
66-
// }
67-
68-
// final sdkPath = path.joinAll(
69-
// [dir.absolute.path, 'assets/libtext_classifier.dylib'],
70-
// );
71-
// _log.finest('sdkPath: $sdkPath');
72-
// if (io.File(sdkPath).existsSync()) {
73-
// _log.fine('Successfully found SDK.');
74-
// } else {
75-
// _log.severe('Invalid SDK path $sdkPath. SDK not found.');
76-
// io.exit(1);
77-
// }
78-
7948
_classifier = TextClassifier(
80-
// options: TextClassifierOptions.fromAssetPath(modelPath),
8149
options: TextClassifierOptions.fromAssetBuffer(
82-
Uint8List.view(classifierBytes.buffer)),
83-
// sdkPath: sdkPath,
50+
classifierBytes.buffer.asUint8List(),
51+
),
8452
);
8553
}
8654

packages/mediapipe-task-text/lib/src/tasks/text_classification/containers/text_classifier_options.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class TextClassifierOptions {
3232
}) {
3333
assert(!kIsWeb, 'fromAssetPath cannot be used on the web');
3434
return TextClassifierOptions(
35-
baseOptions: BaseOptions(modelAssetPath: assetPath),
35+
baseOptions: BaseOptions.path(assetPath),
3636
classifierOptions: classifierOptions,
3737
);
3838
}
@@ -45,7 +45,7 @@ class TextClassifierOptions {
4545
ClassifierOptions classifierOptions = const ClassifierOptions(),
4646
}) =>
4747
TextClassifierOptions(
48-
baseOptions: BaseOptions(modelAssetBuffer: assetBuffer),
48+
baseOptions: BaseOptions.memory(assetBuffer),
4949
classifierOptions: classifierOptions,
5050
);
5151

0 commit comments

Comments
 (0)