Skip to content

CSHARP-5572: Implement new KnownSerializerFinder #1700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,38 @@ public static TValue Deserialize<TValue>(this IBsonSerializer<TValue> serializer
return serializer.Deserialize(context, args);
}

/// <summary>
/// Gets the serializer for a base type starting from a serializer for a derived type.
/// </summary>
/// <param name="serializer">The serializer for the derived type.</param>
/// <param name="baseType">The base type.</param>
/// <returns>The serializer for the base type.</returns>
public static IBsonSerializer GetBaseTypeSerializer(this IBsonSerializer serializer, Type baseType)
{
if (!baseType.IsAssignableFrom(serializer.ValueType))
{
throw new ArgumentException($"{baseType} is not assignable from {serializer.ValueType}.");
}

return BsonSerializer.LookupSerializer(baseType); // TODO: should be able to navigate from serializer
}

/// <summary>
/// Gets the serializer for a derived type starting from a serializer for a base type.
/// </summary>
/// <param name="serializer">The serializer for the base type.</param>
/// <param name="derivedType">The derived type.</param>
/// <returns>The serializer for the derived type.</returns>
public static IBsonSerializer GetDerivedTypeSerializer(this IBsonSerializer serializer, Type derivedType)
{
if (!serializer.ValueType.IsAssignableFrom(derivedType))
{
throw new ArgumentException($"{serializer.ValueType} is not assignable from {derivedType}.");
}

return BsonSerializer.LookupSerializer(derivedType); // TODO: should be able to navigate from serializer
}

/// <summary>
/// Gets the discriminator convention for a serializer.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,73 @@ public interface INullableSerializer
/// </summary>
public static class NullableSerializer
{
private readonly static IBsonSerializer __nullableBooleanInstance = new NullableSerializer<bool>(BooleanSerializer.Instance);
private readonly static IBsonSerializer __nullableDecimalInstance = new NullableSerializer<decimal>(DecimalSerializer.Instance);
private readonly static IBsonSerializer __nullableDecimal128Instance = new NullableSerializer<Decimal128>(Decimal128Serializer.Instance);
private readonly static IBsonSerializer __nullableDoubleInstance = new NullableSerializer<double>(DoubleSerializer.Instance);
private readonly static IBsonSerializer __nullableInt32Instance = new NullableSerializer<int>(Int32Serializer.Instance);
private readonly static IBsonSerializer __nullableInt64Instance = new NullableSerializer<long>(Int64Serializer.Instance);
private readonly static IBsonSerializer __nullableLocalDateTimeInstance = new NullableSerializer<DateTime>(DateTimeSerializer.LocalInstance);
private readonly static IBsonSerializer __nullableObjectIdInstance = new NullableSerializer<ObjectId>(ObjectIdSerializer.Instance);
private readonly static IBsonSerializer __nullableSingleInstance = new NullableSerializer<float>(SingleSerializer.Instance);
private readonly static IBsonSerializer __nullableStandardGuidInstance = new NullableSerializer<Guid>(GuidSerializer.StandardInstance);
private readonly static IBsonSerializer __nullableUtcDateTimeInstance = new NullableSerializer<DateTime>(DateTimeSerializer.UtcInstance);

/// <summary>
/// Gets a serializer for nullable bools.
/// </summary>
public static IBsonSerializer NullableBooleanInstance => __nullableBooleanInstance;

/// <summary>
/// Gets a serializer for nullable decimals.
/// </summary>
public static IBsonSerializer NullableDecimalInstance => __nullableDecimalInstance;

/// <summary>
/// Gets a serializer for nullable Decimal128s.
/// </summary>
public static IBsonSerializer NullableDecimal128Instance => __nullableDecimal128Instance;

/// <summary>
/// Gets a serializer for nullable doubles.
/// </summary>
public static IBsonSerializer NullableDoubleInstance => __nullableDoubleInstance;

/// <summary>
/// Gets a serializer for nullable ints.
/// </summary>
public static IBsonSerializer NullableInt32Instance => __nullableInt32Instance;

/// <summary>
/// Gets a serializer for nullable longs.
/// </summary>
public static IBsonSerializer NullableInt64Instance => __nullableInt64Instance;

/// <summary>
/// Gets a serializer for local DateTime.
/// </summary>
public static IBsonSerializer NullableLocalDateTimeInstance => __nullableLocalDateTimeInstance;

/// <summary>
/// Gets a serializer for nullable floats.
/// </summary>
public static IBsonSerializer NullableSingleInstance => __nullableSingleInstance;

/// <summary>
/// Gets a serializer for nullable ObjectIds.
/// </summary>
public static IBsonSerializer NullableObjectIdInstance => __nullableObjectIdInstance;

/// <summary>
/// Gets a serializer for nullable Guids with standard representation.
/// </summary>
public static IBsonSerializer NullableStandardGuidInstance => __nullableStandardGuidInstance;

/// <summary>
/// Gets a serializer for UTC DateTime.
/// </summary>
public static IBsonSerializer NullableUtcDateTimeInstance => __nullableUtcDateTimeInstance;

/// <summary>
/// Creates a NullableSerializer.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,17 @@ public static TValue GetConstantValue<TValue>(this Expression expression, Expres
var message = $"Expression must be a constant: {expression} in {containingExpression}.";
throw new ExpressionNotSupportedException(message);
}

