Skip to content
171 changes: 130 additions & 41 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import 'content.dart';
import 'error.dart';
import 'function_calling.dart' show Tool, ToolConfig;
import 'schema.dart';

/// Response for Count Tokens
Expand Down Expand Up @@ -155,6 +156,23 @@ final class UsageMetadata {
final List<ModalityTokenCount>? candidatesTokensDetails;
}

/// Constructe a UsageMetadata with all it's fields.
///
/// Expose access to the private constructor for use within the package..
UsageMetadata createUsageMetadata({
required int? promptTokenCount,
required int? candidatesTokenCount,
required int? totalTokenCount,
required List<ModalityTokenCount>? promptTokensDetails,
required List<ModalityTokenCount>? candidatesTokensDetails,
}) =>
UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);

/// Response candidate generated from a [GenerativeModel].
final class Candidate {
// TODO: token count?
Expand Down Expand Up @@ -842,53 +860,124 @@ enum TaskType {
Object toJson() => _jsonString;
}

/// Parse the json to [GenerateContentResponse]
GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
final candidates = switch (jsonObject) {
{'candidates': final List<Object?> candidates} =>
candidates.map(_parseCandidate).toList(),
_ => <Candidate>[]
};
final promptFeedback = switch (jsonObject) {
{'promptFeedback': final promptFeedback?} =>
_parsePromptFeedback(promptFeedback),
_ => null,
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
usageMetadata: usageMedata);
// ignore: public_member_api_docs
abstract interface class SerializationStrategy {
// ignore: public_member_api_docs
GenerateContentResponse parseGenerateContentResponse(Object jsonObject);
// ignore: public_member_api_docs
CountTokensResponse parseCountTokensResponse(Object jsonObject);
// ignore: public_member_api_docs
Map<String, Object?> generateContentRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
Content? systemInstruction,
);

// ignore: public_member_api_docs
Map<String, Object?> countTokensRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
);
}

/// Parse the json to [CountTokensResponse]
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
// ignore: public_member_api_docs
final class VertexSerialization implements SerializationStrategy {
/// Parse the json to [GenerateContentResponse]
@override
GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
final candidates = switch (jsonObject) {
{'candidates': final List<Object?> candidates} =>
candidates.map(_parseCandidate).toList(),
_ => <Candidate>[]
};
final promptFeedback = switch (jsonObject) {
{'promptFeedback': final promptFeedback?} =>
_parsePromptFeedback(promptFeedback),
_ => null,
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
{'totalTokens': final int totalTokens} =>
UsageMetadata._(totalTokenCount: totalTokens),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
usageMetadata: usageMedata);
}

if (jsonObject is! Map) {
throw unhandledFormat('CountTokensResponse', jsonObject);
/// Parse the json to [CountTokensResponse]
@override
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);

if (jsonObject is! Map) {
throw unhandledFormat('CountTokensResponse', jsonObject);
}

final totalTokens = jsonObject['totalTokens'] as int;
final totalBillableCharacters = switch (jsonObject) {
{'totalBillableCharacters': final int totalBillableCharacters} =>
totalBillableCharacters,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};

return CountTokensResponse(
totalTokens,
totalBillableCharacters: totalBillableCharacters,
promptTokensDetails: promptTokensDetails,
);
}

final totalTokens = jsonObject['totalTokens'] as int;
final totalBillableCharacters = switch (jsonObject) {
{'totalBillableCharacters': final int totalBillableCharacters} =>
totalBillableCharacters,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};
@override
Map<String, Object?> generateContentRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
Content? systemInstruction,
) {
return {
'model': '${model.prefix}/${model.name}',
'contents': contents.map((c) => c.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (systemInstruction != null)
'systemInstruction': systemInstruction.toJson(),
};
}

return CountTokensResponse(
totalTokens,
totalBillableCharacters: totalBillableCharacters,
promptTokensDetails: promptTokensDetails,
);
@override
Map<String, Object?> countTokensRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
) =>
// Everything except contents is ignored.
{'contents': contents.map((c) => c.toJson()).toList()};
}

