diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 833b1100da7..fe41c128f1c 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -443,25 +443,7 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - return null; - } - - if (selector != null) - { - source = TranslateSelect(source, selector); - } - - var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); - projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, resultType, _typeMappingSource.FindMapping(resultType)); - - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); - } + => TranslateAggregate(source, selector, resultType, "AVG"); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -841,26 +823,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - return null; - } - - if (selector != null) - { - source = TranslateSelect(source, selector); - } - - var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); - - projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping); - - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); - } + => TranslateAggregate(source, selector, resultType, "MAX"); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -869,26 +832,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - return null; - } - - if (selector != null) - { - source = TranslateSelect(source, selector); - } - - var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); - - projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping); - - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); - } + => TranslateAggregate(source, selector, resultType, "MIN"); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -1241,7 +1185,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s projection = _sqlExpressionFactory.Function("SUM", new[] { projection }, serverOutputType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); + return AggregateResultShaper(source, projection, resultType); } /// @@ -1515,6 +1459,35 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s #endregion Queryable collection support + private ShapedQueryExpression? TranslateAggregate(ShapedQueryExpression source, LambdaExpression? selector, Type resultType, string functionName) + { + var selectExpression = (SelectExpression)source.QueryExpression; + if (selectExpression.IsDistinct + || selectExpression.Limit != null + || selectExpression.Offset != null) + { + return null; + } + + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + if (!_subquery && resultType.IsNullableType()) + { + // For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094. + // Note that relational databases typically return null, which propagates. Cosmos will instead return no elements, + // and hence for Cosmos only we need to change no elements into null. + source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault); + } + + var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); + projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType)); + + return AggregateResultShaper(source, projection, resultType); + } + private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate) { var select = (SelectExpression)source.QueryExpression; @@ -1695,7 +1668,6 @@ private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, private static ShapedQueryExpression AggregateResultShaper( ShapedQueryExpression source, Expression projection, - bool throwOnNullResult, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -1706,29 +1678,7 @@ private static ShapedQueryExpression AggregateResultShaper( var nullableResultType = resultType.MakeNullable(); Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); - if (throwOnNullResult) - { - var resultVariable = Expression.Variable(nullableResultType, "result"); - var returnValueForNull = resultType.IsNullableType() - ? (Expression)Expression.Constant(null, resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.SequenceContainsNoElements)), - resultType); - - shaper = Expression.Block( - new[] { resultVariable }, - Expression.Assign(resultVariable, shaper), - Expression.Condition( - Expression.Equal(resultVariable, Expression.Default(nullableResultType)), - returnValueForNull, - resultType != resultVariable.Type - ? Expression.Convert(resultVariable, resultType) - : resultVariable)); - } - else if (resultType != shaper.Type) + if (resultType != shaper.Type) { shaper = Expression.Convert(shaper, resultType); } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index ca96957f49f..ee14b47e0a1 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -518,7 +518,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType); + => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, resultType); /// protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) @@ -971,7 +971,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK } return TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), resultType); } /// @@ -990,7 +990,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK } return TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), resultType); } /// @@ -1241,7 +1241,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType); + => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, resultType); /// protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count) @@ -1966,7 +1966,6 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape ShapedQueryExpression source, LambdaExpression? selectorLambda, Func methodGenerator, - bool throwWhenEmpty, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -2012,48 +2011,13 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape new Dictionary { { new ProjectionMember(), translation } }); selectExpression.ClearOrdering(); - Expression shaper; - - if (throwWhenEmpty) - { - // Avg/Max/Min case. - // We always read nullable value - // If resultType is nullable then we always return null. Only non-null result shows throwing behavior. - // otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default - // otherwise, server would return null only if it is empty, and we throw - var nullableResultType = resultType.MakeNullable(); - shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); - var resultVariable = Expression.Variable(nullableResultType, "result"); - var returnValueForNull = resultType.IsNullableType() - ? (Expression)Expression.Default(resultType) - : translation.Type.IsNullableType() - ? Expression.Default(resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.SequenceContainsNoElements)), - resultType); - - shaper = Expression.Block( - new[] { resultVariable }, - Expression.Assign(resultVariable, shaper), - Expression.Condition( - Expression.Equal(resultVariable, Expression.Default(nullableResultType)), - returnValueForNull, - resultType != resultVariable.Type - ? Expression.Convert(resultVariable, resultType) - : resultVariable)); - } - else - { - // Sum case. Projection is always non-null. We read nullable value. - shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable()); - if (resultType != shaper.Type) - { - shaper = Expression.Convert(shaper, resultType); - } + // Sum case. Projection is always non-null. We read nullable value. + Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable()); + + if (resultType != shaper.Type) + { + shaper = Expression.Convert(shaper, resultType); } return source.UpdateShaperExpression(shaper); diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs index 7fc328c16fd..e2dd2b90904 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.ComponentModel.DataAnnotations.Schema; + namespace Microsoft.EntityFrameworkCore.Query; #nullable disable @@ -50,6 +52,115 @@ public enum MemberType #endregion 34911 + #region 35094 + + // TODO: Move these tests to a better location. They require nullable properties with nulls in the database. + + [ConditionalFact] + public virtual async Task Min_over_value_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().MinAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Min_over_value_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableVal == null).MinAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Min_over_reference_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().MinAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Min_over_reference_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableRef == null).MinAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Min_over_reference_type_containing_no_data() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.Id < 0).MinAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Max_over_value_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Equal(3.14, await context.Set().MaxAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Max_over_value_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableVal == null).MaxAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Max_over_reference_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Equal("Value", await context.Set().MaxAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Max_over_reference_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableRef == null).MaxAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Max_over_reference_type_containing_no_data() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.Id < 0).MaxAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Average_over_value_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().AverageAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Average_over_value_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableVal == null).AverageAsync(p => p.NullableVal)); + } + + protected class Context35094(DbContextOptions options) : DbContext(options) + { + public DbSet Products { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + => modelBuilder.Entity().HasData( + new Product { Id = 1, NullableRef = "Value", NullableVal = 3.14 }, + new Product { Id = 2, NullableVal = 3.14 }, + new Product { Id = 3, NullableRef = "Value" }); + + public class Product + { + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public int Id { get; set; } + public double? NullableVal { get; set; } + public string NullableRef { get; set; } + } + } + + #endregion 35094 + protected override string StoreName => "AdHocMiscellaneousQueryTests"; diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs index c5a415fb1b9..6f1291dbecd 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs @@ -555,49 +555,33 @@ FROM root c } } - public override async Task Average_no_data_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Average_no_data_nullable(a))).Message); + public override Task Average_no_data_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Average_no_data_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE AVG(c["SupplierID"]) FROM root c WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1)) """); - }); - } - } + }); - public override async Task Average_no_data_cast_to_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Average_no_data_cast_to_nullable(a))).Message); + public override Task Average_no_data_cast_to_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Average_no_data_cast_to_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE AVG(c["OrderID"]) FROM root c WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1)) """); - }); - } - } + }); public override async Task Min_no_data(bool async) { @@ -647,49 +631,33 @@ public override async Task Max_no_data_subquery(bool async) AssertSql(); } - public override async Task Max_no_data_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Max_no_data_nullable(a))).Message); + public override Task Max_no_data_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Max_no_data_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MAX(c["SupplierID"]) FROM root c WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1)) """); - }); - } - } + }); - public override async Task Max_no_data_cast_to_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Max_no_data_cast_to_nullable(a))).Message); + public override Task Max_no_data_cast_to_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Max_no_data_cast_to_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MAX(c["OrderID"]) FROM root c WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1)) """); - }); - } - } + }); public override async Task Min_no_data_subquery(bool async) { @@ -868,49 +836,33 @@ FROM root c """); }); - public override async Task Min_no_data_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Min_no_data_nullable(a))).Message); + public override Task Min_no_data_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Min_no_data_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MIN(c["SupplierID"]) FROM root c WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1)) """); - }); - } - } + }); - public override async Task Min_no_data_cast_to_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Min_no_data_cast_to_nullable(a))).Message); + public override Task Min_no_data_cast_to_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Min_no_data_cast_to_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MIN(c["OrderID"]) FROM root c WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1)) """); - }); - } - } + }); public override Task Min_with_coalesce(bool async) => Fixture.NoSyncTest(