Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#pragma warning disable S2333 // Redundant modifiers should not be used
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable SA1202 // Public members should come before private members
#pragma warning disable SA1203 // Constants should appear before fields

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -825,6 +826,23 @@ static bool IsAsyncMethod(MethodInfo method)
{
try
{
if (value is string text && IsPotentiallyJson(text))
{
Debug.Assert(typeInfo.Type != typeof(string), "string parameters should not enter this branch.");

// Account for the parameter potentially being a JSON string.
// The value is a string but the type is not. Try to deserialize it under the assumption that it's JSON.
// If it's not, we'll fall through to the default path that makes it valid JSON and then tries to deserialize.
try
{
return JsonSerializer.Deserialize(text, typeInfo);
}
catch (JsonException)
{
// If the string is not valid JSON, fall through to the round-trip.
}
}

string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
return JsonSerializer.Deserialize(json, typeInfo);
}
Expand Down Expand Up @@ -1021,6 +1039,35 @@ private record struct DescriptorKey(
AIJsonSchemaCreateOptions SchemaOptions);
}

/// <summary>
/// Quickly checks if the specified string is potentially JSON
/// by checking if the first non-whitespace characters are valid JSON start tokens.
/// </summary>
/// <param name="value">The string to check.</param>
/// <returns>If <see langword="false"/> then the string is definitely not valid JSON.</returns>
private static bool IsPotentiallyJson(string value) => PotentiallyJsonRegex().IsMatch(value);
#if NET
[GeneratedRegex(PotentiallyJsonRegexString, RegexOptions.IgnorePatternWhitespace)]
private static partial Regex PotentiallyJsonRegex();
#else
private static Regex PotentiallyJsonRegex() => _potentiallyJsonRegex;
private static readonly Regex _potentiallyJsonRegex = new(PotentiallyJsonRegexString, RegexOptions.IgnorePatternWhitespace | RegexOptions.Compiled);
#endif
private const string PotentiallyJsonRegexString = """
^\s* # Optional whitespace at the start of the string
( null # null literal
| false # false literal
| true # true literal
| \d # positive number
| -\d # negative number
| " # string
| \[ # start array
| { # start object
| // # Start of single-line comment
| /\* # Start of multi-line comment
)
""";

/// <summary>
/// Removes characters from a .NET member name that shouldn't be used in an AI function name.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,39 @@ public virtual async Task FunctionInvocation_NestedParameters()
AssertUsageAgainstActivities(response, activities);
}

[ConditionalFact]
public virtual async Task FunctionInvocation_ArrayParameter()
{
SkipIfNotEnabled();

var sourceName = Guid.NewGuid().ToString();
var activities = new List<Activity>();
using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
.AddSource(sourceName)
.AddInMemoryExporter(activities)
.Build();

using var chatClient = new FunctionInvokingChatClient(
new OpenTelemetryChatClient(_chatClient, sourceName: sourceName));

List<ChatMessage> messages =
[
new(ChatRole.User, "Can you add bacon, lettuce, and tomatoes to Peter's shopping cart?")
];

string? shopperName = null;
List<string> shoppingCart = [];
AIFunction func = AIFunctionFactory.Create((string[] items, string shopperId) => { shoppingCart.AddRange(items); shopperName = shopperId; }, "AddItemsToShoppingCart");
var response = await chatClient.GetResponseAsync(messages, new()
{
Tools = [func]
});

Assert.Equal("Peter", shopperName);
Assert.Equal(["bacon", "lettuce", "tomatoes"], shoppingCart);
AssertUsageAgainstActivities(response, activities);
}

private static void AssertUsageAgainstActivities(ChatResponse response, List<Activity> activities)
{
// If the underlying IChatClient provides usage data, function invocation should aggregate the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.ComponentModel;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -75,6 +76,66 @@ public async Task Parameters_MissingRequiredParametersFail_Async()
}
}

[Fact]
public async Task Parameters_ToleratesJsonEncodedParameters()
{
AIFunction func = AIFunctionFactory.Create((int x, int y, int z, int w, int u) => x + y + z + w + u);

var result = await func.InvokeAsync(new()
{
["x"] = "1",
["y"] = JsonNode.Parse("2"),
["z"] = JsonDocument.Parse("3"),
["w"] = JsonDocument.Parse("4").RootElement,
["u"] = 5M, // boxed decimal cannot be cast to int, requires conversion
});

AssertExtensions.EqualFunctionCallResults(15, result);
}

[Theory]
[InlineData(" null")]
[InlineData(" false ")]
[InlineData("true ")]
[InlineData("42")]
[InlineData("0.0")]
[InlineData("-1e15")]
[InlineData(" \"I am a string!\" ")]
[InlineData(" {}")]
[InlineData("[]")]
public async Task Parameters_ToleratesJsonStringParameters(string jsonStringParam)
{
AIFunction func = AIFunctionFactory.Create((JsonElement param) => param);
JsonElement expectedResult = JsonDocument.Parse(jsonStringParam).RootElement;

var result = await func.InvokeAsync(new()
{
["param"] = jsonStringParam
});

AssertExtensions.EqualFunctionCallResults(expectedResult, result);
}

[Theory]
[InlineData("")]
[InlineData(" \r\n")]
[InlineData("I am a string!")]
[InlineData("/* Code snippet */ int main(void) { return 0; }")]
[InlineData("let rec Y F x = F (Y F) x")]
[InlineData("+3")]
public async Task Parameters_ToleratesInvalidJsonStringParameters(string invalidJsonParam)
{
AIFunction func = AIFunctionFactory.Create((JsonElement param) => param);
JsonElement expectedResult = JsonDocument.Parse(JsonSerializer.Serialize(invalidJsonParam, JsonContext.Default.String)).RootElement;

var result = await func.InvokeAsync(new()
{
["param"] = invalidJsonParam
});

AssertExtensions.EqualFunctionCallResults(expectedResult, result);
}

[Fact]
public async Task Parameters_MappedByType_Async()
{
Expand Down
Loading