diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart index 710b8f89fed2..48e309268939 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart @@ -1023,8 +1023,13 @@ ModalityTokenCount _parseModalityTokenCount(Object? jsonObject) { if (jsonObject is! Map) { throw unhandledFormat('ModalityTokenCount', jsonObject); } - return ModalityTokenCount(ContentModality._parseValue(jsonObject['modality']), - jsonObject['tokenCount'] as int); + var modality = ContentModality._parseValue(jsonObject['modality']); + + if (jsonObject.containsKey('tokenCount')) { + return ModalityTokenCount(modality, jsonObject['tokenCount'] as int); + } else { + return ModalityTokenCount(modality, 0); + } } SafetyRating _parseSafetyRating(Object? jsonObject) { diff --git a/packages/firebase_vertexai/firebase_vertexai/test/api_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/api_test.dart index 5c292288fe42..0b6359f9a6ef 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/api_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/api_test.dart @@ -437,7 +437,9 @@ void main() { 'totalTokens': 120, 'totalBillableCharacters': 240, 'promptTokensDetails': [ - {'modality': 'TEXT', 'tokenCount': 100}, + { + 'modality': 'TEXT', + }, {'modality': 'IMAGE', 'tokenCount': 20} ] }; @@ -447,7 +449,7 @@ void main() { expect(response.promptTokensDetails, isNotNull); expect(response.promptTokensDetails, hasLength(2)); expect(response.promptTokensDetails![0].modality, ContentModality.text); - expect(response.promptTokensDetails![0].tokenCount, 100); + expect(response.promptTokensDetails![0].tokenCount, 0); expect( response.promptTokensDetails![1].modality, ContentModality.image); expect(response.promptTokensDetails![1].tokenCount, 20); @@ -597,6 +599,47 @@ void main() { expect(response.candidates.first.finishMessage, isNull); }); + test('parses usageMetadata for no tokenCount', () { + final json = { + 'candidates': [basicCandidateJson], + 'usageMetadata': { + 'promptTokenCount': 10, + 'candidatesTokenCount': 20, + 'totalTokenCount': 30, + 'promptTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 10} + ], + 'candidatesTokensDetails': [ + { + 'modality': 'TEXT', + } + ], + } + }; + final response = parseGenerateContentResponse(json); + expect(response.candidates, hasLength(1)); + expect(response.candidates.first.text, 'Hello world'); + expect(response.candidates.first.finishReason, FinishReason.stop); + expect(response.candidates.first.safetyRatings, isNotNull); + expect(response.candidates.first.safetyRatings, hasLength(1)); + + expect(response.usageMetadata, isNotNull); + expect(response.usageMetadata!.promptTokenCount, 10); + expect(response.usageMetadata!.candidatesTokenCount, 20); + expect(response.usageMetadata!.totalTokenCount, 30); + expect(response.usageMetadata!.promptTokensDetails, hasLength(1)); + expect(response.usageMetadata!.promptTokensDetails!.first.modality, + ContentModality.text); + expect( + response.usageMetadata!.promptTokensDetails!.first.tokenCount, 10); + expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1)); + expect(response.usageMetadata!.candidatesTokensDetails!.first.modality, + ContentModality.text); + expect( + response.usageMetadata!.candidatesTokensDetails!.first.tokenCount, + 0); + }); + test('parses citationMetadata with "citationSources"', () { final json = { 'candidates': [