public static bool IsConvert(this Expression expression, out Expression operand)
{
if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression)
{
operand = unaryExpression.Operand;
return true;
}

operand = null;
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ private AstStage RenderProjectStage(
out IBsonSerializer<TOutput> outputSerializer)
{
var partiallyEvaluatedOutput = (Expression<Func<TGrouping, TOutput>>)PartialEvaluator.EvaluatePartially(_output);
var context = TranslationContext.Create(translationOptions);
var parameter = partiallyEvaluatedOutput.Parameters.Single();
var context = TranslationContext.Create(partiallyEvaluatedOutput, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions);
var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true);
var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation);
outputSerializer = (IBsonSerializer<TOutput>)projectSerializer;
Expand Down Expand Up @@ -106,7 +107,8 @@ protected override AstStage RenderGroupingStage(
out IBsonSerializer<IGrouping<TValue, TInput>> groupingOutputSerializer)
{
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var context = TranslationContext.Create(translationOptions);
var parameter = partiallyEvaluatedGroupBy.Parameters.Single();
var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);

var valueSerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
Expand Down Expand Up @@ -150,7 +152,8 @@ protected override AstStage RenderGroupingStage(
out IBsonSerializer<IGrouping<AggregateBucketAutoResultId<TValue>, TInput>> groupingOutputSerializer)
{
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var context = TranslationContext.Create(translationOptions);
var parameter = partiallyEvaluatedGroupBy.Parameters.Single();
var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);

var valueSerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
Expand Down Expand Up @@ -188,7 +191,8 @@ protected override AstStage RenderGroupingStage(
out IBsonSerializer<IGrouping<TValue, TInput>> groupingOutputSerializer)
{
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var context = TranslationContext.Create(translationOptions);
var parameter = partiallyEvaluatedGroupBy.Parameters.Single();
var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
var groupBySerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;
using MongoDB.Bson.Serialization;

namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders;

internal static class KnownSerializerFinder
{
public static KnownSerializerMap FindKnownSerializers(
Expression expression)
{
var knownSerializers = new KnownSerializerMap();
return FindKnownSerializers(expression, knownSerializers);
}

public static KnownSerializerMap FindKnownSerializers(
Expression expression,
Expression initialNode,
IBsonSerializer knownSerializer)
{
var knownSerializers = new KnownSerializerMap();
knownSerializers.AddSerializer(initialNode, knownSerializer);
return FindKnownSerializers(expression, knownSerializers);
}

public static KnownSerializerMap FindKnownSerializers(
Expression expression,
(Expression Node, IBsonSerializer KnownSerializer)[] initialNodes)
{
var knownSerializers = new KnownSerializerMap();
foreach (var (initialNode, knownSerializer) in initialNodes)
{
knownSerializers.AddSerializer(initialNode, knownSerializer);

}
return FindKnownSerializers(expression, knownSerializers);
}

public static KnownSerializerMap FindKnownSerializers(
Expression expression,
KnownSerializerMap knownSerializers)
{
var visitor = new KnownSerializerFinderVisitor(knownSerializers);

int oldSerializerCount;
int newSerializerCount;
do
{
visitor.StartNextPass();
oldSerializerCount = knownSerializers.Count;
visitor.Visit(expression);
newSerializerCount = knownSerializers.Count;

// TODO: prevent infinite loop, throw after 100000 passes?
}
while (visitor.Pass == 1 || newSerializerCount > oldSerializerCount); // I don't know yet if this can be done in a single pass

//#if DEBUG
var expressionWithUnknownSerializer = UnknownSerializerFinder.FindExpressionWithUnknownSerializer(expression, knownSerializers);
if (expressionWithUnknownSerializer != null)
{
throw new ExpressionNotSupportedException(expressionWithUnknownSerializer, because: "we were unable to determine which serializer to use for the result");
}
//#endif

return knownSerializers;
}
}
Loading