diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs index 80249075d38..c4ec80709ce 100644 --- a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs @@ -25,11 +25,11 @@ public abstract class ObservableRecipient : ObservableObject /// Initializes a new instance of the class. /// /// - /// This constructor will produce an instance that will use the instance + /// This constructor will produce an instance that will use the instance /// to perform requested operations. It will also be available locally through the property. /// protected ObservableRecipient() - : this(Messaging.Messenger.Default) + : this(WeakReferenceMessenger.Default) { } @@ -78,7 +78,7 @@ public bool IsActive /// /// The base implementation registers all messages for this recipients that have been declared /// explicitly through the interface, using the default channel. - /// For more details on how this works, see the method. + /// For more details on how this works, see the method. /// If you need more fine tuned control, want to register messages individually or just prefer /// the lambda-style syntax for message registration, override this method and register manually. /// diff --git a/Microsoft.Toolkit.Mvvm/Messaging/IMessenger.cs b/Microsoft.Toolkit.Mvvm/Messaging/IMessenger.cs index 39e38d5af01..2e41916e950 100644 --- a/Microsoft.Toolkit.Mvvm/Messaging/IMessenger.cs +++ b/Microsoft.Toolkit.Mvvm/Messaging/IMessenger.cs @@ -7,8 +7,70 @@ namespace Microsoft.Toolkit.Mvvm.Messaging { + /// + /// A used to represent actions to invoke when a message is received. + /// The recipient is given as an input argument to allow message registrations to avoid creating + /// closures: if an instance method on a recipient needs to be invoked it is possible to just + /// cast the recipient to the right type and then access the local method from that instance. + /// + /// The type of recipient for the message. + /// The type of message to receive. + /// The recipient that is receiving the message. + /// The message being received. + public delegate void MessageHandler(TRecipient recipient, TMessage message) + where TRecipient : class + where TMessage : class; + /// /// An interface for a type providing the ability to exchange messages between different objects. + /// This can be useful to decouple different modules of an application without having to keep strong + /// references to types being referenced. It is also possible to send messages to specific channels, uniquely + /// identified by a token, and to have different messengers in different sections of an applications. + /// In order to use the functionalities, first define a message type, like so: + /// + /// public sealed class LoginCompletedMessage { } + /// + /// Then, register your a recipient for this message: + /// + /// Messenger.Default.Register<MyRecipientType, LoginCompletedMessage>(this, (r, m) => + /// { + /// // Handle the message here... + /// }); + /// + /// The message handler here is a lambda expression taking two parameters: the recipient and the message. + /// This is done to avoid the allocations for the closures that would've been generated if the expression + /// had captured the current instance. The recipient type parameter is used so that the recipient can be + /// directly accessed within the handler without the need to manually perform type casts. This allows the + /// code to be less verbose and more reliable, as all the checks are done just at build time. If the handler + /// is defined within the same type as the recipient, it is also possible to directly access private members. + /// This allows the message handler to be a static method, which enables the C# compiler to perform a number + /// of additional memory optimizations (such as caching the delegate, avoiding unnecessary memory allocations). + /// Finally, send a message when needed, like so: + /// + /// Messenger.Default.Send<LoginCompletedMessage>(); + /// + /// Additionally, the method group syntax can also be used to specify the message handler + /// to invoke when receiving a message, if a method with the right signature is available + /// in the current scope. This is helpful to keep the registration and handling logic separate. + /// Following up from the previous example, consider a class having this method: + /// + /// private static void Receive(MyRecipientType recipient, LoginCompletedMessage message) + /// { + /// // Handle the message there + /// } + /// + /// The registration can then be performed in a single line like so: + /// + /// Messenger.Default.Register(this, Receive); + /// + /// The C# compiler will automatically convert that expression to a instance + /// compatible with . + /// This will also work if multiple overloads of that method are available, each handling a different + /// message type: the C# compiler will automatically pick the right one for the current message type. + /// It is also possible to register message handlers explicitly using the interface. + /// To do so, the recipient just needs to implement the interface and then call the + /// extension, which will automatically register + /// all the handlers that are declared by the recipient type. Registration for individual handlers is supported as well. /// public interface IMessenger { @@ -28,13 +90,15 @@ bool IsRegistered(object recipient, TToken token) /// /// Registers a recipient for a given type of message. /// + /// The type of recipient for the message. /// The type of message to receive. /// The type of token to use to pick the messages to receive. /// The recipient that will receive the messages. /// A token used to determine the receiving channel to use. - /// The to invoke when a message is received. + /// The to invoke when a message is received. /// Thrown when trying to register the same message twice. - void Register(object recipient, TToken token, Action action) + void Register(TRecipient recipient, TToken token, MessageHandler handler) + where TRecipient : class where TMessage : class where TToken : IEquatable; @@ -83,6 +147,14 @@ TMessage Send(TMessage message, TToken token) where TMessage : class where TToken : IEquatable; + /// + /// Performs a cleanup on the current messenger. + /// Invoking this method does not unregister any of the currently registered + /// recipient, and it can be used to perform cleanup operations such as + /// trimming the internal data structures of a messenger implementation. + /// + void Cleanup(); + /// /// Resets the instance and unregisters all the existing recipients. /// diff --git a/Microsoft.Toolkit.Mvvm/Messaging/MessengerExtensions.cs b/Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs similarity index 83% rename from Microsoft.Toolkit.Mvvm/Messaging/MessengerExtensions.cs rename to Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs index a344b76a4f3..72caa9fb7dd 100644 --- a/Microsoft.Toolkit.Mvvm/Messaging/MessengerExtensions.cs +++ b/Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs @@ -8,19 +8,20 @@ using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; +using Microsoft.Toolkit.Mvvm.Messaging.Internals; namespace Microsoft.Toolkit.Mvvm.Messaging { /// /// Extensions for the type. /// - public static partial class MessengerExtensions + public static class IMessengerExtensions { /// /// A class that acts as a container to load the instance linked to /// the method. /// This class is needed to avoid forcing the initialization code in the static constructor to run as soon as - /// the type is referenced, even if that is done just to use methods + /// the type is referenced, even if that is done just to use methods /// that do not actually require this instance to be available. /// We're effectively using this type to leverage the lazy loading of static constructors done by the runtime. /// @@ -32,7 +33,7 @@ private static class MethodInfos static MethodInfos() { RegisterIRecipient = ( - from methodInfo in typeof(MessengerExtensions).GetMethods() + from methodInfo in typeof(IMessengerExtensions).GetMethods() where methodInfo.Name == nameof(Register) && methodInfo.IsGenericMethod && methodInfo.GetGenericArguments().Length == 2 @@ -174,7 +175,7 @@ static Action GetRegistrationAction(Type type, Metho public static void Register(this IMessenger messenger, IRecipient recipient) where TMessage : class { - messenger.Register(recipient, default, recipient.Receive); + messenger.Register, TMessage, Unit>(recipient, default, (r, m) => r.Receive(m)); } /// @@ -191,7 +192,7 @@ public static void Register(this IMessenger messenger, IRecipi where TMessage : class where TToken : IEquatable { - messenger.Register(recipient, token, recipient.Receive); + messenger.Register, TMessage, TToken>(recipient, token, (r, m) => r.Receive(m)); } /// @@ -200,13 +201,47 @@ public static void Register(this IMessenger messenger, IRecipi /// The type of message to receive. /// The instance to use to register the recipient. /// The recipient that will receive the messages. - /// The to invoke when a message is received. + /// The to invoke when a message is received. /// Thrown when trying to register the same message twice. /// This method will use the default channel to perform the requested registration. - public static void Register(this IMessenger messenger, object recipient, Action action) + public static void Register(this IMessenger messenger, object recipient, MessageHandler handler) where TMessage : class { - messenger.Register(recipient, default(Unit), action); + messenger.Register(recipient, default(Unit), handler); + } + + /// + /// Registers a recipient for a given type of message. + /// + /// The type of recipient for the message. + /// The type of message to receive. + /// The instance to use to register the recipient. + /// The recipient that will receive the messages. + /// The to invoke when a message is received. + /// Thrown when trying to register the same message twice. + /// This method will use the default channel to perform the requested registration. + public static void Register(this IMessenger messenger, TRecipient recipient, MessageHandler handler) + where TRecipient : class + where TMessage : class + { + messenger.Register(recipient, default(Unit), handler); + } + + /// + /// Registers a recipient for a given type of message. + /// + /// The type of message to receive. + /// The type of token to use to pick the messages to receive. + /// The instance to use to register the recipient. + /// The recipient that will receive the messages. + /// A token used to determine the receiving channel to use. + /// The to invoke when a message is received. + /// Thrown when trying to register the same message twice. + public static void Register(this IMessenger messenger, object recipient, TToken token, MessageHandler handler) + where TMessage : class + where TToken : IEquatable + { + messenger.Register(recipient, token, handler); } /// diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/DictionarySlim{TKey,TValue}.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/DictionarySlim{TKey,TValue}.cs similarity index 95% rename from Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/DictionarySlim{TKey,TValue}.cs rename to Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/DictionarySlim{TKey,TValue}.cs index ab4706de5f7..11012a1e350 100644 --- a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/DictionarySlim{TKey,TValue}.cs +++ b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/DictionarySlim{TKey,TValue}.cs @@ -130,7 +130,11 @@ public void Clear() this.entries = InitialEntries; } - /// + /// + /// Checks whether or not the dictionary contains a pair with a specified key. + /// + /// The key to look for. + /// Whether or not the key was present in the dictionary. public bool ContainsKey(TKey key) { Entry[] entries = this.entries; @@ -176,7 +180,18 @@ public bool TryGetValue(TKey key, out TValue? value) } /// - public bool TryRemove(TKey key, out object? result) + public bool TryRemove(TKey key) + { + return TryRemove(key, out _); + } + + /// + /// Tries to remove a value with a specified key, if present. + /// + /// The key of the value to remove. + /// The removed value, if it was present. + /// Whether or not the key was present. + public bool TryRemove(TKey key, out TValue? result) { Entry[] entries = this.entries; int bucketIndex = key.GetHashCode() & (this.buckets.Length - 1); @@ -218,13 +233,6 @@ public bool TryRemove(TKey key, out object? result) return false; } - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool Remove(TKey key) - { - return TryRemove(key, out _); - } - /// /// Gets the value for the specified key, or, if the key is not present, /// adds an entry and returns the value by ref. This makes it possible to diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/HashHelpers.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/HashHelpers.cs similarity index 100% rename from Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/HashHelpers.cs rename to Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/HashHelpers.cs diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim.cs similarity index 100% rename from Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim.cs rename to Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim.cs diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim{TKey,TValue}.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim{TKey,TValue}.cs similarity index 100% rename from Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim{TKey,TValue}.cs rename to Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim{TKey,TValue}.cs diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim{TKey}.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim{TKey}.cs similarity index 57% rename from Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim{TKey}.cs rename to Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim{TKey}.cs index 7de2505148a..eeb4fa75f60 100644 --- a/Microsoft.Toolkit.Mvvm/Messaging/Microsoft.Collections.Extensions/IDictionarySlim{TKey}.cs +++ b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Microsoft.Collections.Extensions/IDictionarySlim{TKey}.cs @@ -14,18 +14,10 @@ internal interface IDictionarySlim : IDictionarySlim where TKey : IEquatable { /// - /// Tries to remove a value with a specified key. + /// Tries to remove a value with a specified key, if present. /// /// The key of the value to remove. - /// The removed value, if it was present. - /// .Whether or not the key was present. - bool TryRemove(TKey key, out object? result); - - /// - /// Removes an item from the dictionary with the specified key, if present. - /// - /// The key of the item to remove. - /// Whether or not an item was removed. - bool Remove(TKey key); + /// Whether or not the key was present. + bool TryRemove(TKey key); } } diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Internals/Type2.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Type2.cs new file mode 100644 index 00000000000..411b8809ba6 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Type2.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.CompilerServices; + +namespace Microsoft.Toolkit.Mvvm.Messaging.Internals +{ + /// + /// A simple type representing an immutable pair of types. + /// + /// + /// This type replaces a simple as it's faster in its + /// and methods, and because + /// unlike a value tuple it exposes its fields as immutable. Additionally, the + /// and fields provide additional clarity reading + /// the code compared to and . + /// + internal readonly struct Type2 : IEquatable + { + /// + /// The type of registered message. + /// + public readonly Type TMessage; + + /// + /// The type of registration token. + /// + public readonly Type TToken; + + /// + /// Initializes a new instance of the struct. + /// + /// The type of registered message. + /// The type of registration token. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Type2(Type tMessage, Type tToken) + { + TMessage = tMessage; + TToken = tToken; + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Equals(Type2 other) + { + // We can't just use reference equality, as that's technically not guaranteed + // to work and might fail in very rare cases (eg. with type forwarding between + // different assemblies). Instead, we can use the == operator to compare for + // equality, which still avoids the callvirt overhead of calling Type.Equals, + // and is also implemented as a JIT intrinsic on runtimes such as .NET Core. + return + TMessage == other.TMessage && + TToken == other.TToken; + } + + /// + public override bool Equals(object? obj) + { + return obj is Type2 other && Equals(other); + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override int GetHashCode() + { + unchecked + { + // To combine the two hashes, we can simply use the fast djb2 hash algorithm. + // This is not a problem in this case since we already know that the base + // RuntimeHelpers.GetHashCode method is providing hashes with a good enough distribution. + int hash = RuntimeHelpers.GetHashCode(TMessage); + + hash = (hash << 5) + hash; + + hash += RuntimeHelpers.GetHashCode(TToken); + + return hash; + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Internals/Unit.cs b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Unit.cs new file mode 100644 index 00000000000..6f136210b34 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/Messaging/Internals/Unit.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.CompilerServices; + +namespace Microsoft.Toolkit.Mvvm.Messaging.Internals +{ + /// + /// An empty type representing a generic token with no specific value. + /// + internal readonly struct Unit : IEquatable + { + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Equals(Unit other) + { + return true; + } + + /// + public override bool Equals(object? obj) + { + return obj is Unit; + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override int GetHashCode() + { + return 0; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm/Messaging/MessengerExtensions.Unit.cs b/Microsoft.Toolkit.Mvvm/Messaging/MessengerExtensions.Unit.cs deleted file mode 100644 index 67dfbc1c7ea..00000000000 --- a/Microsoft.Toolkit.Mvvm/Messaging/MessengerExtensions.Unit.cs +++ /dev/null @@ -1,41 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.Toolkit.Mvvm.Messaging -{ - /// - /// Extensions for the type. - /// - public static partial class MessengerExtensions - { - /// - /// An empty type representing a generic token with no specific value. - /// - private readonly struct Unit : IEquatable - { - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool Equals(Unit other) - { - return true; - } - - /// - public override bool Equals(object? obj) - { - return obj is Unit; - } - - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public override int GetHashCode() - { - return 0; - } - } - } -} diff --git a/Microsoft.Toolkit.Mvvm/Messaging/Messenger.cs b/Microsoft.Toolkit.Mvvm/Messaging/StrongReferenceMessenger.cs similarity index 57% rename from Microsoft.Toolkit.Mvvm/Messaging/Messenger.cs rename to Microsoft.Toolkit.Mvvm/Messaging/StrongReferenceMessenger.cs index 9c27badd255..5f4d5dccdec 100644 --- a/Microsoft.Toolkit.Mvvm/Messaging/Messenger.cs +++ b/Microsoft.Toolkit.Mvvm/Messaging/StrongReferenceMessenger.cs @@ -8,71 +8,46 @@ using System.Runtime.CompilerServices; using System.Threading; using Microsoft.Collections.Extensions; +using Microsoft.Toolkit.Mvvm.Messaging.Internals; namespace Microsoft.Toolkit.Mvvm.Messaging { /// - /// A type that can be used to exchange messages between different objects. - /// This can be useful to decouple different modules of an application without having to keep strong - /// references to types being referenced. It is also possible to send messages to specific channels, uniquely - /// identified by a token, and to have different messengers in different sections of an applications. - /// In order to use the functionalities, first define a message type, like so: - /// - /// public sealed class LoginCompletedMessage { } - /// - /// Then, register your a recipient for this message: - /// - /// Messenger.Default.Register<LoginCompletedMessage>(this, m => - /// { - /// // Handle the message here... - /// }); - /// - /// Finally, send a message when needed, like so: - /// - /// Messenger.Default.Send<LoginCompletedMessage>(); - /// - /// Additionally, the method group syntax can also be used to specify the action - /// to invoke when receiving a message, if a method with the right signature is available - /// in the current scope. This is helpful to keep the registration and handling logic separate. - /// Following up from the previous example, consider a class having this method: - /// - /// private void Receive(LoginCompletedMessage message) - /// { - /// // Handle the message there - /// } - /// - /// The registration can then be performed in a single line like so: - /// - /// Messenger.Default.Register<LoginCompletedMessage>(this, Receive); - /// - /// The C# compiler will automatically convert that expression to an instance - /// compatible with the method. - /// This will also work if multiple overloads of that method are available, each handling a different - /// message type: the C# compiler will automatically pick the right one for the current message type. - /// For info on the other available features, check the interface. + /// A class providing a reference implementation for the interface. /// - public sealed class Messenger : IMessenger + /// + /// This implementation uses strong references to track the registered + /// recipients, so it is necessary to manually unregister them when they're no longer needed. + /// + public sealed class StrongReferenceMessenger : IMessenger { - // The Messenger class uses the following logic to link stored instances together: + // This messenger uses the following logic to link stored instances together: // -------------------------------------------------------------------------------------------------------- // DictionarySlim> recipientsMap; // | \________________[*]IDictionarySlim> - // | \___ / / / - // | ________(recipients registrations)___________\________/ / __/ - // | / _______(channel registrations)_____\___________________/ / - // | / / \ / - // DictionarySlim>> mapping = Mapping - // / / \ / / - // ___(Type2.tToken)____/ / \______/___________________/ - // /________________(Type2.tMessage)____/ / - // / ________________________________________/ + // | \____________/_________ / + // | ________(recipients registrations)____________________/ \ / + // | / ____(channel registrations)________________\____________/ + // | / / \ + // DictionarySlim>> mapping = Mapping + // / / / + // ___(Type2.TToken)____/ / / + // /________________(Type2.TMessage)________________________/ / + // / ____________________________________________________________/ // / / // DictionarySlim typesMap; // -------------------------------------------------------------------------------------------------------- - // Each combination of results in a concrete Mapping type, which holds - // the references from registered recipients to handlers. The handlers are stored in a > - // dictionary, so that each recipient can have up to one registered handler for a given token, for each - // message type. Each mapping is stored in the types map, which associates each pair of concrete types to its + // Each combination of results in a concrete Mapping type, which holds the references + // from registered recipients to handlers. The handlers are stored in a > dictionary, + // so that each recipient can have up to one registered handler for a given token, for each message type. + // Note that the registered handlers are only stored as object references, even if they were actually of type + // MessageHandler, to avoid unnecessary unsafe casts. Each handler is also generic with respect to the + // recipient type, in order to allow the messenger to track and invoke type-specific handlers without using reflection and + // without having to capture the input handler in a proxy delegate, causing one extra memory allocations and adding overhead. + // This allows users to retain type information on each registered recipient, instead of having to manually cast each recipient + // to the right type within the handler. The type conversion is guaranteed to be respected due to how the messenger type + // itself works - as registered handlers are always invoked on their respective recipients. + // Each mapping is stored in the types map, which associates each pair of concrete types to its // mapping instance. Mapping instances are exposed as IMapping items, as each will be a closed type over // a different combination of TMessage and TToken generic type parameters. Each existing recipient is also stored in // the main recipients map, along with a set of all the existing dictionaries of handlers for that recipient (for all @@ -109,9 +84,9 @@ public sealed class Messenger : IMessenger private readonly DictionarySlim typesMap = new DictionarySlim(); /// - /// Gets the default instance. + /// Gets the default instance. /// - public static Messenger Default { get; } = new Messenger(); + public static StrongReferenceMessenger Default { get; } = new StrongReferenceMessenger(); /// public bool IsRegistered(object recipient, TToken token) @@ -132,7 +107,8 @@ public bool IsRegistered(object recipient, TToken token) } /// - public void Register(object recipient, TToken token, Action action) + public void Register(TRecipient recipient, TToken token, MessageHandler handler) + where TRecipient : class where TMessage : class where TToken : IEquatable { @@ -141,22 +117,20 @@ public void Register(object recipient, TToken token, Action registration list for this recipient Mapping mapping = GetOrAddMapping(); var key = new Recipient(recipient); - ref DictionarySlim>? map = ref mapping.GetOrAddValueRef(key); + ref DictionarySlim? map = ref mapping.GetOrAddValueRef(key); - map ??= new DictionarySlim>(); + map ??= new DictionarySlim(); // Add the new registration entry - ref Action? handler = ref map.GetOrAddValueRef(token); + ref object? registeredHandler = ref map.GetOrAddValueRef(token); - if (!(handler is null)) + if (!(registeredHandler is null)) { ThrowInvalidOperationExceptionForDuplicateRegistration(); } - handler = action; - - // Update the total counter for handlers for the current type parameters - mapping.TotalHandlersCount++; + // Treat the input delegate as if it was covariant (see comments below in the Send method) + registeredHandler = handler; // Make sure this registration map is tracked for the current recipient ref HashSet? set = ref this.recipientsMap.GetOrAddValueRef(key); @@ -183,37 +157,23 @@ public void UnregisterAll(object recipient) // Removes all the lists of registered handlers for the recipient foreach (IMapping mapping in set!) { - if (mapping.TryRemove(key, out object? handlersMap)) + if (mapping.TryRemove(key) && + mapping.Count == 0) { - // If this branch is taken, it means the target recipient to unregister - // had at least one registered handler for the current - // pair of type parameters, which here is masked out by the IMapping interface. - // Before removing the handlers, we need to retrieve the count of how many handlers - // are being removed, in order to update the total counter for the mapping. - // Just casting the dictionary to the base interface and accessing the Count - // property directly gives us O(1) access time to retrieve this count. - // The handlers map is the IDictionary instance for the mapping. - int handlersCount = Unsafe.As(handlersMap).Count; - - mapping.TotalHandlersCount -= handlersCount; - - if (mapping.Count == 0) - { - // Maps here are really of type Mapping<,> and with unknown type arguments. - // If after removing the current recipient a given map becomes empty, it means - // that there are no registered recipients at all for a given pair of message - // and token types. In that case, we also remove the map from the types map. - // The reason for keeping a key in each mapping is that removing items from a - // dictionary (a hashed collection) only costs O(1) in the best case, while - // if we had tried to iterate the whole dictionary every time we would have - // paid an O(n) minimum cost for each single remove operation. - this.typesMap.Remove(mapping.TypeArguments); - } + // Maps here are really of type Mapping<,> and with unknown type arguments. + // If after removing the current recipient a given map becomes empty, it means + // that there are no registered recipients at all for a given pair of message + // and token types. In that case, we also remove the map from the types map. + // The reason for keeping a key in each mapping is that removing items from a + // dictionary (a hashed collection) only costs O(1) in the best case, while + // if we had tried to iterate the whole dictionary every time we would have + // paid an O(n) minimum cost for each single remove operation. + this.typesMap.TryRemove(mapping.TypeArguments, out _); } } // Remove the associated set in the recipients map - this.recipientsMap.Remove(key); + this.recipientsMap.TryRemove(key, out _); } } @@ -222,7 +182,7 @@ public void UnregisterAll(object recipient, TToken token) where TToken : IEquatable { bool lockTaken = false; - IDictionarySlim>[]? maps = null; + object[]? maps = null; int i = 0; // We use an explicit try/finally block here instead of the lock syntax so that we can use a single @@ -243,11 +203,16 @@ public void UnregisterAll(object recipient, TToken token) return; } - // Copy the candidate mappings for the target recipient to a local - // array, as we can't modify the contents of the set while iterating it. - // The rented buffer is oversized and will also include mappings for - // handlers of messages that are registered through a different token. - maps = ArrayPool>>.Shared.Rent(set!.Count); + // Copy the candidate mappings for the target recipient to a local array, as we can't modify the + // contents of the set while iterating it. The rented buffer is oversized and will also include + // mappings for handlers of messages that are registered through a different token. Note that + // we're using just an object array to minimize the number of total rented buffers, that would + // just remain in the shared pool unused, other than when they are rented here. Instead, we're + // using a type that would possibly also be used by the users of the library, which increases + // the opportunities to reuse existing buffers for both. When we need to reference an item + // stored in the buffer with the type we know it will have, we use Unsafe.As to avoid the + // expensive type check in the cast, since we already know the assignment will be valid. + maps = ArrayPool.Shared.Rent(set!.Count); foreach (IMapping item in set) { @@ -265,8 +230,10 @@ public void UnregisterAll(object recipient, TToken token) // without having to know the concrete type in advance, and without having // to deal with reflection: we can just check if the type of the closed interface // matches with the token type currently in use, and operate on those instances. - foreach (IDictionarySlim> map in maps.AsSpan(0, i)) + foreach (object obj in maps.AsSpan(0, i)) { + var map = Unsafe.As>>(obj); + // We don't need whether or not the map contains the recipient, as the // sequence of maps has already been copied from the set containing all // the mappings for the target recipients: it is guaranteed to be here. @@ -274,36 +241,22 @@ public void UnregisterAll(object recipient, TToken token) // Try to remove the registered handler for the input token, // for the current message type (unknown from here). - if (holder.Remove(token)) + if (holder.TryRemove(token) && + holder.Count == 0) { - // As above, we need to update the total number of registered handlers for the map. - // In this case we also know that the current TToken type parameter is of interest - // for the current method, as we're only unsubscribing handlers using that token. - // This is because we're already working on the final mapping, - // which associates a single handler with a given token, for a given recipient. - // This means that we don't have to retrieve the count to subtract in this case, - // we're just removing a single handler at a time. So, we just decrement the total. - Unsafe.As(map).TotalHandlersCount--; - - if (holder.Count == 0) + // If the map is empty, remove the recipient entirely from its container + map.TryRemove(key); + + // If no handlers are left at all for the recipient, across all + // message types and token types, remove the set of mappings + // entirely for the current recipient, and lost the strong + // reference to it as well. This is the same situation that + // would've been achieved by just calling UnregisterAll(recipient). + if (map.Count == 0 && + set.Remove(Unsafe.As(map)) && + set.Count == 0) { - // If the map is empty, remove the recipient entirely from its container - map.Remove(key); - - if (map.Count == 0) - { - // If no handlers are left at all for the recipient, across all - // message types and token types, remove the set of mappings - // entirely for the current recipient, and lost the strong - // reference to it as well. This is the same situation that - // would've been achieved by just calling UnregisterAll(recipient). - set.Remove(Unsafe.As(map)); - - if (set.Count == 0) - { - this.recipientsMap.Remove(key); - } - } + this.recipientsMap.TryRemove(key, out _); } } } @@ -324,7 +277,7 @@ public void UnregisterAll(object recipient, TToken token) { maps.AsSpan(0, i).Clear(); - ArrayPool>>.Shared.Return(maps); + ArrayPool.Shared.Return(maps); } } } @@ -344,84 +297,92 @@ public void Unregister(object recipient, TToken token) var key = new Recipient(recipient); - if (!mapping!.TryGetValue(key, out DictionarySlim>? dictionary)) + if (!mapping!.TryGetValue(key, out DictionarySlim? dictionary)) { return; } // Remove the target handler - if (dictionary!.Remove(token)) + if (dictionary!.TryRemove(token, out _) && + dictionary.Count == 0) { - // Decrement the total count, as above - mapping.TotalHandlersCount--; - // If the map is empty, it means that the current recipient has no remaining // registered handlers for the current combination, regardless, // of the specific token value (ie. the channel used to receive messages of that type). // We can remove the map entirely from this container, and remove the link to the map itself // to the current mapping between existing registered recipients (or entire recipients too). - if (dictionary.Count == 0) - { - mapping.Remove(key); + mapping.TryRemove(key, out _); - HashSet set = this.recipientsMap[key]; + HashSet set = this.recipientsMap[key]; - set.Remove(mapping); - - if (set.Count == 0) - { - this.recipientsMap.Remove(key); - } + if (set.Remove(mapping) && + set.Count == 0) + { + this.recipientsMap.TryRemove(key, out _); } } } } /// - public TMessage Send(TMessage message, TToken token) + public unsafe TMessage Send(TMessage message, TToken token) where TMessage : class where TToken : IEquatable { - Action[] entries; + object[] handlers; + object[] recipients; + ref object handlersRef = ref Unsafe.AsRef(null); + ref object recipientsRef = ref Unsafe.AsRef(null); int i = 0; lock (this.recipientsMap) { // Check whether there are any registered recipients - if (!TryGetMapping(out Mapping? mapping)) + _ = TryGetMapping(out Mapping? mapping); + + // We need to make a local copy of the currently registered handlers, since users might + // try to unregister (or register) new handlers from inside one of the currently existing + // handlers. We can use memory pooling to reuse arrays, to minimize the average memory + // usage. In practice, we usually just need to pay the small overhead of copying the items. + // The current mapping contains all the currently registered recipients and handlers for + // the combination in use. In the worst case scenario, all recipients + // will have a registered handler with a token matching the input one, meaning that we could + // have at worst a number of pending handlers to invoke equal to the total number of recipient + // in the mapping. This relies on the fact that tokens are unique, and that there is only + // one handler associated with a given token. We can use this upper bound as the requested + // size for each array rented from the pool, which guarantees that we'll have enough space. + int totalHandlersCount = mapping?.Count ?? 0; + + if (totalHandlersCount == 0) { return message; } - // We need to make a local copy of the currently registered handlers, - // since users might try to unregister (or register) new handlers from - // inside one of the currently existing handlers. We can use memory pooling - // to reuse arrays, to minimize the average memory usage. In practice, - // we usually just need to pay the small overhead of copying the items. - entries = ArrayPool>.Shared.Rent(mapping!.TotalHandlersCount); + handlers = ArrayPool.Shared.Rent(totalHandlersCount); + recipients = ArrayPool.Shared.Rent(totalHandlersCount); + handlersRef = ref handlers[0]; + recipientsRef = ref recipients[0]; // Copy the handlers to the local collection. - // Both types being enumerate expose a struct enumerator, - // so we're not actually allocating the enumerator here. // The array is oversized at this point, since it also includes // handlers for different tokens. We can reuse the same variable // to count the number of matching handlers to invoke later on. - // This will be the array slice with valid actions in the rented buffer. - var mappingEnumerator = mapping.GetEnumerator(); + // This will be the array slice with valid handler in the rented buffer. + var mappingEnumerator = mapping!.GetEnumerator(); // Explicit enumerator usage here as we're using a custom one // that doesn't expose the single standard Current property. while (mappingEnumerator.MoveNext()) { - var pairsEnumerator = mappingEnumerator.Value.GetEnumerator(); + object recipient = mappingEnumerator.Key.Target; - while (pairsEnumerator.MoveNext()) + // Pick the target handler, if the token is a match for the recipient + if (mappingEnumerator.Value.TryGetValue(token, out object? handler)) { - // Only select the ones with a matching token - if (pairsEnumerator.Key.Equals(token)) - { - entries[i++] = pairsEnumerator.Value; - } + // We can manually offset here to skip the bounds checks in this inner loop when + // indexing the array (the size is already verified and guaranteed to be enough). + Unsafe.Add(ref handlersRef, (IntPtr)(void*)(uint)i) = handler!; + Unsafe.Add(ref recipientsRef, (IntPtr)(void*)(uint)i++) = recipient; } } } @@ -429,23 +390,44 @@ public TMessage Send(TMessage message, TToken token) try { // Invoke all the necessary handlers on the local copy of entries - foreach (var entry in entries.AsSpan(0, i)) + for (int j = 0; j < i; j++) { - entry(message); + // We're doing an unsafe cast to skip the type checks again. + // See the comments in the UnregisterAll method for more info. + object handler = Unsafe.Add(ref handlersRef, (IntPtr)(void*)(uint)j); + object recipient = Unsafe.Add(ref recipientsRef, (IntPtr)(void*)(uint)j); + + // Here we perform an unsafe cast to enable covariance for delegate types. + // We know that the input recipient will always respect the type constraints + // of each original input delegate, and doing so allows us to still invoke + // them all from here without worrying about specific generic type arguments. + Unsafe.As>(handler)(recipient, message); } } finally { // As before, we also need to clear it first to avoid having potentially long // lasting memory leaks due to leftover references being stored in the pool. - entries.AsSpan(0, i).Clear(); + handlers.AsSpan(0, i).Clear(); + recipients.AsSpan(0, i).Clear(); - ArrayPool>.Shared.Return(entries); + ArrayPool.Shared.Return(handlers); + ArrayPool.Shared.Return(recipients); } return message; } + /// + void IMessenger.Cleanup() + { + // The current implementation doesn't require any kind of cleanup operation, as + // all the internal data structures are already kept in sync whenever a recipient + // is added or removed. This method is implemented through an explicit interface + // implementation so that developers using this type directly will not see it in + // the API surface (as it wouldn't be useful anyway, since it's a no-op here). + } + /// public void Reset() { @@ -515,7 +497,7 @@ private Mapping GetOrAddMapping() /// This type is defined for simplicity and as a workaround for the lack of support for using type aliases /// over open generic types in C# (using type aliases can only be used for concrete, closed types). /// - private sealed class Mapping : DictionarySlim>>, IMapping + private sealed class Mapping : DictionarySlim>, IMapping where TMessage : class where TToken : IEquatable { @@ -529,9 +511,6 @@ public Mapping() /// public Type2 TypeArguments { get; } - - /// - public int TotalHandlersCount { get; set; } } /// @@ -544,11 +523,6 @@ private interface IMapping : IDictionarySlim /// Gets the instance representing the current type arguments. /// Type2 TypeArguments { get; } - - /// - /// Gets or sets the total number of handlers in the current instance. - /// - int TotalHandlersCount { get; set; } } /// @@ -567,7 +541,7 @@ private interface IMapping : IDictionarySlim /// /// The registered recipient. /// - private readonly object target; + public readonly object Target; /// /// Initializes a new instance of the struct. @@ -576,14 +550,14 @@ private interface IMapping : IDictionarySlim [MethodImpl(MethodImplOptions.AggressiveInlining)] public Recipient(object target) { - this.target = target; + Target = target; } /// [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool Equals(Recipient other) { - return ReferenceEquals(this.target, other.target); + return ReferenceEquals(Target, other.Target); } /// @@ -596,81 +570,7 @@ public override bool Equals(object? obj) [MethodImpl(MethodImplOptions.AggressiveInlining)] public override int GetHashCode() { - return RuntimeHelpers.GetHashCode(this.target); - } - } - - /// - /// A simple type representing an immutable pair of types. - /// - /// - /// This type replaces a simple as it's faster in its - /// and methods, and because - /// unlike a value tuple it exposes its fields as immutable. Additionally, the - /// and fields provide additional clarity reading - /// the code compared to and . - /// - private readonly struct Type2 : IEquatable - { - /// - /// The type of registered message. - /// - private readonly Type tMessage; - - /// - /// The type of registration token. - /// - private readonly Type tToken; - - /// - /// Initializes a new instance of the struct. - /// - /// The type of registered message. - /// The type of registration token. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Type2(Type tMessage, Type tToken) - { - this.tMessage = tMessage; - this.tToken = tToken; - } - - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool Equals(Type2 other) - { - // We can't just use reference equality, as that's technically not guaranteed - // to work and might fail in very rare cases (eg. with type forwarding between - // different assemblies). Instead, we can use the == operator to compare for - // equality, which still avoids the callvirt overhead of calling Type.Equals, - // and is also implemented as a JIT intrinsic on runtimes such as .NET Core. - return - this.tMessage == other.tMessage && - this.tToken == other.tToken; - } - - /// - public override bool Equals(object? obj) - { - return obj is Type2 other && Equals(other); - } - - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public override int GetHashCode() - { - unchecked - { - // To combine the two hashes, we can simply use the fast djb2 hash algorithm. - // This is not a problem in this case since we already know that the base - // RuntimeHelpers.GetHashCode method is providing hashes with a good enough distribution. - int hash = RuntimeHelpers.GetHashCode(this.tMessage); - - hash = (hash << 5) + hash; - - hash += RuntimeHelpers.GetHashCode(this.tToken); - - return hash; - } + return RuntimeHelpers.GetHashCode(this.Target); } } diff --git a/Microsoft.Toolkit.Mvvm/Messaging/WeakReferenceMessenger.cs b/Microsoft.Toolkit.Mvvm/Messaging/WeakReferenceMessenger.cs new file mode 100644 index 00000000000..72b1683f577 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/Messaging/WeakReferenceMessenger.cs @@ -0,0 +1,488 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Runtime.CompilerServices; +using Microsoft.Collections.Extensions; +using Microsoft.Toolkit.Mvvm.Messaging.Internals; +#if NETSTANDARD2_1 +using RecipientsTable = System.Runtime.CompilerServices.ConditionalWeakTable; +#else +using RecipientsTable = Microsoft.Toolkit.Mvvm.Messaging.WeakReferenceMessenger.ConditionalWeakTable; +#endif + +namespace Microsoft.Toolkit.Mvvm.Messaging +{ + /// + /// A class providing a reference implementation for the interface. + /// + /// + /// This implementation uses weak references to track the registered + /// recipients, so it is not necessary to manually unregister them when they're no longer needed. + /// + public sealed class WeakReferenceMessenger : IMessenger + { + // This messenger uses the following logic to link stored instances together: + // -------------------------------------------------------------------------------------------------------- + // DictionarySlim> mapping + // / / / + // ___(Type2.TToken)___/ / / + // /_________________(Type2.TMessage)______________________/ / + // / ___________________________/ + // / / + // DictionarySlim> recipientsMap; + // -------------------------------------------------------------------------------------------------------- + // Just like in the strong reference variant, each pair of message and token types is used as a key in the + // recipients map. In this case, the values in the dictionary are ConditionalWeakTable<,> instances, that + // link each registered recipient to a map of currently registered handlers, through a weak reference. + // The value in each conditional table is Dictionary>, using + // the same unsafe cast as before to allow the generic handler delegates to be invoked without knowing + // what type each recipient was registered with, and without the need to use reflection. + + /// + /// The map of currently registered recipients for all message types. + /// + private readonly DictionarySlim recipientsMap = new DictionarySlim(); + + /// + /// Gets the default instance. + /// + public static WeakReferenceMessenger Default { get; } = new WeakReferenceMessenger(); + + /// + public bool IsRegistered(object recipient, TToken token) + where TMessage : class + where TToken : IEquatable + { + lock (this.recipientsMap) + { + Type2 type2 = new Type2(typeof(TMessage), typeof(TToken)); + + // Get the conditional table associated with the target recipient, for the current pair + // of token and message types. If it exists, check if there is a matching token. + return + this.recipientsMap.TryGetValue(type2, out RecipientsTable? table) && + table!.TryGetValue(recipient, out IDictionarySlim? mapping) && + Unsafe.As>(mapping).ContainsKey(token); + } + } + + /// + public void Register(TRecipient recipient, TToken token, MessageHandler handler) + where TRecipient : class + where TMessage : class + where TToken : IEquatable + { + lock (this.recipientsMap) + { + Type2 type2 = new Type2(typeof(TMessage), typeof(TToken)); + + // Get the conditional table for the pair of type arguments, or create it if it doesn't exist + ref RecipientsTable? mapping = ref this.recipientsMap.GetOrAddValueRef(type2); + + mapping ??= new RecipientsTable(); + + // Get or create the handlers dictionary for the target recipient + var map = Unsafe.As>(mapping.GetValue(recipient, _ => new DictionarySlim())); + + // Add the new registration entry + ref object? registeredHandler = ref map.GetOrAddValueRef(token); + + if (!(registeredHandler is null)) + { + ThrowInvalidOperationExceptionForDuplicateRegistration(); + } + + // Store the input handler + registeredHandler = handler; + } + } + + /// + public void UnregisterAll(object recipient) + { + lock (this.recipientsMap) + { + var enumerator = this.recipientsMap.GetEnumerator(); + + // Traverse all the existing conditional tables and remove all the ones + // with the target recipient as key. We don't perform a cleanup here, + // as that is responsability of a separate method defined below. + while (enumerator.MoveNext()) + { + enumerator.Value.Remove(recipient); + } + } + } + + /// + public void UnregisterAll(object recipient, TToken token) + where TToken : IEquatable + { + lock (this.recipientsMap) + { + var enumerator = this.recipientsMap.GetEnumerator(); + + // Same as above, with the difference being that this time we only go through + // the conditional tables having a matching token type as key, and that we + // only try to remove handlers with a matching token, if any. + while (enumerator.MoveNext()) + { + if (enumerator.Key.TToken == typeof(TToken) && + enumerator.Value.TryGetValue(recipient, out IDictionarySlim mapping)) + { + Unsafe.As>(mapping).TryRemove(token, out _); + } + } + } + } + + /// + public void Unregister(object recipient, TToken token) + where TMessage : class + where TToken : IEquatable + { + lock (this.recipientsMap) + { + var type2 = new Type2(typeof(TMessage), typeof(TToken)); + var enumerator = this.recipientsMap.GetEnumerator(); + + // Traverse all the existing token and message pairs matching the current type + // arguments, and remove all the handlers with a matching token, as above. + while (enumerator.MoveNext()) + { + if (enumerator.Key.Equals(type2) && + enumerator.Value.TryGetValue(recipient, out IDictionarySlim mapping)) + { + Unsafe.As>(mapping).TryRemove(token, out _); + } + } + } + } + + /// + public TMessage Send(TMessage message, TToken token) + where TMessage : class + where TToken : IEquatable + { + ArrayPoolBufferWriter recipients; + ArrayPoolBufferWriter handlers; + + lock (this.recipientsMap) + { + Type2 type2 = new Type2(typeof(TMessage), typeof(TToken)); + + // Try to get the target table + if (!this.recipientsMap.TryGetValue(type2, out RecipientsTable? table)) + { + return message; + } + + recipients = ArrayPoolBufferWriter.Create(); + handlers = ArrayPoolBufferWriter.Create(); + + // We need a local, temporary copy of all the pending recipients and handlers to + // invoke, to avoid issues with handlers unregistering from messages while we're + // holding the lock. To do this, we can just traverse the conditional table in use + // to enumerate all the existing recipients for the token and message types pair + // corresponding to the generic arguments for this invocation, and then track the + // handlers with a matching token, and their corresponding recipients. + foreach (KeyValuePair pair in table!) + { + var map = Unsafe.As>(pair.Value); + + if (map.TryGetValue(token, out object? handler)) + { + recipients.Add(pair.Key); + handlers.Add(handler!); + } + } + } + + try + { + ReadOnlySpan + recipientsSpan = recipients.Span, + handlersSpan = handlers.Span; + + for (int i = 0; i < recipientsSpan.Length; i++) + { + // Just like in the other messenger, here we need an unsafe cast to be able to + // invoke a generic delegate with a contravariant input argument, with a less + // derived reference, without reflection. This is guaranteed to work by how the + // messenger tracks registered recipients and their associated handlers, so the + // type conversion will always be valid (the recipients are the rigth instances). + Unsafe.As>(handlersSpan[i])(recipientsSpan[i], message); + } + } + finally + { + recipients.Dispose(); + handlers.Dispose(); + } + + return message; + } + + /// + public void Cleanup() + { + lock (this.recipientsMap) + { + using ArrayPoolBufferWriter type2s = ArrayPoolBufferWriter.Create(); + using ArrayPoolBufferWriter emptyRecipients = ArrayPoolBufferWriter.Create(); + + var enumerator = this.recipientsMap.GetEnumerator(); + + // First, we go through all the currently registered pairs of token and message types. + // These represents all the combinations of generic arguments with at least one registered + // handler, with the exception of those with recipients that have already been collected. + while (enumerator.MoveNext()) + { + emptyRecipients.Reset(); + + bool hasAtLeastOneHandler = false; + + // Go through the currently alive recipients to look for those with no handlers left. We track + // the ones we find to remove them outside of the loop (can't modify during enumeration). + foreach (KeyValuePair pair in enumerator.Value) + { + if (pair.Value.Count == 0) + { + emptyRecipients.Add(pair.Key); + } + else + { + hasAtLeastOneHandler = true; + } + } + + // Remove the handler maps for recipients that are still alive but with no handlers + foreach (object recipient in emptyRecipients.Span) + { + enumerator.Value.Remove(recipient); + } + + // Track the type combinations with no recipients or handlers left + if (!hasAtLeastOneHandler) + { + type2s.Add(enumerator.Key); + } + } + + // Remove all the mappings with no handlers left + foreach (Type2 key in type2s.Span) + { + this.recipientsMap.TryRemove(key, out _); + } + } + } + + /// + public void Reset() + { + lock (this.recipientsMap) + { + this.recipientsMap.Clear(); + } + } + +#if !NETSTANDARD2_1 + /// + /// A wrapper for + /// that backports the enumerable support to .NET Standard 2.0 through an auxiliary list. + /// + /// Tke key of items to store in the table. + /// The values to store in the table. + internal sealed class ConditionalWeakTable + where TKey : class + where TValue : class? + { + /// + /// The underlying instance. + /// + private readonly System.Runtime.CompilerServices.ConditionalWeakTable table; + + /// + /// A supporting linked list to store keys in . This is needed to expose + /// the ability to enumerate existing keys when there is no support for that in the BCL. + /// + private readonly LinkedList> keys; + + /// + /// Initializes a new instance of the class. + /// + public ConditionalWeakTable() + { + this.table = new System.Runtime.CompilerServices.ConditionalWeakTable(); + this.keys = new LinkedList>(); + } + + /// + public bool TryGetValue(TKey key, out TValue value) + { + return this.table.TryGetValue(key, out value); + } + + /// + public TValue GetValue(TKey key, System.Runtime.CompilerServices.ConditionalWeakTable.CreateValueCallback createValueCallback) + { + // Get or create the value. When this method returns, the key will be present in the table + TValue value = this.table.GetValue(key, createValueCallback); + + // Check if the list of keys contains the given key. + // If it does, we can just stop here and return the result. + foreach (WeakReference node in this.keys) + { + if (node.TryGetTarget(out TKey? target) && + ReferenceEquals(target, key)) + { + return value; + } + } + + // Add the key to the list of weak references to track it + this.keys.AddFirst(new WeakReference(key)); + + return value; + } + + /// + public bool Remove(TKey key) + { + return this.table.Remove(key); + } + + /// + public IEnumerator> GetEnumerator() + { + for (LinkedListNode>? node = this.keys.First; !(node is null);) + { + LinkedListNode>? next = node.Next; + + // Get the key and value for the current node + if (node.Value.TryGetTarget(out TKey? target) && + this.table.TryGetValue(target!, out TValue value)) + { + yield return new KeyValuePair(target, value); + } + else + { + // If the current key has been collected, trim the list + this.keys.Remove(node); + } + + node = next; + } + } + } +#endif + + /// + /// A simple buffer writer implementation using pooled arrays. + /// + /// The type of items to store in the list. + /// + /// This type is a to avoid the object allocation and to + /// enable the pattern-based support. We aren't worried with consumers not + /// using this type correctly since it's private and only accessible within the parent type. + /// + private ref struct ArrayPoolBufferWriter + { + /// + /// The default buffer size to use to expand empty arrays. + /// + private const int DefaultInitialBufferSize = 128; + + /// + /// The underlying array. + /// + private T[] array; + + /// + /// The starting offset within . + /// + private int index; + + /// + /// Creates a new instance of the struct. + /// + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ArrayPoolBufferWriter Create() + { + return new ArrayPoolBufferWriter { array = ArrayPool.Shared.Rent(DefaultInitialBufferSize) }; + } + + /// + /// Gets a with the current items. + /// + public ReadOnlySpan Span + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => this.array.AsSpan(0, this.index); + } + + /// + /// Adds a new item to the current collection. + /// + /// The item to add. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Add(T item) + { + if (this.index == this.array.Length) + { + ResizeBuffer(); + } + + this.array[this.index++] = item; + } + + /// + /// Resets the underlying array and the stored items. + /// + public void Reset() + { + Array.Clear(this.array, 0, this.index); + + this.index = 0; + } + + /// + /// Resizes when there is no space left for new items. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private void ResizeBuffer() + { + T[] rent = ArrayPool.Shared.Rent(this.index << 2); + + Array.Copy(this.array, 0, rent, 0, this.index); + Array.Clear(this.array, 0, this.index); + + ArrayPool.Shared.Return(this.array); + + this.array = rent; + } + + /// + public void Dispose() + { + Array.Clear(this.array, 0, this.index); + + ArrayPool.Shared.Return(this.array); + } + } + + /// + /// Throws an when trying to add a duplicate handler. + /// + private static void ThrowInvalidOperationExceptionForDuplicateRegistration() + { + throw new InvalidOperationException("The target recipient has already subscribed to the target message"); + } + } +} diff --git a/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj b/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj index d9e4a16cdea..24b6ccc43b5 100644 --- a/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj +++ b/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj @@ -4,6 +4,7 @@ netstandard2.0;netstandard2.1 8.0 enable + true Windows Community Toolkit MVVM Toolkit This package includes a .NET Standard MVVM library with helpers such as: diff --git a/UnitTests/UnitTests.Shared/Mvvm/Test_AsyncRelayCommand{T}.cs b/UnitTests/UnitTests.Shared/Mvvm/Test_AsyncRelayCommand{T}.cs new file mode 100644 index 00000000000..bf6320fa083 --- /dev/null +++ b/UnitTests/UnitTests.Shared/Mvvm/Test_AsyncRelayCommand{T}.cs @@ -0,0 +1,141 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; +using Microsoft.Toolkit.Mvvm.Input; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + [SuppressMessage("StyleCop.CSharp.DocumentationRules", "SA1649", Justification = "Generic type")] + public class Test_AsyncRelayCommandOfT + { + [TestCategory("Mvvm")] + [TestMethod] + public async Task Test_AsyncRelayCommandOfT_AlwaysEnabled() + { + int ticks = 0; + + var command = new AsyncRelayCommand(async s => + { + await Task.Delay(1000); + ticks = int.Parse(s); + await Task.Delay(1000); + }); + + Assert.IsTrue(command.CanExecute(null)); + Assert.IsTrue(command.CanExecute("1")); + + (object, EventArgs) args = default; + + command.CanExecuteChanged += (s, e) => args = (s, e); + + command.NotifyCanExecuteChanged(); + + Assert.AreSame(args.Item1, command); + Assert.AreSame(args.Item2, EventArgs.Empty); + + Assert.IsNull(command.ExecutionTask); + Assert.IsFalse(command.IsRunning); + + Task task = command.ExecuteAsync((object)"42"); + + Assert.IsNotNull(command.ExecutionTask); + Assert.AreSame(command.ExecutionTask, task); + Assert.IsTrue(command.IsRunning); + + await task; + + Assert.IsFalse(command.IsRunning); + + Assert.AreEqual(ticks, 42); + + command.Execute("2"); + + await command.ExecutionTask!; + + Assert.AreEqual(ticks, 2); + + Assert.ThrowsException(() => command.Execute(new object())); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_AsyncRelayCommandOfT_WithCanExecuteFunctionTrue() + { + int ticks = 0; + + var command = new AsyncRelayCommand( + s => + { + ticks = int.Parse(s); + return Task.CompletedTask; + }, s => true); + + Assert.IsTrue(command.CanExecute(null)); + Assert.IsTrue(command.CanExecute("1")); + + command.Execute("42"); + + Assert.AreEqual(ticks, 42); + + command.Execute("2"); + + Assert.AreEqual(ticks, 2); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_AsyncRelayCommandOfT_WithCanExecuteFunctionFalse() + { + int ticks = 0; + + var command = new AsyncRelayCommand( + s => + { + ticks = int.Parse(s); + return Task.CompletedTask; + }, s => false); + + Assert.IsFalse(command.CanExecute(null)); + Assert.IsFalse(command.CanExecute("1")); + + command.Execute("2"); + + Assert.AreEqual(ticks, 0); + + command.Execute("42"); + + Assert.AreEqual(ticks, 0); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_AsyncRelayCommandOfT_NullWithValueType() + { + int n = 0; + + var command = new AsyncRelayCommand(i => + { + n = i; + return Task.CompletedTask; + }); + + // Special case for null value types + Assert.IsTrue(command.CanExecute(null)); + + command = new AsyncRelayCommand( + i => + { + n = i; + return Task.CompletedTask; + }, i => i > 0); + + Assert.ThrowsException(() => command.CanExecute(null)); + } + } +} diff --git a/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.Request.cs b/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.Request.cs index 07fef8d9300..24ef9067ff5 100644 --- a/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.Request.cs +++ b/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.Request.cs @@ -15,13 +15,18 @@ public partial class Test_Messenger { [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_RequestMessage_Ok() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_RequestMessage_Ok(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); + object test = null; - void Receive(NumberRequestMessage m) + void Receive(object recipient, NumberRequestMessage m) { + test = recipient; + Assert.IsFalse(m.HasReceivedResponse); m.Reply(42); @@ -33,28 +38,33 @@ void Receive(NumberRequestMessage m) int result = messenger.Send(); + Assert.AreSame(test, recipient); Assert.AreEqual(result, 42); } [TestCategory("Mvvm")] [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] [ExpectedException(typeof(InvalidOperationException))] - public void Test_Messenger_RequestMessage_Fail_NoReply() + public void Test_Messenger_RequestMessage_Fail_NoReply(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); int result = messenger.Send(); } [TestCategory("Mvvm")] [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] [ExpectedException(typeof(InvalidOperationException))] - public void Test_Messenger_RequestMessage_Fail_MultipleReplies() + public void Test_Messenger_RequestMessage_Fail_MultipleReplies(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - void Receive(NumberRequestMessage m) + void Receive(object recipient, NumberRequestMessage m) { m.Reply(42); m.Reply(42); @@ -71,12 +81,14 @@ public class NumberRequestMessage : RequestMessage [TestCategory("Mvvm")] [TestMethod] - public async Task Test_Messenger_AsyncRequestMessage_Ok_Sync() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public async Task Test_Messenger_AsyncRequestMessage_Ok_Sync(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - void Receive(AsyncNumberRequestMessage m) + void Receive(object recipient, AsyncNumberRequestMessage m) { Assert.IsFalse(m.HasReceivedResponse); @@ -94,9 +106,11 @@ void Receive(AsyncNumberRequestMessage m) [TestCategory("Mvvm")] [TestMethod] - public async Task Test_Messenger_AsyncRequestMessage_Ok_Async() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public async Task Test_Messenger_AsyncRequestMessage_Ok_Async(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); async Task GetNumberAsync() @@ -106,7 +120,7 @@ async Task GetNumberAsync() return 42; } - void Receive(AsyncNumberRequestMessage m) + void Receive(object recipient, AsyncNumberRequestMessage m) { Assert.IsFalse(m.HasReceivedResponse); @@ -124,23 +138,27 @@ void Receive(AsyncNumberRequestMessage m) [TestCategory("Mvvm")] [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] [ExpectedException(typeof(InvalidOperationException))] - public async Task Test_Messenger_AsyncRequestMessage_Fail_NoReply() + public async Task Test_Messenger_AsyncRequestMessage_Fail_NoReply(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); int result = await messenger.Send(); } [TestCategory("Mvvm")] [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] [ExpectedException(typeof(InvalidOperationException))] - public async Task Test_Messenger_AsyncRequestMessage_Fail_MultipleReplies() + public async Task Test_Messenger_AsyncRequestMessage_Fail_MultipleReplies(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - void Receive(AsyncNumberRequestMessage m) + void Receive(object recipient, AsyncNumberRequestMessage m) { m.Reply(42); m.Reply(42); @@ -157,12 +175,14 @@ public class AsyncNumberRequestMessage : AsyncRequestMessage [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_CollectionRequestMessage_Ok_NoReplies() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_CollectionRequestMessage_Ok_NoReplies(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - void Receive(NumbersCollectionRequestMessage m) + void Receive(object recipient, NumbersCollectionRequestMessage m) { } @@ -175,17 +195,36 @@ void Receive(NumbersCollectionRequestMessage m) [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_CollectionRequestMessage_Ok_MultipleReplies() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_CollectionRequestMessage_Ok_MultipleReplies(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); object recipient1 = new object(), recipient2 = new object(), - recipient3 = new object(); + recipient3 = new object(), + r1 = null, + r2 = null, + r3 = null; - void Receive1(NumbersCollectionRequestMessage m) => m.Reply(1); - void Receive2(NumbersCollectionRequestMessage m) => m.Reply(2); - void Receive3(NumbersCollectionRequestMessage m) => m.Reply(3); + void Receive1(object recipient, NumbersCollectionRequestMessage m) + { + r1 = recipient; + m.Reply(1); + } + + void Receive2(object recipient, NumbersCollectionRequestMessage m) + { + r2 = recipient; + m.Reply(2); + } + + void Receive3(object recipient, NumbersCollectionRequestMessage m) + { + r3 = recipient; + m.Reply(3); + } messenger.Register(recipient1, Receive1); messenger.Register(recipient2, Receive2); @@ -198,6 +237,10 @@ public void Test_Messenger_CollectionRequestMessage_Ok_MultipleReplies() responses.Add(response); } + Assert.AreSame(r1, recipient1); + Assert.AreSame(r2, recipient2); + Assert.AreSame(r3, recipient3); + CollectionAssert.AreEquivalent(responses, new[] { 1, 2, 3 }); } @@ -207,12 +250,14 @@ public class NumbersCollectionRequestMessage : CollectionRequestMessage [TestCategory("Mvvm")] [TestMethod] - public async Task Test_Messenger_AsyncCollectionRequestMessage_Ok_NoReplies() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public async Task Test_Messenger_AsyncCollectionRequestMessage_Ok_NoReplies(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - void Receive(AsyncNumbersCollectionRequestMessage m) + void Receive(object recipient, AsyncNumbersCollectionRequestMessage m) { } @@ -225,9 +270,11 @@ void Receive(AsyncNumbersCollectionRequestMessage m) [TestCategory("Mvvm")] [TestMethod] - public async Task Test_Messenger_AsyncCollectionRequestMessage_Ok_MultipleReplies() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public async Task Test_Messenger_AsyncCollectionRequestMessage_Ok_MultipleReplies(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); object recipient1 = new object(), recipient2 = new object(), @@ -241,10 +288,10 @@ async Task GetNumberAsync() return 3; } - void Receive1(AsyncNumbersCollectionRequestMessage m) => m.Reply(1); - void Receive2(AsyncNumbersCollectionRequestMessage m) => m.Reply(Task.FromResult(2)); - void Receive3(AsyncNumbersCollectionRequestMessage m) => m.Reply(GetNumberAsync()); - void Receive4(AsyncNumbersCollectionRequestMessage m) => m.Reply(_ => GetNumberAsync()); + void Receive1(object recipient, AsyncNumbersCollectionRequestMessage m) => m.Reply(1); + void Receive2(object recipient, AsyncNumbersCollectionRequestMessage m) => m.Reply(Task.FromResult(2)); + void Receive3(object recipient, AsyncNumbersCollectionRequestMessage m) => m.Reply(GetNumberAsync()); + void Receive4(object recipient, AsyncNumbersCollectionRequestMessage m) => m.Reply(_ => GetNumberAsync()); messenger.Register(recipient1, Receive1); messenger.Register(recipient2, Receive2); diff --git a/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.cs b/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.cs index 33d6e8cd19b..f0939b8859b 100644 --- a/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.cs +++ b/UnitTests/UnitTests.Shared/Mvvm/Test_Messenger.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Linq; +using System.Reflection; using Microsoft.Toolkit.Mvvm.Messaging; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -13,9 +15,11 @@ public partial class Test_Messenger { [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_UnregisterRecipientWithMessageType() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_UnregisterRecipientWithMessageType(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); messenger.Unregister(recipient); @@ -23,9 +27,11 @@ public void Test_Messenger_UnregisterRecipientWithMessageType() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_UnregisterRecipientWithMessageTypeAndToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_UnregisterRecipientWithMessageTypeAndToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); messenger.Unregister(recipient, nameof(MessageA)); @@ -33,9 +39,11 @@ public void Test_Messenger_UnregisterRecipientWithMessageTypeAndToken() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_UnregisterRecipientWithToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_UnregisterRecipientWithToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); messenger.UnregisterAll(recipient, nameof(MessageA)); @@ -43,9 +51,11 @@ public void Test_Messenger_UnregisterRecipientWithToken() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_UnregisterRecipientWithRecipient() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_UnregisterRecipientWithRecipient(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); messenger.UnregisterAll(recipient); @@ -53,12 +63,14 @@ public void Test_Messenger_UnregisterRecipientWithRecipient() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_RegisterAndUnregisterRecipientWithMessageType() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_RegisterAndUnregisterRecipientWithMessageType(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - messenger.Register(recipient, m => { }); + messenger.Register(recipient, (r, m) => { }); messenger.Unregister(recipient); @@ -67,12 +79,14 @@ public void Test_Messenger_RegisterAndUnregisterRecipientWithMessageType() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_RegisterAndUnregisterRecipientWithMessageTypeAndToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_RegisterAndUnregisterRecipientWithMessageTypeAndToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - messenger.Register(recipient, nameof(MessageA), m => { }); + messenger.Register(recipient, nameof(MessageA), (r, m) => { }); messenger.Unregister(recipient, nameof(MessageA)); @@ -81,12 +95,14 @@ public void Test_Messenger_RegisterAndUnregisterRecipientWithMessageTypeAndToken [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_RegisterAndUnregisterRecipientWithToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_RegisterAndUnregisterRecipientWithToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - messenger.Register(recipient, nameof(MessageA), m => { }); + messenger.Register(recipient, nameof(MessageA), (r, m) => { }); messenger.UnregisterAll(recipient, nameof(MessageA)); @@ -95,12 +111,14 @@ public void Test_Messenger_RegisterAndUnregisterRecipientWithToken() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_RegisterAndUnregisterRecipientWithRecipient() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_RegisterAndUnregisterRecipientWithRecipient(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - messenger.Register(recipient, nameof(MessageA), m => { }); + messenger.Register(recipient, nameof(MessageA), (r, m) => { }); messenger.UnregisterAll(recipient); @@ -109,120 +127,145 @@ public void Test_Messenger_RegisterAndUnregisterRecipientWithRecipient() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_IsRegistered_Register_Send_UnregisterOfTMessage_WithNoToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_IsRegistered_Register_Send_UnregisterOfTMessage_WithNoToken(Type type) { + var messenger = (IMessenger)Activator.CreateInstance(type); object a = new object(); - Assert.IsFalse(Messenger.Default.IsRegistered(a)); + Assert.IsFalse(messenger.IsRegistered(a)); + object recipient = null; string result = null; - Messenger.Default.Register(a, m => result = m.Text); - Assert.IsTrue(Messenger.Default.IsRegistered(a)); + messenger.Register(a, (r, m) => + { + recipient = r; + result = m.Text; + }); + + Assert.IsTrue(messenger.IsRegistered(a)); - Messenger.Default.Send(new MessageA { Text = nameof(MessageA) }); + messenger.Send(new MessageA { Text = nameof(MessageA) }); + Assert.AreSame(recipient, a); Assert.AreEqual(result, nameof(MessageA)); - Messenger.Default.Unregister(a); + messenger.Unregister(a); - Assert.IsFalse(Messenger.Default.IsRegistered(a)); + Assert.IsFalse(messenger.IsRegistered(a)); + recipient = null; result = null; - Messenger.Default.Send(new MessageA { Text = nameof(MessageA) }); + messenger.Send(new MessageA { Text = nameof(MessageA) }); + + Assert.IsNull(recipient); Assert.IsNull(result); } [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_IsRegistered_Register_Send_UnregisterRecipient_WithNoToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_IsRegistered_Register_Send_UnregisterRecipient_WithNoToken(Type type) { + var messenger = (IMessenger)Activator.CreateInstance(type); object a = new object(); - Assert.IsFalse(Messenger.Default.IsRegistered(a)); + Assert.IsFalse(messenger.IsRegistered(a)); string result = null; - Messenger.Default.Register(a, m => result = m.Text); + messenger.Register(a, (r, m) => result = m.Text); - Assert.IsTrue(Messenger.Default.IsRegistered(a)); + Assert.IsTrue(messenger.IsRegistered(a)); - Messenger.Default.Send(new MessageA { Text = nameof(MessageA) }); + messenger.Send(new MessageA { Text = nameof(MessageA) }); Assert.AreEqual(result, nameof(MessageA)); - Messenger.Default.UnregisterAll(a); + messenger.UnregisterAll(a); - Assert.IsFalse(Messenger.Default.IsRegistered(a)); + Assert.IsFalse(messenger.IsRegistered(a)); result = null; - Messenger.Default.Send(new MessageA { Text = nameof(MessageA) }); + messenger.Send(new MessageA { Text = nameof(MessageA) }); Assert.IsNull(result); } [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_IsRegistered_Register_Send_UnregisterOfTMessage_WithToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_IsRegistered_Register_Send_UnregisterOfTMessage_WithToken(Type type) { + var messenger = (IMessenger)Activator.CreateInstance(type); object a = new object(); - Assert.IsFalse(Messenger.Default.IsRegistered(a)); + Assert.IsFalse(messenger.IsRegistered(a)); string result = null; - Messenger.Default.Register(a, nameof(MessageA), m => result = m.Text); + messenger.Register(a, nameof(MessageA), (r, m) => result = m.Text); - Assert.IsTrue(Messenger.Default.IsRegistered(a, nameof(MessageA))); + Assert.IsTrue(messenger.IsRegistered(a, nameof(MessageA))); - Messenger.Default.Send(new MessageA { Text = nameof(MessageA) }, nameof(MessageA)); + messenger.Send(new MessageA { Text = nameof(MessageA) }, nameof(MessageA)); Assert.AreEqual(result, nameof(MessageA)); - Messenger.Default.Unregister(a, nameof(MessageA)); + messenger.Unregister(a, nameof(MessageA)); - Assert.IsFalse(Messenger.Default.IsRegistered(a, nameof(MessageA))); + Assert.IsFalse(messenger.IsRegistered(a, nameof(MessageA))); result = null; - Messenger.Default.Send(new MessageA { Text = nameof(MessageA) }, nameof(MessageA)); + messenger.Send(new MessageA { Text = nameof(MessageA) }, nameof(MessageA)); Assert.IsNull(result); } [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_DuplicateRegistrationWithMessageType() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_DuplicateRegistrationWithMessageType(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - messenger.Register(recipient, m => { }); + messenger.Register(recipient, (r, m) => { }); Assert.ThrowsException(() => { - messenger.Register(recipient, m => { }); + messenger.Register(recipient, (r, m) => { }); }); } [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_DuplicateRegistrationWithMessageTypeAndToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_DuplicateRegistrationWithMessageTypeAndToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new object(); - messenger.Register(recipient, nameof(MessageA), m => { }); + messenger.Register(recipient, nameof(MessageA), (r, m) => { }); Assert.ThrowsException(() => { - messenger.Register(recipient, nameof(MessageA), m => { }); + messenger.Register(recipient, nameof(MessageA), (r, m) => { }); }); } [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_IRecipient_NoMessages() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_IRecipient_NoMessages(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new RecipientWithNoMessages(); messenger.RegisterAll(recipient); @@ -233,9 +276,11 @@ public void Test_Messenger_IRecipient_NoMessages() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_IRecipient_SomeMessages_NoToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_IRecipient_SomeMessages_NoToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new RecipientWithSomeMessages(); messenger.RegisterAll(recipient); @@ -264,9 +309,11 @@ public void Test_Messenger_IRecipient_SomeMessages_NoToken() [TestCategory("Mvvm")] [TestMethod] - public void Test_Messenger_IRecipient_SomeMessages_WithToken() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_IRecipient_SomeMessages_WithToken(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var recipient = new RecipientWithSomeMessages(); var token = nameof(Test_Messenger_IRecipient_SomeMessages_WithToken); @@ -297,12 +344,182 @@ public void Test_Messenger_IRecipient_SomeMessages_WithToken() Assert.IsFalse(messenger.IsRegistered(recipient)); } + [TestCategory("Mvvm")] + [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_RegisterWithTypeParameter(Type type) + { + var messenger = (IMessenger)Activator.CreateInstance(type); + var recipient = new RecipientWithNoMessages { Number = 42 }; + + int number = 0; + + messenger.Register(recipient, (r, m) => number = r.Number); + + messenger.Send(); + + Assert.AreEqual(number, 42); + } + + [TestCategory("Mvvm")] + [TestMethod] + [DataRow(typeof(StrongReferenceMessenger), false)] + [DataRow(typeof(WeakReferenceMessenger), true)] + public void Test_Messenger_Collect_Test(Type type, bool isWeak) + { + var messenger = (IMessenger)Activator.CreateInstance(type); + + WeakReference weakRecipient; + + void Test() + { + var recipient = new RecipientWithNoMessages { Number = 42 }; + weakRecipient = new WeakReference(recipient); + + messenger.Register(recipient, (r, m) => { }); + + Assert.IsTrue(messenger.IsRegistered(recipient)); + Assert.IsTrue(weakRecipient.IsAlive); + + GC.KeepAlive(recipient); + } + + Test(); + + GC.Collect(); + + Assert.AreEqual(!isWeak, weakRecipient.IsAlive); + } + + [TestCategory("Mvvm")] + [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_Reset(Type type) + { + var messenger = (IMessenger)Activator.CreateInstance(type); + var recipient = new RecipientWithSomeMessages(); + + messenger.RegisterAll(recipient); + + Assert.IsTrue(messenger.IsRegistered(recipient)); + Assert.IsTrue(messenger.IsRegistered(recipient)); + + messenger.Reset(); + + Assert.IsFalse(messenger.IsRegistered(recipient)); + Assert.IsFalse(messenger.IsRegistered(recipient)); + } + + [TestCategory("Mvvm")] + [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_Default(Type type) + { + PropertyInfo defaultInfo = type.GetProperty("Default"); + + var default1 = defaultInfo!.GetValue(null); + var default2 = defaultInfo!.GetValue(null); + + Assert.IsNotNull(default1); + Assert.IsNotNull(default2); + Assert.AreSame(default1, default2); + } + + [TestCategory("Mvvm")] + [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_Cleanup(Type type) + { + var messenger = (IMessenger)Activator.CreateInstance(type); + var recipient = new RecipientWithSomeMessages(); + + messenger.Register(recipient); + + Assert.IsTrue(messenger.IsRegistered(recipient)); + + void Test() + { + var recipient2 = new RecipientWithSomeMessages(); + + messenger.Register(recipient2); + + Assert.IsTrue(messenger.IsRegistered(recipient2)); + + GC.KeepAlive(recipient2); + } + + Test(); + + GC.Collect(); + + // Here we just check that calling Cleanup doesn't alter the state + // of the messenger. This method shouldn't really do anything visible + // to consumers, it's just a way for messengers to compact their data. + messenger.Cleanup(); + + Assert.IsTrue(messenger.IsRegistered(recipient)); + } + + [TestCategory("Mvvm")] + [TestMethod] + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_Messenger_ManyRecipients(Type type) + { + var messenger = (IMessenger)Activator.CreateInstance(type); + + void Test() + { + var recipients = Enumerable.Range(0, 512).Select(_ => new RecipientWithSomeMessages()).ToArray(); + + foreach (var recipient in recipients) + { + messenger.RegisterAll(recipient); + } + + foreach (var recipient in recipients) + { + Assert.IsTrue(messenger.IsRegistered(recipient)); + Assert.IsTrue(messenger.IsRegistered(recipient)); + } + + messenger.Send(); + messenger.Send(); + messenger.Send(); + + foreach (var recipient in recipients) + { + Assert.AreEqual(recipient.As, 1); + Assert.AreEqual(recipient.Bs, 2); + } + + foreach (ref var recipient in recipients.AsSpan()) + { + recipient = null; + } + } + + Test(); + + GC.Collect(); + + // Just invoke a final cleanup to improve coverage, this is unrelated to this test in particular + messenger.Cleanup(); + } + public sealed class RecipientWithNoMessages { + public int Number { get; set; } } - public sealed class RecipientWithSomeMessages - : IRecipient, ICloneable, IRecipient + public sealed class RecipientWithSomeMessages : + IRecipient, + IRecipient, + ICloneable { public int As { get; private set; } diff --git a/UnitTests/UnitTests.Shared/Mvvm/Test_ObservableRecipient.cs b/UnitTests/UnitTests.Shared/Mvvm/Test_ObservableRecipient.cs index d5e330e1e53..e40776f4244 100644 --- a/UnitTests/UnitTests.Shared/Mvvm/Test_ObservableRecipient.cs +++ b/UnitTests/UnitTests.Shared/Mvvm/Test_ObservableRecipient.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using Microsoft.Toolkit.Mvvm.ComponentModel; using Microsoft.Toolkit.Mvvm.Messaging; using Microsoft.Toolkit.Mvvm.Messaging.Messages; @@ -14,9 +15,12 @@ public class Test_ObservableRecipient { [TestCategory("Mvvm")] [TestMethod] - public void Test_ObservableRecipient_Activation() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_ObservableRecipient_Activation(Type type) { - var viewmodel = new SomeRecipient(); + var messenger = (IMessenger)Activator.CreateInstance(type); + var viewmodel = new SomeRecipient(messenger); Assert.IsFalse(viewmodel.IsActivatedCheck); @@ -33,18 +37,32 @@ public void Test_ObservableRecipient_Activation() [TestCategory("Mvvm")] [TestMethod] - public void Test_ObservableRecipient_Defaults() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_ObservableRecipient_IsSame(Type type) + { + var messenger = (IMessenger)Activator.CreateInstance(type); + var viewmodel = new SomeRecipient(messenger); + + Assert.AreSame(viewmodel.CurrentMessenger, messenger); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservableRecipient_Default() { var viewmodel = new SomeRecipient(); - Assert.AreSame(viewmodel.CurrentMessenger, Messenger.Default); + Assert.AreSame(viewmodel.CurrentMessenger, WeakReferenceMessenger.Default); } [TestCategory("Mvvm")] [TestMethod] - public void Test_ObservableRecipient_Injection() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_ObservableRecipient_Injection(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var viewmodel = new SomeRecipient(messenger); Assert.AreSame(viewmodel.CurrentMessenger, messenger); @@ -52,14 +70,16 @@ public void Test_ObservableRecipient_Injection() [TestCategory("Mvvm")] [TestMethod] - public void Test_ObservableRecipient_Broadcast() + [DataRow(typeof(StrongReferenceMessenger))] + [DataRow(typeof(WeakReferenceMessenger))] + public void Test_ObservableRecipient_Broadcast(Type type) { - var messenger = new Messenger(); + var messenger = (IMessenger)Activator.CreateInstance(type); var viewmodel = new SomeRecipient(messenger); PropertyChangedMessage message = null; - messenger.Register>(messenger, m => message = m); + messenger.Register>(messenger, (r, m) => message = m); viewmodel.Data = 42; @@ -97,7 +117,7 @@ protected override void OnActivated() { IsActivatedCheck = true; - Messenger.Register(this, m => { }); + Messenger.Register(this, (r, m) => { }); } protected override void OnDeactivated() diff --git a/UnitTests/UnitTests.Shared/Mvvm/Test_RelayCommand{T}.cs b/UnitTests/UnitTests.Shared/Mvvm/Test_RelayCommand{T}.cs index d2e13e7ed21..1c606b4b884 100644 --- a/UnitTests/UnitTests.Shared/Mvvm/Test_RelayCommand{T}.cs +++ b/UnitTests/UnitTests.Shared/Mvvm/Test_RelayCommand{T}.cs @@ -35,7 +35,7 @@ public void Test_RelayCommandOfT_AlwaysEnabled() Assert.AreSame(args.Item1, command); Assert.AreSame(args.Item2, EventArgs.Empty); - command.Execute("Hello"); + command.Execute((object)"Hello"); Assert.AreEqual(text, "Hello"); @@ -57,7 +57,7 @@ public void Test_RelayCommand_WithCanExecuteFunction() Assert.ThrowsException(() => command.CanExecute(new object())); - command.Execute("Hello"); + command.Execute((object)"Hello"); Assert.AreEqual(text, "Hello"); @@ -65,5 +65,21 @@ public void Test_RelayCommand_WithCanExecuteFunction() Assert.AreEqual(text, "Hello"); } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_RelayCommand_NullWithValueType() + { + int n = 0; + + var command = new RelayCommand(i => n = i); + + // Special case for null value types + Assert.IsTrue(command.CanExecute(null)); + + command = new RelayCommand(i => n = i, i => i > 0); + + Assert.ThrowsException(() => command.CanExecute(null)); + } } } diff --git a/UnitTests/UnitTests.Shared/UnitTests.Shared.projitems b/UnitTests/UnitTests.Shared/UnitTests.Shared.projitems index 1d958a159c5..9d9bd736b94 100644 --- a/UnitTests/UnitTests.Shared/UnitTests.Shared.projitems +++ b/UnitTests/UnitTests.Shared/UnitTests.Shared.projitems @@ -28,6 +28,7 @@ +