diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 308480635d8..09846198802 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -8,8 +8,12 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; +#pragma warning disable CA2213 // Disposable fields should be disposed + namespace Microsoft.Extensions.AI; /// @@ -34,8 +38,15 @@ namespace Microsoft.Extensions.AI; /// invocation requests to that same function. /// /// -public class FunctionInvokingChatClient : DelegatingChatClient +public partial class FunctionInvokingChatClient : DelegatingChatClient { + /// The logger to use for logging information about function invocation. + private readonly ILogger _logger; + + /// The to use for telemetry. + /// This component does not own the instance and should not dispose it. + private readonly ActivitySource? _activitySource; + /// Maximum number of roundtrips allowed to the inner client. private int? _maximumIterationsPerRequest; @@ -43,9 +54,12 @@ public class FunctionInvokingChatClient : DelegatingChatClient /// Initializes a new instance of the class. /// /// The underlying , or the next instance in a chain of clients. - public FunctionInvokingChatClient(IChatClient innerClient) + /// An to use for logging information about function invocation. + public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null) : base(innerClient) { + _logger = logger ?? NullLogger.Instance; + _activitySource = innerClient.GetService(); } /// @@ -562,13 +576,95 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul /// /// The to monitor for cancellation requests. The default is . /// The result of the function invocation. This may be null if the function invocation returned null. - protected virtual Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) + protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) { _ = Throw.IfNull(context); - return context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken); + using Activity? activity = _activitySource?.StartActivity(context.Function.Metadata.Name); + + long startingTimestamp = 0; + if (_logger.IsEnabled(LogLevel.Debug)) + { + startingTimestamp = Stopwatch.GetTimestamp(); + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogInvokingSensitive(context.Function.Metadata.Name, LoggingHelpers.AsJson(context.CallContent.Arguments, context.Function.Metadata.JsonSerializerOptions)); + } + else + { + LogInvoking(context.Function.Metadata.Name); + } + } + + object? result = null; + try + { + result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) + { + if (activity is not null) + { + _ = activity.SetTag("error.type", e.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, e.Message); + } + + if (e is OperationCanceledException) + { + LogInvocationCanceled(context.Function.Metadata.Name); + } + else + { + LogInvocationFailed(context.Function.Metadata.Name, e); + } + + throw; + } + finally + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + TimeSpan elapsed = GetElapsedTime(startingTimestamp); + + if (result is not null && _logger.IsEnabled(LogLevel.Trace)) + { + LogInvocationCompletedSensitive(context.Function.Metadata.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.Metadata.JsonSerializerOptions)); + } + else + { + LogInvocationCompleted(context.Function.Metadata.Name, elapsed); + } + } + } + + return result; } + private static TimeSpan GetElapsedTime(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency))); +#endif + + [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)] + private partial void LogInvoking(string methodName); + + [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)] + private partial void LogInvokingSensitive(string methodName, string arguments); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)] + private partial void LogInvocationCompleted(string methodName, TimeSpan duration); + + [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)] + private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")] + private partial void LogInvocationCanceled(string methodName); + + [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")] + private partial void LogInvocationFailed(string methodName, Exception error); + /// Provides context for a function invocation. public sealed class FunctionInvocationContext { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs index 15010b42068..fa64bcedc78 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -16,15 +18,21 @@ public static class FunctionInvokingChatClientBuilderExtensions /// /// This works by adding an instance of with default options. /// The being used to build the chat pipeline. + /// An optional to use to create a logger for logging function invocations. /// An optional callback that can be used to configure the instance. /// The supplied . - public static ChatClientBuilder UseFunctionInvocation(this ChatClientBuilder builder, Action? configure = null) + public static ChatClientBuilder UseFunctionInvocation( + this ChatClientBuilder builder, + ILoggerFactory? loggerFactory = null, + Action? configure = null) { _ = Throw.IfNull(builder); - return builder.Use(innerClient => + return builder.Use((services, innerClient) => { - var chatClient = new FunctionInvokingChatClient(innerClient); + loggerFactory ??= services.GetService(); + + var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient))); configure?.Invoke(chatClient); return chatClient; }); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs index fc01b8c21b9..b816af150b7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -168,7 +168,7 @@ public override async IAsyncEnumerable CompleteSt } } - private string AsJson(T value) => JsonSerializer.Serialize(value, _jsonSerializerOptions.GetTypeInfo(typeof(T))); + private string AsJson(T value) => LoggingHelpers.AsJson(value, _jsonSerializerOptions); [LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")] private partial void LogInvoked(string methodName); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index a6dfe53adf5..6274c39419b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -17,6 +17,8 @@ using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; +#pragma warning disable S3358 // Ternary operators should not be nested + namespace Microsoft.Extensions.AI; /// A delegating chat client that implements the OpenTelemetry Semantic Conventions for Generative AI systems. @@ -106,6 +108,11 @@ protected override void Dispose(bool disposing) /// public bool EnableSensitiveData { get; set; } + /// + public override object? GetService(Type serviceType, object? serviceKey = null) => + serviceType == typeof(ActivitySource) ? _activitySource : + base.GetService(serviceType, serviceKey); + /// public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) { @@ -254,7 +261,7 @@ private static ChatCompletion ComposeStreamingUpdatesIntoChatCompletion( string? modelId = options?.ModelId ?? _modelId; activity = _activitySource.StartActivity( - $"{OpenTelemetryConsts.GenAI.Chat} {modelId}", + string.IsNullOrWhiteSpace(modelId) ? OpenTelemetryConsts.GenAI.Chat : $"{OpenTelemetryConsts.GenAI.Chat} {modelId}", ActivityKind.Client); if (activity is not null) diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs index c085aaef350..2dce06620a8 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -72,13 +72,19 @@ public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator i advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries }); } + /// + public override object? GetService(Type serviceType, object? serviceKey = null) => + serviceType == typeof(ActivitySource) ? _activitySource : + base.GetService(serviceType, serviceKey); + /// public override async Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(values); - using Activity? activity = CreateAndConfigureActivity(); + using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + string? requestModelId = options?.ModelId ?? _modelId; GeneratedEmbeddings? response = null; Exception? error = null; @@ -93,7 +99,7 @@ public override async Task> GenerateAsync(IEnume } finally { - TraceCompletion(activity, response, error, stopwatch); + TraceCompletion(activity, requestModelId, response, error, stopwatch); } return response; @@ -112,18 +118,20 @@ protected override void Dispose(bool disposing) } /// Creates an activity for an embedding generation request, or returns null if not enabled. - private Activity? CreateAndConfigureActivity() + private Activity? CreateAndConfigureActivity(EmbeddingGenerationOptions? options) { Activity? activity = null; if (_activitySource.HasListeners()) { + string? modelId = options?.ModelId ?? _modelId; + activity = _activitySource.StartActivity( - $"{OpenTelemetryConsts.GenAI.Embed} {_modelId}", + string.IsNullOrWhiteSpace(modelId) ? OpenTelemetryConsts.GenAI.Embed : $"{OpenTelemetryConsts.GenAI.Embed} {modelId}", ActivityKind.Client, default(ActivityContext), [ new(OpenTelemetryConsts.GenAI.Operation.Name, OpenTelemetryConsts.GenAI.Embed), - new(OpenTelemetryConsts.GenAI.Request.Model, _modelId), + new(OpenTelemetryConsts.GenAI.Request.Model, modelId), new(OpenTelemetryConsts.GenAI.SystemName, _modelProvider), ]); @@ -149,6 +157,7 @@ protected override void Dispose(bool disposing) /// Adds embedding generation response information to the activity. private void TraceCompletion( Activity? activity, + string? requestModelId, GeneratedEmbeddings? embeddings, Exception? error, Stopwatch? stopwatch) @@ -167,7 +176,7 @@ private void TraceCompletion( if (_operationDurationHistogram.Enabled && stopwatch is not null) { TagList tags = default; - AddMetricTags(ref tags, responseModelId); + AddMetricTags(ref tags, requestModelId, responseModelId); if (error is not null) { tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName); @@ -180,7 +189,7 @@ private void TraceCompletion( { TagList tags = default; tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "input"); - AddMetricTags(ref tags, responseModelId); + AddMetricTags(ref tags, requestModelId, responseModelId); _tokenUsageHistogram.Record(inputTokens.Value); } @@ -206,13 +215,13 @@ private void TraceCompletion( } } - private void AddMetricTags(ref TagList tags, string? responseModelId) + private void AddMetricTags(ref TagList tags, string? requestModelId, string? responseModelId) { tags.Add(OpenTelemetryConsts.GenAI.Operation.Name, OpenTelemetryConsts.GenAI.Embed); - if (_modelId is string requestModel) + if (requestModelId is not null) { - tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModel); + tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModelId); } tags.Add(OpenTelemetryConsts.GenAI.SystemName, _modelProvider); diff --git a/src/Libraries/Microsoft.Extensions.AI/LoggingHelpers.cs b/src/Libraries/Microsoft.Extensions.AI/LoggingHelpers.cs new file mode 100644 index 00000000000..72a7e283988 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/LoggingHelpers.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S108 // Nested blocks of code should not be left empty +#pragma warning disable S2486 // Generic exceptions should not be ignored + +using System.Text.Json; + +namespace Microsoft.Extensions.AI; + +/// Provides internal helpers for implementing logging. +internal static class LoggingHelpers +{ + /// Serializes as JSON for logging purposes. + public static string AsJson(T value, JsonSerializerOptions? options) + { + if (options?.TryGetTypeInfo(typeof(T), out var typeInfo) is true || + AIJsonUtilities.DefaultOptions.TryGetTypeInfo(typeof(T), out typeInfo)) + { + try + { + return JsonSerializer.Serialize(value, typeInfo); + } + catch + { + } + } + + // If we're unable to get a type info for the value, or if we fail to serialize, + // return an empty JSON object. We do not want lack of type info to disrupt application behavior with exceptions. + return "{}"; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs index 5eacced35b7..64a632d0846 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs @@ -18,7 +18,7 @@ public sealed class TestChatClient : IChatClient public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? CompleteStreamingAsyncCallback { get; set; } - public Func? GetServiceCallback { get; set; } + public Func GetServiceCallback { get; set; } = (_, _) => null; public Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) => CompleteAsyncCallback!.Invoke(chatMessages, options, cancellationToken); @@ -27,7 +27,7 @@ public IAsyncEnumerable CompleteStreamingAsync(IL => CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken); public object? GetService(Type serviceType, object? serviceKey = null) - => GetServiceCallback!(serviceType, serviceKey); + => GetServiceCallback(serviceType, serviceKey); void IDisposable.Dispose() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs index 5b79b1908da..7438edc752e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs @@ -14,13 +14,13 @@ public sealed class TestEmbeddingGenerator : IEmbeddingGenerator, EmbeddingGenerationOptions?, CancellationToken, Task>>>? GenerateAsyncCallback { get; set; } - public Func? GetServiceCallback { get; set; } + public Func GetServiceCallback { get; set; } = (_, _) => null; public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) => GenerateAsyncCallback!.Invoke(values, options, cancellationToken); public object? GetService(Type serviceType, object? serviceKey = null) - => GetServiceCallback!(serviceType, serviceKey); + => GetServiceCallback(serviceType, serviceKey); void IDisposable.Dispose() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 20780d968f7..542851baa69 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -3,15 +3,26 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using OpenTelemetry.Trace; using Xunit; namespace Microsoft.Extensions.AI; public class FunctionInvokingChatClientTests { + [Fact] + public void InvalidArgs_Throws() + { + Assert.Throws("innerClient", () => new FunctionInvokingChatClient(null!)); + Assert.Throws("builder", () => ((ChatClientBuilder)null!).UseFunctionInvocation()); + } + [Fact] public void Ctor_HasExpectedDefaults() { @@ -294,6 +305,89 @@ public async Task RejectsMultipleChoicesAsync() Assert.Single(chat); // It didn't add anything to the chat history } + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task FunctionInvocationsLogged(LogLevel level) + { + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], configurePipeline: b => b.Use(c => new FunctionInvokingChatClient(c, services.GetRequiredService>()))); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")), + entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\""))); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")), + entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result"))); + } + else + { + Assert.Empty(clp.Logger.Entries); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry) + { + string sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using TracerProvider? tracerProvider = enableTelemetry ? + OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build() : + null; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], configurePipeline: b => b.Use(c => + new FunctionInvokingChatClient( + new OpenTelemetryChatClient(c, sourceName: sourceName)))); + + if (enableTelemetry) + { + Assert.Collection(activities, + activity => Assert.Equal("chat", activity.DisplayName), + activity => Assert.Equal("Func1", activity.DisplayName), + activity => Assert.Equal("chat", activity.DisplayName)); + } + else + { + Assert.Empty(activities); + } + } + private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan,