Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public static JsonElement CreateFunctionJsonSchema(

JsonNode parameterSchema = CreateJsonSchemaCore(
type: parameter.ParameterType,
parameterName: parameter.Name,
parameter: parameter,
description: parameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
hasDefaultValue: parameter.HasDefaultValue,
defaultValue: GetDefaultValueNormalized(parameter),
Expand Down Expand Up @@ -178,7 +178,7 @@ public static JsonElement CreateJsonSchema(
{
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
JsonNode schema = CreateJsonSchemaCore(type, parameterName: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions);
JsonNode schema = CreateJsonSchemaCore(type, parameter: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions);

// Finally, apply any schema transformations if specified.
if (inferenceOptions.TransformOptions is { } options)
Expand Down Expand Up @@ -208,7 +208,7 @@ internal static void ValidateSchemaDocument(JsonElement document, [CallerArgumen
#endif
private static JsonNode CreateJsonSchemaCore(
Type? type,
string? parameterName,
ParameterInfo? parameter,
string? description,
bool hasDefaultValue,
object? defaultValue,
Expand Down Expand Up @@ -272,14 +272,14 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
// The resulting schema might be a $ref using a pointer to a different location in the document.
// As JSON pointer doesn't support relative paths, parameter schemas need to fix up such paths
// to accommodate the fact that they're being nested inside of a higher-level schema.
if (parameterName is not null && objSchema.TryGetPropertyValue(RefPropertyName, out JsonNode? paramName))
if (parameter?.Name is not null && objSchema.TryGetPropertyValue(RefPropertyName, out JsonNode? paramName))
{
// Fix up any $ref URIs to match the path from the root document.
string refUri = paramName!.GetValue<string>();
Debug.Assert(refUri is "#" || refUri.StartsWith("#/", StringComparison.Ordinal), $"Expected {nameof(refUri)} to be either # or start with #/, got {refUri}");
refUri = refUri == "#"
? $"#/{PropertiesPropertyName}/{parameterName}"
: $"#/{PropertiesPropertyName}/{parameterName}/{refUri.AsMemory("#/".Length)}";
? $"#/{PropertiesPropertyName}/{parameter.Name}"
: $"#/{PropertiesPropertyName}/{parameter.Name}/{refUri.AsMemory("#/".Length)}";

objSchema[RefPropertyName] = (JsonNode)refUri;
}
Expand Down Expand Up @@ -359,7 +359,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri);
}

ApplyDataAnnotations(parameterName, ref schema, ctx);
ApplyDataAnnotations(ref schema, ctx);

// Finally, apply any user-defined transformations if specified.
if (inferenceOptions.TransformSchemaNode is { } transformer)
Expand Down Expand Up @@ -389,30 +389,30 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema)
}
}

void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSchemaCreateContext ctx)
void ApplyDataAnnotations(ref JsonNode schema, AIJsonSchemaCreateContext ctx)
{
if (ctx.GetCustomAttribute<DisplayNameAttribute>() is { } displayNameAttribute)
if (ResolveAttribute<DisplayNameAttribute>() is { } displayNameAttribute)
{
ConvertSchemaToObject(ref schema)[TitlePropertyName] ??= displayNameAttribute.DisplayName;
}

#if NET || NETFRAMEWORK
if (ctx.GetCustomAttribute<EmailAddressAttribute>() is { } emailAttribute)
if (ResolveAttribute<EmailAddressAttribute>() is { } emailAttribute)
{
ConvertSchemaToObject(ref schema)[FormatPropertyName] ??= "email";
}

if (ctx.GetCustomAttribute<UrlAttribute>() is { } urlAttribute)
if (ResolveAttribute<UrlAttribute>() is { } urlAttribute)
{
ConvertSchemaToObject(ref schema)[FormatPropertyName] ??= "uri";
}

if (ctx.GetCustomAttribute<RegularExpressionAttribute>() is { } regexAttribute)
if (ResolveAttribute<RegularExpressionAttribute>() is { } regexAttribute)
{
ConvertSchemaToObject(ref schema)[PatternPropertyName] ??= regexAttribute.Pattern;
}

if (ctx.GetCustomAttribute<StringLengthAttribute>() is { } stringLengthAttribute)
if (ResolveAttribute<StringLengthAttribute>() is { } stringLengthAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);

Expand All @@ -424,7 +424,7 @@ void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSche
obj[MaxLengthStringPropertyName] ??= stringLengthAttribute.MaximumLength;
}

if (ctx.GetCustomAttribute<MinLengthAttribute>() is { } minLengthAttribute)
if (ResolveAttribute<MinLengthAttribute>() is { } minLengthAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);
if (obj[TypePropertyName] is JsonNode typeNode && typeNode.GetValueKind() is JsonValueKind.String && typeNode.GetValue<string>() is "string")
Expand All @@ -437,7 +437,7 @@ void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSche
}
}

if (ctx.GetCustomAttribute<MaxLengthAttribute>() is { } maxLengthAttribute)
if (ResolveAttribute<MaxLengthAttribute>() is { } maxLengthAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);
if (obj[TypePropertyName] is JsonNode typeNode && typeNode.GetValueKind() is JsonValueKind.String && typeNode.GetValue<string>() is "string")
Expand All @@ -450,7 +450,7 @@ void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSche
}
}

if (ctx.GetCustomAttribute<RangeAttribute>() is { } rangeAttribute)
if (ResolveAttribute<RangeAttribute>() is { } rangeAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);

Expand Down Expand Up @@ -521,12 +521,12 @@ void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSche
#endif

#if NET
if (ctx.GetCustomAttribute<Base64StringAttribute>() is { } base64Attribute)
if (ResolveAttribute<Base64StringAttribute>() is { } base64Attribute)
{
ConvertSchemaToObject(ref schema)[ContentEncodingPropertyName] ??= "base64";
}

if (ctx.GetCustomAttribute<LengthAttribute>() is { } lengthAttribute)
if (ResolveAttribute<LengthAttribute>() is { } lengthAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);

Expand All @@ -550,7 +550,7 @@ void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSche
}
}

if (ctx.GetCustomAttribute<AllowedValuesAttribute>() is { } allowedValuesAttribute)
if (ResolveAttribute<AllowedValuesAttribute>() is { } allowedValuesAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);
if (!obj.ContainsKey(EnumPropertyName))
Expand All @@ -562,7 +562,7 @@ void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSche
}
}

if (ctx.GetCustomAttribute<DeniedValuesAttribute>() is { } deniedValuesAttribute)
if (ResolveAttribute<DeniedValuesAttribute>() is { } deniedValuesAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);

Expand Down Expand Up @@ -597,7 +597,7 @@ static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions seriali
return enumArray;
}

if (ctx.GetCustomAttribute<DataTypeAttribute>() is { } dataTypeAttribute)
if (ResolveAttribute<DataTypeAttribute>() is { } dataTypeAttribute)
{
JsonObject obj = ConvertSchemaToObject(ref schema);
switch (dataTypeAttribute.DataType)
Expand Down Expand Up @@ -629,6 +629,17 @@ static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions seriali
}
}
#endif
TAttribute? ResolveAttribute<TAttribute>()
where TAttribute : Attribute
{
// If this is the root schema, check for any parameter attributes first.
if (ctx.Path.IsEmpty && parameter?.GetCustomAttribute<TAttribute>(inherit: true) is TAttribute attr)
{
return attr;
}

return ctx.GetCustomAttribute<TAttribute>(inherit: true);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,26 @@ public enum MyEnumValue
B = 2
}

[Fact]
public static void CreateFunctionJsonSchema_ReadsParameterDataAnnotationAttributes()
{
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
AIFunction func = AIFunctionFactory.Create(([Range(1, 10)] int num, [StringLength(100, MinimumLength = 1)] string str) => num + str.Length, serializerOptions: options);

using JsonDocument expectedSchema = JsonDocument.Parse("""
{
"type":"object",
"properties": {
"num": { "type":"integer", "minimum": 1, "maximum": 10 },
"str": { "type":"string", "minLength": 1, "maxLength": 100 }
},
"required":["num","str"]
}
""");

AssertDeepEquals(expectedSchema.RootElement, func.JsonSchema);
}

[Fact]
public static void CreateJsonSchema_CanBeBoolean()
{
Expand Down
Loading