Skip to content

breaking(go): Refactored primitives APIs to be consistent #3403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func defineProgrammableModel(r *registry.Registry) *programmableModel {
Tools: true,
Multiturn: true,
}
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
DefineModel(r, "programmableModel", &ModelOptions{Supports: supports}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
Expand All @@ -91,10 +91,7 @@ func TestGenerateAction(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

r, err := registry.New()
if err != nil {
t.Fatalf("failed to create registry: %v", err)
}
r := registry.New()
ConfigureFormats(r)

pm := defineProgrammableModel(r)
Expand Down
124 changes: 86 additions & 38 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)

// EmbedderFunc is the function type for embedding documents.
type EmbedderFunc = func(context.Context, *EmbedRequest) (*EmbedResponse, error)

// Embedder represents an embedder that can perform content embedding.
type Embedder interface {
// Name returns the registry name of the embedder.
Expand All @@ -33,14 +35,30 @@ type Embedder interface {
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
}

// EmbedderInfo represents the structure of the embedder information object.
type EmbedderInfo struct {
// Label is a user-friendly name for the embedder model (e.g., "Google AI - Gemini Pro").
Label string `json:"label,omitempty"`
// Supports defines the capabilities of the embedder, such as input types and multilingual support.
Supports *EmbedderSupports `json:"supports,omitempty"`
// Dimensions specifies the number of dimensions in the embedding vector.
Dimensions int `json:"dimensions,omitempty"`
// EmbedderArg is the interface for embedder arguments. It can either be the embedder action itself or a reference to be looked up.
type EmbedderArg interface {
Name() string
}

// EmbedderRef is a struct to hold embedder name and configuration.
type EmbedderRef struct {
name string
config any
}

// NewEmbedderRef creates a new EmbedderRef with the given name and configuration.
func NewEmbedderRef(name string, config any) EmbedderRef {
return EmbedderRef{name: name, config: config}
}

// Name returns the name of the embedder.
func (e EmbedderRef) Name() string {
return e.name
}

// Config returns the configuration to use by default for this embedder.
func (e EmbedderRef) Config() any {
return e.config
}

// EmbedderSupports represents the supported capabilities of the embedder model.
Expand All @@ -53,48 +71,66 @@ type EmbedderSupports struct {

// EmbedderOptions represents the configuration options for an embedder.
type EmbedderOptions struct {
// ConfigSchema defines the schema for the embedder's configuration options.
ConfigSchema any `json:"configSchema,omitempty"`
// Info contains metadata about the embedder, such as its label and capabilities.
Info *EmbedderInfo `json:"info,omitempty"`
// ConfigSchema is the JSON schema for the embedder's config.
ConfigSchema map[string]any `json:"configSchema,omitempty"`
// Label is a user-friendly name for the embedder model (e.g., "Google AI - Gemini Pro").
Label string `json:"label,omitempty"`
// Supports defines the capabilities of the embedder, such as input types and multilingual support.
Supports *EmbedderSupports `json:"supports,omitempty"`
// Dimensions specifies the number of dimensions in the embedding vector.
Dimensions int `json:"dimensions,omitempty"`
}

// An embedder is used to convert a document to a multidimensional vector.
// embedder is an action with functions specific to converting documents to multidimensional vectors such as Embed().
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] that runs it.
func DefineEmbedder(
r *registry.Registry,
provider, name string,
opts *EmbedderOptions,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
metadata := map[string]any{}
metadata["type"] = "embedder"
metadata["info"] = opts.Info
if opts.ConfigSchema != nil {
metadata["embedder"] = map[string]any{"customOptions": base.ToSchemaMap(opts.ConfigSchema)}
func DefineEmbedder(r *registry.Registry, name string, opts *EmbedderOptions, fn EmbedderFunc) Embedder {
if name == "" {
panic("ai.DefineEmbedder: name is required")
}
inputSchema := base.InferJSONSchema(EmbedRequest{})
if inputSchema.Properties != nil && opts.ConfigSchema != nil {
if _, ok := inputSchema.Properties.Get("options"); ok {
inputSchema.Properties.Set("options", base.InferJSONSchema(opts.ConfigSchema))

if opts == nil {
opts = &EmbedderOptions{
Label: name,
}
}
return (*embedder)(core.DefineActionWithInputSchema(r, provider, name, core.ActionTypeEmbedder, metadata, inputSchema, embed))
if opts.Supports == nil {
opts.Supports = &EmbedderSupports{}
}

metadata := map[string]any{
"type": core.ActionTypeEmbedder,
// TODO: This should be under "embedder" but JS has it as "info".
"info": map[string]any{
"label": opts.Label,
"dimensions": opts.Dimensions,
"supports": map[string]any{
"input": opts.Supports.Input,
"multilingual": opts.Supports.Multilingual,
},
},
"embedder": map[string]any{
"customOptions": opts.ConfigSchema,
},
}

inputSchema := core.InferSchemaMap(EmbedRequest{})
if inputSchema != nil && opts.ConfigSchema != nil {
if _, ok := inputSchema["options"]; ok {
inputSchema["options"] = opts.ConfigSchema
}
}

return (*embedder)(core.DefineActionWithInputSchema(r, name, core.ActionTypeEmbedder, metadata, inputSchema, fn))
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It will try to resolve the embedder dynamically if the embedder is not found.
// It returns nil if the embedder was not resolved.
func LookupEmbedder(r *registry.Registry, provider, name string) Embedder {
action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, core.ActionTypeEmbedder, provider, name)
if action == nil {
return nil
}

return (*embedder)(action)
func LookupEmbedder(r *registry.Registry, name string) Embedder {
return (*embedder)(core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, core.ActionTypeEmbedder, name))
}

// Name returns the name of the embedder.
Expand All @@ -108,14 +144,26 @@ func (e *embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse
}

// Embed invokes the embedder with provided options.
func Embed(ctx context.Context, e Embedder, opts ...EmbedderOption) (*EmbedResponse, error) {
func Embed(ctx context.Context, r *registry.Registry, opts ...EmbedderOption) (*EmbedResponse, error) {
embedOpts := &embedderOptions{}
for _, opt := range opts {
if err := opt.applyEmbedder(embedOpts); err != nil {
return nil, fmt.Errorf("ai.Embed: error applying options: %w", err)
}
}

e, ok := embedOpts.Embedder.(Embedder)
if !ok {
e = LookupEmbedder(r, embedOpts.Embedder.Name())
}
if e == nil {
return nil, fmt.Errorf("ai.Embed: embedder not found: %s", embedOpts.Embedder.Name())
}

if embedRef, ok := embedOpts.Embedder.(EmbedderRef); ok && embedOpts.Config == nil {
embedOpts.Config = embedRef.Config()
}

req := &EmbedRequest{
Input: embedOpts.Documents,
Options: embedOpts.Config,
Expand Down
101 changes: 70 additions & 31 deletions go/ai/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ai

import (
"context"
"errors"
"fmt"

"github.com/firebase/genkit/go/core"
Expand All @@ -29,6 +28,12 @@ import (
"go.opentelemetry.io/otel/trace"
)

// EvaluatorFunc is the function type for evaluator implementations.
type EvaluatorFunc = func(context.Context, *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error)

// BatchEvaluatorFunc is the function type for batch evaluator implementations.
type BatchEvaluatorFunc = func(context.Context, *EvaluatorRequest) (*EvaluatorResponse, error)

// Evaluator represents a evaluator action.
type Evaluator interface {
// Name returns the name of the evaluator.
Expand All @@ -37,6 +42,7 @@ type Evaluator interface {
Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error)
}

// evaluator is an action with functions specific to evaluating a dataset.
type evaluator core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]

// Example is a single example that requires evaluation
Expand Down Expand Up @@ -104,9 +110,14 @@ type EvaluationResult struct {
type EvaluatorResponse = []EvaluationResult

type EvaluatorOptions struct {
// ConfigSchema is the JSON schema for the evaluator's config.
ConfigSchema map[string]any `json:"configSchema,omitempty"`
// DisplayName is the name of the evaluator as it appears in the UI.
DisplayName string `json:"displayName"`
Definition string `json:"definition"`
IsBilled bool `json:"isBilled,omitempty"`
// Definition is the definition of the evaluator.
Definition string `json:"definition"`
// IsBilled is a flag indicating if the evaluator is billed.
IsBilled bool `json:"isBilled,omitempty"`
}

// EvaluatorCallbackRequest is the data we pass to the callback function
Expand All @@ -123,18 +134,34 @@ type EvaluatorCallbackResponse = EvaluationResult
// DefineEvaluator registers the given evaluator function as an action, and
// returns a [Evaluator] that runs it. This method process the input dataset
// one-by-one.
func DefineEvaluator(r *registry.Registry, provider, name string, options *EvaluatorOptions, eval func(context.Context, *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error)) (Evaluator, error) {
if options == nil {
return nil, errors.New("EvaluatorOptions must be provided")
func DefineEvaluator(r *registry.Registry, name string, opts *EvaluatorOptions, fn EvaluatorFunc) Evaluator {
if name == "" {
panic("ai.DefineEvaluator: evaluator name is required")
}

if opts == nil {
opts = &EvaluatorOptions{}
}

// TODO(ssbushi): Set this on `evaluator` key on action metadata
metadataMap := map[string]any{}
metadataMap["evaluatorIsBilled"] = options.IsBilled
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition
metadata := map[string]any{
"type": core.ActionTypeEvaluator,
"evaluator": map[string]any{
"evaluatorIsBilled": opts.IsBilled,
"evaluatorDisplayName": opts.DisplayName,
"evaluatorDefinition": opts.Definition,
},
}

inputSchema := core.InferSchemaMap(EvaluatorRequest{})
if inputSchema != nil && opts.ConfigSchema != nil {
if _, ok := inputSchema["options"]; ok {
inputSchema["options"] = opts.ConfigSchema
}
}

actionDef := (*evaluator)(core.DefineAction(r, provider, name, core.ActionTypeEvaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
var evalResponses []EvaluationResult
return (*evaluator)(core.DefineActionWithInputSchema(r, name, core.ActionTypeEvaluator, metadata, inputSchema, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
var results []EvaluationResult
for _, datapoint := range req.Dataset {
if datapoint.TestCaseId == "" {
datapoint.TestCaseId = uuid.New().String()
Expand All @@ -143,62 +170,74 @@ func DefineEvaluator(r *registry.Registry, provider, name string, options *Evalu
func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) {
traceId := trace.SpanContextFromContext(ctx).TraceID().String()
spanId := trace.SpanContextFromContext(ctx).SpanID().String()

callbackRequest := EvaluatorCallbackRequest{
Input: *input,
Options: req.Options,
}
evaluatorResponse, err := eval(ctx, &callbackRequest)

result, err := fn(ctx, &callbackRequest)
if err != nil {
failedScore := Score{
Status: ScoreStatusFail.String(),
Error: fmt.Sprintf("Evaluation of test case %s failed: \n %s", input.TestCaseId, err.Error()),
}
failedEvalResult := EvaluationResult{
failedResult := EvaluationResult{
TestCaseId: input.TestCaseId,
Evaluation: []Score{failedScore},
TraceID: traceId,
SpanID: spanId,
}
evalResponses = append(evalResponses, failedEvalResult)
results = append(results, failedResult)
// return error to mark span as failed
return nil, err
}
evaluatorResponse.TraceID = traceId
evaluatorResponse.SpanID = spanId
evalResponses = append(evalResponses, *evaluatorResponse)
return evaluatorResponse, nil

result.TraceID = traceId
result.SpanID = spanId

results = append(results, *result)

return result, nil
})
if err != nil {
logger.FromContext(ctx).Debug("EvaluatorAction", "err", err)
continue
}
}
return &evalResponses, nil
return &results, nil
}))
return actionDef, nil
}

// DefineBatchEvaluator registers the given evaluator function as an action, and
// returns a [Evaluator] that runs it. This method provide the full
// [EvaluatorRequest] to the callback function, giving more flexibilty to the
// user for processing the data, such as batching or parallelization.
func DefineBatchEvaluator(r *registry.Registry, provider, name string, options *EvaluatorOptions, batchEval func(context.Context, *EvaluatorRequest) (*EvaluatorResponse, error)) (Evaluator, error) {
if options == nil {
return nil, errors.New("EvaluatorOptions must be provided")
func DefineBatchEvaluator(r *registry.Registry, name string, opts *EvaluatorOptions, fn BatchEvaluatorFunc) Evaluator {
if name == "" {
panic("ai.DefineBatchEvaluator: batch evaluator name is required")
}

if opts == nil {
opts = &EvaluatorOptions{}
}

metadataMap := map[string]any{}
metadataMap["evaluatorIsBilled"] = options.IsBilled
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition
metadata := map[string]any{
"type": core.ActionTypeEvaluator,
"evaluator": map[string]any{
"evaluatorIsBilled": opts.IsBilled,
"evaluatorDisplayName": opts.DisplayName,
"evaluatorDefinition": opts.Definition,
},
}

return (*evaluator)(core.DefineAction(r, provider, name, core.ActionTypeEvaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
return (*evaluator)(core.DefineAction(r, name, core.ActionTypeEvaluator, metadata, fn))
}

// LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator].
// It returns nil if the evaluator was not defined.
func LookupEvaluator(r *registry.Registry, provider, name string) Evaluator {
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, core.ActionTypeEvaluator, provider, name))
func LookupEvaluator(r *registry.Registry, name string) Evaluator {
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, core.ActionTypeEvaluator, name))
}

// Evaluate calls the retrivers with provided options.
Expand Down
Loading
Loading