Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/Libraries/Microsoft.Extensions.AI.Abstractions/AITool.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,55 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Microsoft.Shared.Collections;

namespace Microsoft.Extensions.AI;

#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods

/// <summary>Represents a tool that can be specified to an AI service.</summary>
public class AITool
[DebuggerDisplay("{DebuggerDisplay,nq}")]
public abstract class AITool
{
/// <summary>Initializes a new instance of the <see cref="AITool"/> class.</summary>
protected AITool()
{
}

/// <summary>Gets the name of the tool.</summary>
public virtual string Name => GetType().Name;

/// <summary>Gets a description of the tool, suitable for use in describing the purpose to a model.</summary>
public virtual string Description => string.Empty;

/// <summary>Gets any additional properties associated with the tool.</summary>
public virtual IReadOnlyDictionary<string, object?> AdditionalProperties => EmptyReadOnlyDictionary<string, object?>.Instance;

/// <inheritdoc/>
public override string ToString() => Name;

/// <summary>Gets the string to display in the debugger for this instance.</summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private string DebuggerDisplay
{
get
{
StringBuilder sb = new(Name);

if (Description is string description && !string.IsNullOrEmpty(description))
{
_ = sb.Append(" (").Append(description).Append(')');
}

foreach (var entry in AdditionalProperties)
{
_ = sb.Append(", ").Append(entry.Key).Append(" = ").Append(entry.Value);
}

return sb.ToString();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.Extensions.AI;

/// <summary>Represents a tool that can be specified to an AI service to enable it to execute code it generates.</summary>
/// <remarks>
/// This tool does not itself implement code interpration. It is a marker that can be used to inform a service
/// that the service is allowed to execute its generated code if the service is capable of doing so.
/// </remarks>
public class CodeInterpreterTool : AITool
{
/// <summary>Initializes a new instance of the <see cref="CodeInterpreterTool"/> class.</summary>
public CodeInterpreterTool()
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Text.Json;
using System.Threading;
Expand All @@ -12,15 +11,8 @@
namespace Microsoft.Extensions.AI;

/// <summary>Represents a function that can be described to an AI service and invoked.</summary>
[DebuggerDisplay("{DebuggerDisplay,nq}")]
public abstract class AIFunction : AITool
{
/// <summary>Gets the name of the function.</summary>
public abstract string Name { get; }

/// <summary>Gets a description of the function, suitable for use in describing the purpose to a model.</summary>
public abstract string Description { get; }

/// <summary>Gets a JSON Schema describing the function and its input parameters.</summary>
/// <remarks>
/// <para>
Expand Down Expand Up @@ -56,11 +48,8 @@ public abstract class AIFunction : AITool
/// </remarks>
public virtual MethodInfo? UnderlyingMethod => null;

/// <summary>Gets any additional properties associated with the function.</summary>
public virtual IReadOnlyDictionary<string, object?> AdditionalProperties => EmptyReadOnlyDictionary<string, object?>.Instance;

/// <summary>Gets a <see cref="JsonSerializerOptions"/> that can be used to marshal function parameters.</summary>
public virtual JsonSerializerOptions? JsonSerializerOptions => AIJsonUtilities.DefaultOptions;
public virtual JsonSerializerOptions JsonSerializerOptions => AIJsonUtilities.DefaultOptions;

/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
Expand All @@ -75,18 +64,11 @@ public abstract class AIFunction : AITool
return InvokeCoreAsync(arguments, cancellationToken);
}

/// <inheritdoc/>
public override string ToString() => Name;

/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>The result of the function's execution.</returns>
protected abstract Task<object?> InvokeCoreAsync(
IEnumerable<KeyValuePair<string, object?>> arguments,
CancellationToken cancellationToken);

/// <summary>Gets the string to display in the debugger for this instance.</summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private string DebuggerDisplay => string.IsNullOrWhiteSpace(Description) ? Name : $"{Name} ({Description})";
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,25 @@ private static (RunCreationOptions RunOptions, List<FunctionResultContent>? Tool
{
foreach (AITool tool in tools)
{
if (tool is AIFunction aiFunction)
switch (tool)
{
bool? strict =
aiFunction.AdditionalProperties.TryGetValue("Strict", out object? strictObj) &&
strictObj is bool strictValue ?
strictValue : null;

var functionParameters = BinaryData.FromBytes(
JsonSerializer.SerializeToUtf8Bytes(
JsonSerializer.Deserialize(aiFunction.JsonSchema, OpenAIJsonContext.Default.OpenAIChatToolJson)!,
OpenAIJsonContext.Default.OpenAIChatToolJson));

runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Name, aiFunction.Description, functionParameters, strict));
case AIFunction aiFunction:
bool? strict =
aiFunction.AdditionalProperties.TryGetValue("Strict", out object? strictObj) &&
strictObj is bool strictValue ?
strictValue : null;

var functionParameters = BinaryData.FromBytes(
JsonSerializer.SerializeToUtf8Bytes(
JsonSerializer.Deserialize(aiFunction.JsonSchema, OpenAIJsonContext.Default.OpenAIChatToolJson)!,
OpenAIJsonContext.Default.OpenAIChatToolJson));

runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Name, aiFunction.Description, functionParameters, strict));
break;

case CodeInterpreterTool:
runOptions.ToolsOverride.Add(ToolDefinition.CreateCodeInterpreter());
break;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ public static ChatOptions FromOpenAIOptions(ChatCompletionOptions? options)
{
foreach (ChatTool tool in tools)
{
result.Tools ??= [];
result.Tools.Add(FromOpenAIChatTool(tool));
if (FromOpenAIChatTool(tool) is { } convertedTool)
{
(result.Tools ??= []).Add(convertedTool);
}
}

using var toolChoiceJson = JsonDocument.Parse(JsonModelHelpers.Serialize(options.ToolChoice).ToMemory());
Expand Down Expand Up @@ -407,17 +409,24 @@ public static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
return result;
}

private static AITool FromOpenAIChatTool(ChatTool chatTool)
private static AITool? FromOpenAIChatTool(ChatTool chatTool)
{
AdditionalPropertiesDictionary additionalProperties = [];
if (chatTool.FunctionSchemaIsStrict is bool strictValue)
switch (chatTool.Kind)
{
additionalProperties["Strict"] = strictValue;
}
case ChatToolKind.Function:
AdditionalPropertiesDictionary additionalProperties = [];
if (chatTool.FunctionSchemaIsStrict is bool strictValue)
{
additionalProperties["Strict"] = strictValue;
}

OpenAIChatToolJson openAiChatTool = JsonSerializer.Deserialize(chatTool.FunctionParameters.ToMemory().Span, OpenAIJsonContext.Default.OpenAIChatToolJson)!;
JsonElement schema = JsonSerializer.SerializeToElement(openAiChatTool, OpenAIJsonContext.Default.OpenAIChatToolJson);
return new MetadataOnlyAIFunction(chatTool.FunctionName, chatTool.FunctionDescription, schema, additionalProperties);

OpenAIChatToolJson openAiChatTool = JsonSerializer.Deserialize(chatTool.FunctionParameters.ToMemory().Span, OpenAIJsonContext.Default.OpenAIChatToolJson)!;
JsonElement schema = JsonSerializer.SerializeToElement(openAiChatTool, OpenAIJsonContext.Default.OpenAIChatToolJson);
return new MetadataOnlyAIFunction(chatTool.FunctionName, chatTool.FunctionDescription, schema, additionalProperties);
default:
return null;
}
}

private sealed class MetadataOnlyAIFunction(string name, string description, JsonElement schema, IReadOnlyDictionary<string, object?> additionalProps) : AIFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ static bool IsAsyncMethod(MethodInfo method)
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);

// Create a marshaller that simply looks up the parameter by name in the arguments dictionary.
return (IReadOnlyDictionary<string, object?> arguments, AIFunctionContext? _) =>
return (arguments, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
if (arguments.TryGetValue(parameter.Name, out object? value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public AIFunctionFactoryOptions()
public string? Description { get; set; }

/// <summary>
/// Gets or sets additional values to store on the resulting <see cref="AIFunction.AdditionalProperties" /> property.
/// Gets or sets additional values to store on the resulting <see cref="AITool.AdditionalProperties" /> property.
/// </summary>
/// <remarks>
/// This property can be used to provide arbitrary information about the function.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Xunit;

namespace Microsoft.Extensions.AI;

public class AIToolTests
{
[Fact]
public void Constructor_Roundtrips()
{
DerivedAITool tool = new();
Assert.Equal(nameof(DerivedAITool), tool.Name);
Assert.Equal(nameof(DerivedAITool), tool.ToString());
Assert.Empty(tool.Description);
Assert.Empty(tool.AdditionalProperties);
}

private sealed class DerivedAITool : AITool;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Xunit;

namespace Microsoft.Extensions.AI;

public class CodeInterpreterToolTests
{
[Fact]
public void Constructor_Roundtrips()
{
var tool = new CodeInterpreterTool();
Assert.Equal(nameof(CodeInterpreterTool), tool.Name);
Assert.Empty(tool.Description);
Assert.Empty(tool.AdditionalProperties);
Assert.Equal(nameof(CodeInterpreterTool), tool.ToString());
}
}