Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -289,24 +289,49 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
objSchema.InsertAtStart(TypePropertyName, "string");
}

// Include the type keyword in nullable enum types
if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type)?.IsEnum is true && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
}

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// In certain configurations STJ represents .NET numeric types as ["string", "number"], which will then lead to an error.
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType))
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType, out bool isNullable))
{
// We don't want to emit any array for "type". In this case we know it contains "integer" or "number",
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama.
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = numericType;
if (isNullable)
{
// If the type is nullable, we still need use a type array
obj[TypePropertyName] = new JsonArray { (JsonNode)numericType, (JsonNode)"null" };
}
else
{
obj[TypePropertyName] = (JsonNode)numericType;
}

_ = obj.Remove(PatternPropertyName);
}

if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type) is Type nullableElement)
{
// Account for bug https://github.com/dotnet/runtime/issues/117493
// To be removed once System.Text.Json v10 becomes the lowest supported version.
// null not inserted in the type keyword for root-level Nullable<T> types.
if (objSchema.TryGetPropertyValue(TypePropertyName, out JsonNode? typeKeyWord) &&
typeKeyWord?.GetValueKind() is JsonValueKind.String)
{
string typeValue = typeKeyWord.GetValue<string>()!;
if (typeValue is not "null")
{
objSchema[TypePropertyName] = new JsonArray { (JsonNode)typeValue, (JsonNode)"null" };
}
}

// Include the type keyword in nullable enum types
if (nullableElement.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
}
}
}

if (ctx.Path.IsEmpty && hasDefaultValue)
Expand Down Expand Up @@ -601,11 +626,12 @@ static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions seriali
}
}

private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType)
private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType, out bool isNullable)
{
numericType = null;
isNullable = false;

if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray { Count: 2 } typeArray)
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
{
bool allowString = false;

Expand All @@ -617,11 +643,23 @@ private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateCont
switch (type)
{
case "integer" or "number":
if (numericType is not null)
{
// Conflicting numeric type
return false;
}

numericType = type;
break;
case "string":
allowString = true;
break;
case "null":
isNullable = true;
break;
default:
// keyword is not valid in the context of numeric types.
return false;
}
}
}
Expand Down Expand Up @@ -665,7 +703,7 @@ private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)

if (defaultValue is null || (defaultValue == DBNull.Value && parameterType != typeof(DBNull)))
{
return parameterType.IsValueType
return parameterType.IsValueType && Nullable.GetUnderlyingType(parameterType) is null
#if NET
? RuntimeHelpers.GetUninitializedObject(parameterType)
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ private static class ReflectionHelpers
public static bool IsBuiltInConverter(JsonConverter converter) =>
converter.GetType().Assembly == typeof(JsonConverter).Assembly;

public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null;

public static Type GetElementType(JsonTypeInfo typeInfo)
{
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type");
Expand Down
16 changes: 10 additions & 6 deletions src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,20 +452,24 @@ JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema)

bool IsNullableSchema(ref GenerationState state)
{
// A schema is marked as nullable if either
// A schema is marked as nullable if either:
// 1. We have a schema for a property where either the getter or setter are marked as nullable.
// 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable
// 2. We have a schema for a Nullable<T> type.
// 3. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable.

if (propertyInfo != null || parameterInfo != null)
{
return !isNonNullableType;
}
else

if (Nullable.GetUnderlyingType(typeInfo.Type) is not null)
{
return ReflectionHelpers.CanBeNull(typeInfo.Type) &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
return true;
}

return !typeInfo.Type.IsValueType &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,29 @@ public static void EqualFunctionCallParameters(
public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null)
=> AreJsonEquivalentValues(expected, actual, options);

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
/// <summary>
/// Asserts that the two JSON values are equal.
/// </summary>
public static void EqualJsonValues(JsonElement expectedJson, JsonElement actualJson, string? propertyName = null)
{
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
if (!JsonNode.DeepEquals(
JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions)))
JsonSerializer.SerializeToNode(expectedJson, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualJson, AIJsonUtilities.DefaultOptions)))
{
string message = propertyName is null
? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}";
? $"JSON result does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}"
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}";