Candidate _parseCandidate(Object? jsonObject) {
Expand Down
110 changes: 86 additions & 24 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import 'package:web_socket_channel/io.dart';
import 'api.dart';
import 'client.dart';
import 'content.dart';
import 'developer/api.dart';
import 'function_calling.dart';
import 'imagen_api.dart';
import 'imagen_content.dart';
Expand All @@ -52,33 +53,28 @@ enum Task {
predict,
}

/// Base class for models.
///
/// Do not instantiate directly.
abstract class BaseModel {
// ignore: public_member_api_docs
BaseModel(
abstract interface class _ModelUri {
String get baseAuthority;
Uri taskUri(Task task);
({String prefix, String name}) get model;
}

final class _VertexUri implements _ModelUri {
_VertexUri(
{required String model,
required String location,
required FirebaseApp app})
: _model = normalizeModelName(model),
: model = _normalizeModelName(model),
_projectUri = _vertexUri(app, location);

static const _baseUrl = 'firebasevertexai.googleapis.com';
static const _baseAuthority = 'firebasevertexai.googleapis.com';
static const _apiVersion = 'v1beta';

final ({String prefix, String name}) _model;

final Uri _projectUri;

/// The normalized model name.
({String prefix, String name}) get model => _model;

/// Returns the model code for a user friendly model name.
///
/// If the model name is already a model code (contains a `/`), use the parts
/// directly. Otherwise, return a `models/` model code.
static ({String prefix, String name}) normalizeModelName(String modelName) {
static ({String prefix, String name}) _normalizeModelName(String modelName) {
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
final parts = modelName.split('/');
return (prefix: parts.first, name: parts.skip(1).join('/'));
Expand All @@ -87,11 +83,79 @@ abstract class BaseModel {
static Uri _vertexUri(FirebaseApp app, String location) {
var projectId = app.options.projectId;
return Uri.https(
_baseUrl,
_baseAuthority,
'/$_apiVersion/projects/$projectId/locations/$location/publishers/google',
);
}

final Uri _projectUri;
@override
final ({String prefix, String name}) model;

@override
String get baseAuthority => _baseAuthority;

@override
Uri taskUri(Task task) {
return _projectUri.replace(
pathSegments: _projectUri.pathSegments
.followedBy([model.prefix, '${model.name}:${task.name}']));
}
}

final class _GoogleAIUri implements _ModelUri {
_GoogleAIUri({
required String model,
required FirebaseApp app,
}) : model = _normalizeModelName(model),
_baseUri = _googleAIBaseUri(app: app);

/// Returns the model code for a user friendly model name.
///
/// If the model name is already a model code (contains a `/`), use the parts
/// directly. Otherwise, return a `models/` model code.
static ({String prefix, String name}) _normalizeModelName(String modelName) {
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
final parts = modelName.split('/');
return (prefix: parts.first, name: parts.skip(1).join('/'));
}

static const _apiVersion = 'v1beta';
static const _baseAuthority = 'firebasevertexai.googleapis.com';
static Uri _googleAIBaseUri(
{String apiVersion = _apiVersion, required FirebaseApp app}) =>
Uri.https(
_baseAuthority, '$apiVersion/projects/${app.options.projectId}');
final Uri _baseUri;

@override
final ({String prefix, String name}) model;

@override
String get baseAuthority => _baseAuthority;

@override
Uri taskUri(Task task) => _baseUri.replace(
pathSegments: _baseUri.pathSegments
.followedBy([model.prefix, '${model.name}:${task.name}']));
}

/// Base class for models.
///
/// Do not instantiate directly.
abstract class BaseModel {
BaseModel._(
{required SerializationStrategy serializationStrategy,
required _ModelUri modelUri})
: _serializationStrategy = serializationStrategy,
_modelUri = modelUri;

final SerializationStrategy _serializationStrategy;
final _ModelUri _modelUri;

/// The normalized model name.
({String prefix, String name}) get model => _modelUri.model;

/// Returns a function that generates Firebase auth tokens.
static FutureOr<Map<String, String>> Function() firebaseTokens(
FirebaseAppCheck? appCheck, FirebaseAuth? auth, FirebaseApp? app) {
Expand Down Expand Up @@ -120,9 +184,7 @@ abstract class BaseModel {
}

/// Returns a URI for the given [task].
Uri taskUri(Task task) => _projectUri.replace(
pathSegments: _projectUri.pathSegments
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
Uri taskUri(Task task) => _modelUri.taskUri(task);
}

/// An abstract base class for models that interact with an API using an [ApiClient].
Expand All @@ -136,11 +198,11 @@ abstract class BaseModel {
abstract class BaseApiClientModel extends BaseModel {
// ignore: public_member_api_docs
BaseApiClientModel({
required super.model,
required super.location,
required super.app,
required super.serializationStrategy,
required super.modelUri,
required ApiClient client,
}) : _client = client;
}) : _client = client,
super._();

final ApiClient _client;

Expand Down
Loading
Loading