throw new XunitException(message);
}
}

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
{
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
EqualJsonValues(expectedElement, actualElement, propertyName);

static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options)
=> value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,21 @@ public static void CreateFunctionJsonSchema_TreatsIntegralTypesAsInteger_EvenWit
int i = 0;
foreach (JsonProperty property in schemaParameters.EnumerateObject())
{
string numericType = Type.GetTypeCode(parameters[i].ParameterType) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
? "number"
: "integer";
bool isNullable = false;
Type type = parameters[i].ParameterType;
if (Nullable.GetUnderlyingType(type) is { } elementType)
{
type = elementType;
isNullable = true;
}

string numericType = Type.GetTypeCode(type) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
? "\"number\""
: "\"integer\"";

JsonElement expected = JsonDocument.Parse($$"""
{
"type": "{{numericType}}"
"type": {{(isNullable ? $"[{numericType}, \"null\"]" : numericType)}}
}
""").RootElement;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -854,6 +855,71 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
}

[Fact]
public async Task AIFunctionFactory_NullableParameters()
{
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);

AIFunction f = AIFunctionFactory.Create(
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
serializerOptions: JsonContext.Default.Options);

JsonElement expectedSchema = JsonDocument.Parse("""
{
"type": "object",
"properties": {
"limit": {
"type": ["integer", "null"],
"default": null
},
"from": {
"type": ["string", "null"],
"format": "date-time",
"default": null
}
}
}
""").RootElement;

AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);

object? result = await f.InvokeAsync();
Assert.Contains("[1,1,1,1]", result?.ToString());
}

[Fact]
public async Task AIFunctionFactory_NullableParameters_AllowReadingFromString()
{
JsonSerializerOptions options = new(JsonContext.Default.Options) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);

AIFunction f = AIFunctionFactory.Create(
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
serializerOptions: options);

JsonElement expectedSchema = JsonDocument.Parse("""
{
"type": "object",
"properties": {
"limit": {
"type": ["integer", "null"],
"default": null
},
"from": {
"type": ["string", "null"],
"format": "date-time",
"default": null
}
}
}
""").RootElement;

AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);

object? result = await f.InvokeAsync();
Assert.Contains("[1,1,1,1]", result?.ToString());
}

[Fact]
public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
{
Expand Down Expand Up @@ -959,5 +1025,7 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(StructWithDefaultCtor))]
[JsonSerializable(typeof(B))]
[JsonSerializable(typeof(int?))]
[JsonSerializable(typeof(DateTime?))]
private partial class JsonContext : JsonSerializerContext;
}
10 changes: 8 additions & 2 deletions test/Shared/JsonSchemaExporter/TestData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ internal sealed record TestData<T>(
T? Value,
[StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema,
IEnumerable<T?>? AdditionalValues = null,
object? ExporterOptions = null,
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
System.Text.Json.Schema.JsonSchemaExporterOptions? ExporterOptions = null,
#endif
JsonSerializerOptions? Options = null,
bool WritesNumbersAsStrings = false)
: ITestData
Expand All @@ -22,7 +24,9 @@ internal sealed record TestData<T>(

public Type Type => typeof(T);
object? ITestData.Value => Value;
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
object? ITestData.ExporterOptions => ExporterOptions;
#endif
JsonNode ITestData.ExpectedJsonSchema { get; } =
JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions)
?? throw new ArgumentNullException("schema must not be null");
Expand All @@ -32,7 +36,7 @@ IEnumerable<ITestData> ITestData.GetTestDataForAllValues()
yield return this;

if (default(T) is null &&
#if NET9_0_OR_GREATER
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
ExporterOptions is System.Text.Json.Schema.JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable: false } &&
#endif
Value is not null)
Expand All @@ -58,7 +62,9 @@ public interface ITestData

JsonNode ExpectedJsonSchema { get; }

#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
object? ExporterOptions { get; }
#endif

JsonSerializerOptions? Options { get; }

Expand Down
Loading
Loading