diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests.csproj b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests.csproj
new file mode 100644
index 000000000..ff9b195ac
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests.csproj
@@ -0,0 +1,21 @@
+
+
+
+ net6.0
+ enable
+
+ false
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests/LabsUITestMethodTests.cs b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests/LabsUITestMethodTests.cs
new file mode 100644
index 000000000..3656bb557
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Tests/LabsUITestMethodTests.cs
@@ -0,0 +1,296 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Diagnostics;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace CommunityToolkit.Labs.Core.SourceGenerators.Tests;
+
+[TestClass]
+public partial class LabsUITestMethodTests
+{
+ private const string DispatcherQueueDefinition = @"
+namespace MyApp
+{
+ public partial class Test
+ {
+ public System.Threading.Tasks.Task EnqueueAsync(System.Func> function) => System.Threading.Tasks.Task.Run(function);
+
+ public System.Threading.Tasks.Task EnqueueAsync(System.Action function) => System.Threading.Tasks.Task.Run(function);
+ }
+}
+";
+
+ [TestMethod]
+ public void TestControlHasConstructorWithParameters()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ public System.Threading.Tasks.Task LoadTestContentAsync(Microsoft.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+ public System.Threading.Tasks.Task UnloadTestContentAsync(Microsoft.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+
+ [LabsUITestMethod]
+ public void TestMethod(MyControl control)
+ {
+ }
+ }
+
+ public class MyControl : Microsoft.UI.Xaml.FrameworkElement
+ {
+ public MyControl(string id)
+ {
+ }
+ }
+ }
+
+ namespace Microsoft.UI.Xaml
+ {
+ public class FrameworkElement { }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition, DiagnosticDescriptors.TestControlHasConstructorWithParameters.Id);
+ }
+
+ [TestMethod]
+ public void Async_Mux_NoErrors()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ public System.Threading.Tasks.Task LoadTestContentAsync(Microsoft.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+ public System.Threading.Tasks.Task UnloadTestContentAsync(Microsoft.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+
+ [LabsUITestMethod]
+ public async System.Threading.Tasks.Task TestMethod(MyControl control)
+ {
+ }
+ }
+
+ public class MyControl : Microsoft.UI.Xaml.FrameworkElement
+ {
+ }
+ }
+
+ namespace Microsoft.UI.Xaml
+ {
+ public class FrameworkElement { }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition);
+ }
+
+ [TestMethod]
+ public void Async_Wux_NoErrors()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ public System.Threading.Tasks.Task LoadTestContentAsync(Windows.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+ public System.Threading.Tasks.Task UnloadTestContentAsync(Windows.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+
+ [LabsUITestMethod]
+ public async System.Threading.Tasks.Task TestMethod(MyControl control)
+ {
+ }
+ }
+
+ public class MyControl : Windows.UI.Xaml.FrameworkElement
+ {
+ }
+ }
+
+ namespace Windows.UI.Xaml
+ {
+ public class FrameworkElement { }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition);
+ }
+
+ [TestMethod]
+ public void Async_NoMethodParams_NoErrors()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ [LabsUITestMethod]
+ public async System.Threading.Tasks.Task TestMethod()
+ {
+ }
+ }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition);
+ }
+
+ [TestMethod]
+ public void Synchronous_Mux_NoErrors()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ public System.Threading.Tasks.Task LoadTestContentAsync(Microsoft.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+ public System.Threading.Tasks.Task UnloadTestContentAsync(Microsoft.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+
+ [LabsUITestMethod]
+ public void TestMethod(MyControl control)
+ {
+ }
+ }
+
+ public class MyControl : Microsoft.UI.Xaml.FrameworkElement
+ {
+ }
+ }
+
+ namespace Microsoft.UI.Xaml
+ {
+ public class FrameworkElement { }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition);
+ }
+
+ [TestMethod]
+ public void Synchronous_Wux_NoErrors()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ public System.Threading.Tasks.Task LoadTestContentAsync(Windows.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+ public System.Threading.Tasks.Task UnloadTestContentAsync(Windows.UI.Xaml.FrameworkElement content) => System.Threading.Tasks.Task.CompletedTask;
+
+ [LabsUITestMethod]
+ public void TestMethod(MyControl control)
+ {
+ }
+ }
+
+ public class MyControl : Windows.UI.Xaml.FrameworkElement
+ {
+ }
+ }
+
+ namespace Windows.UI.Xaml
+ {
+ public class FrameworkElement { }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition);
+ }
+
+ [TestMethod]
+ public void Synchronous_NoMethodParams_NoErrors()
+ {
+ string source = @"
+ using System.ComponentModel;
+ using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+ namespace MyApp
+ {
+ public partial class Test
+ {
+ [LabsUITestMethod]
+ public void TestMethod()
+ {
+ }
+ }
+ }";
+
+ VerifyGeneratedDiagnostics(source + DispatcherQueueDefinition);
+ }
+
+ ///
+ /// Verifies the output of a source generator.
+ ///
+ /// The generator type to use.
+ /// The input source to process.
+ /// The input documentation info to process.
+ /// The diagnostic ids to expect for the input source code.
+ private static void VerifyGeneratedDiagnostics(string source, params string[] diagnosticsIds)
+ where TGenerator : class, IIncrementalGenerator, new()
+ {
+ VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source), diagnosticsIds);
+ }
+
+ ///
+ /// Verifies the output of a source generator.
+ ///
+ /// The generator type to use.
+ /// The input source tree to process.
+ /// The input documentation info to process.
+ /// The diagnostic ids to expect for the input source code.
+ private static void VerifyGeneratedDiagnostics(SyntaxTree syntaxTree, params string[] diagnosticsIds)
+ where TGenerator : class, IIncrementalGenerator, new()
+ {
+ var attributeType = typeof(LabsUITestMethodAttribute);
+
+ var references =
+ from assembly in AppDomain.CurrentDomain.GetAssemblies()
+ where !assembly.IsDynamic
+ let reference = MetadataReference.CreateFromFile(assembly.Location)
+ select reference;
+
+ var compilation = CSharpCompilation.Create(
+ "original.Sample",
+ new[] { syntaxTree },
+ references,
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+
+ var compilationDiagnostics = compilation.GetDiagnostics();
+
+ Assert.IsTrue(compilationDiagnostics.All(x => x.Severity != DiagnosticSeverity.Error), $"Expected no compilation errors before source generation. Got: \n{string.Join("\n", compilationDiagnostics.Where(x => x.Severity == DiagnosticSeverity.Error).Select(x => $"[{x.Id}: {x.GetMessage()}]"))}");
+
+ IIncrementalGenerator generator = new TGenerator();
+
+ GeneratorDriver driver =
+ CSharpGeneratorDriver
+ .Create(generator)
+ .WithUpdatedParseOptions((CSharpParseOptions)syntaxTree.Options);
+
+ _ = driver.RunGeneratorsAndUpdateCompilation(compilation, out Compilation outputCompilation, out ImmutableArray diagnostics);
+
+ HashSet resultingIds = diagnostics.Select(diagnostic => diagnostic.Id).ToHashSet();
+ var generatedCompilationDiaghostics = outputCompilation.GetDiagnostics();
+
+ Assert.IsTrue(resultingIds.SetEquals(diagnosticsIds), $"Expected one of [{string.Join(", ", diagnosticsIds)}] diagnostic Ids. Got [{string.Join(", ", resultingIds)}]");
+ Assert.IsTrue(generatedCompilationDiaghostics.All(x => x.Severity != DiagnosticSeverity.Error), $"Expected no generated compilation errors. Got: \n{string.Join("\n", generatedCompilationDiaghostics.Where(x => x.Severity == DiagnosticSeverity.Error).Select(x => $"[{x.Id}: {x.GetMessage()}]"))}");
+
+ GC.KeepAlive(attributeType);
+ }
+}
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.csproj b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.csproj
new file mode 100644
index 000000000..9b3c8d237
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.csproj
@@ -0,0 +1,14 @@
+
+
+
+ netstandard2.0
+ enable
+ nullable
+ 10.0
+
+
+
+
+
+
+
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Diagnostics/DiagnosticDescriptors.cs b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Diagnostics/DiagnosticDescriptors.cs
new file mode 100644
index 000000000..c9a304039
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Diagnostics/DiagnosticDescriptors.cs
@@ -0,0 +1,28 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.CodeAnalysis;
+
+namespace CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Diagnostics;
+
+///
+/// A container for all instances for errors reported by analyzers in this project.
+///
+public static class DiagnosticDescriptors
+{
+ ///
+ /// Gets a indicating that a test method decorated with asks for a control instance with a non-parameterless constructor.
+ ///
+ /// Format: "Cannot generate test with type {0} as it has a constructor with parameters.".
+ ///
+ ///
+ public static readonly DiagnosticDescriptor TestControlHasConstructorWithParameters = new(
+ id: "LUITM0001",
+ title: $"Provided control must not have a constructor with parameters.",
+ messageFormat: $"Cannot generate test with control {{0}} as it has a constructor with parameters.",
+ category: typeof(LabsUITestMethodGenerator).FullName,
+ defaultSeverity: DiagnosticSeverity.Error,
+ isEnabledByDefault: true,
+ description: $"Cannot generate test method with provided control.");
+}
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Extensions/ISymbolExtensions.cs b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Extensions/ISymbolExtensions.cs
new file mode 100644
index 000000000..a95702b75
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Extensions/ISymbolExtensions.cs
@@ -0,0 +1,122 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Immutable;
+using Microsoft.CodeAnalysis;
+
+namespace CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Extensions;
+
+///
+/// Extension methods for the type.
+///
+///
+/// Borrowed from ISymbolExtensions in the dotnet Toolkit's CommunityToolkit.Mvvm.SourceGenerators project.
+///
+internal static class ISymbolExtensions
+{
+ ///
+ /// Gets the fully qualified name for a given symbol.
+ ///
+ /// The input instance.
+ /// The fully qualified name for .
+ public static string GetFullyQualifiedName(this ISymbol symbol)
+ {
+ return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ }
+
+ ///
+ /// Gets the fully qualified name for a given symbol, including nullability annotations
+ ///
+ /// The input instance.
+ /// The fully qualified name for .
+ public static string GetFullyQualifiedNameWithNullabilityAnnotations(this ISymbol symbol)
+ {
+ return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat.AddMiscellaneousOptions(SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier));
+ }
+
+ ///
+ /// Checks whether or not a given type symbol has a specified full name.
+ ///
+ /// The input instance to check.
+ /// The full name to check.
+ /// Whether has a full name equals to .
+ public static bool HasFullyQualifiedName(this ISymbol symbol, string name)
+ {
+ return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == name;
+ }
+
+ ///
+ /// Checks whether or not a given symbol has an attribute with the specified full name.
+ ///
+ /// The input instance to check.
+ /// The attribute name to look for.
+ /// Whether or not has an attribute with the specified name.
+ public static bool HasAttributeWithFullyQualifiedName(this ISymbol symbol, string name)
+ {
+ ImmutableArray attributes = symbol.GetAttributes();
+
+ foreach (AttributeData attribute in attributes)
+ {
+ if (attribute.AttributeClass?.HasFullyQualifiedName(name) == true)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Calculates the effective accessibility for a given symbol.
+ ///
+ /// The instance to check.
+ /// The effective accessibility for .
+ public static Accessibility GetEffectiveAccessibility(this ISymbol symbol)
+ {
+ // Start by assuming it's visible
+ Accessibility visibility = Accessibility.Public;
+
+ // Handle special cases
+ switch (symbol.Kind)
+ {
+ case SymbolKind.Alias: return Accessibility.Private;
+ case SymbolKind.Parameter: return GetEffectiveAccessibility(symbol.ContainingSymbol);
+ case SymbolKind.TypeParameter: return Accessibility.Private;
+ }
+
+ // Traverse the symbol hierarchy to determine the effective accessibility
+ while (symbol is not null && symbol.Kind != SymbolKind.Namespace)
+ {
+ switch (symbol.DeclaredAccessibility)
+ {
+ case Accessibility.NotApplicable:
+ case Accessibility.Private:
+ return Accessibility.Private;
+ case Accessibility.Internal:
+ case Accessibility.ProtectedAndInternal:
+ visibility = Accessibility.Internal;
+ break;
+ }
+
+ symbol = symbol.ContainingSymbol;
+ }
+
+ return visibility;
+ }
+
+ ///
+ /// Checks whether or not a given symbol can be accessed from a specified assembly.
+ ///
+ /// The input instance to check.
+ /// The assembly to check the accessibility of for.
+ /// Whether can access .
+ public static bool CanBeAccessedFrom(this ISymbol symbol, IAssemblySymbol assembly)
+ {
+ Accessibility accessibility = symbol.GetEffectiveAccessibility();
+
+ return
+ accessibility == Accessibility.Public ||
+ accessibility == Accessibility.Internal && symbol.ContainingAssembly.GivesAccessTo(assembly);
+ }
+}
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Extensions/ITypeSymbolExtensions.cs b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Extensions/ITypeSymbolExtensions.cs
new file mode 100644
index 000000000..7d4f813c4
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/Extensions/ITypeSymbolExtensions.cs
@@ -0,0 +1,133 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+
+namespace CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Extensions;
+
+///
+/// Extension methods for the type.
+///
+/// Borrowed from ITypeSymbolExtensions in the dotnet Toolkit's CommunityToolkit.Mvvm.SourceGenerators project.
+///
+internal static class ITypeSymbolExtensions
+{
+ ///
+ /// Checks whether or not a given has or inherits from a specified type.
+ ///
+ /// The target instance to check.
+ /// The full name of the type to check for inheritance.
+ /// Whether or not is or inherits from .
+ public static bool HasOrInheritsFromFullyQualifiedName(this ITypeSymbol typeSymbol, string name)
+ {
+ for (var currentType = typeSymbol; currentType is not null; currentType = currentType.BaseType)
+ {
+ if (currentType.HasFullyQualifiedName(name))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Checks whether or not a given inherits from a specified type.
+ ///
+ /// The target instance to check.
+ /// The full name of the type to check for inheritance.
+ /// Whether or not inherits from .
+ public static bool InheritsFromFullyQualifiedName(this ITypeSymbol typeSymbol, string name)
+ {
+ var baseType = typeSymbol.BaseType;
+
+ while (baseType != null)
+ {
+ if (baseType.HasFullyQualifiedName(name))
+ {
+ return true;
+ }
+
+ baseType = baseType.BaseType;
+ }
+
+ return false;
+ }
+
+ ///
+ /// Checks whether or not a given implements an interface with a specified name.
+ ///
+ /// The target instance to check.
+ /// The full name of the type to check for interface implementation.
+ /// Whether or not has an interface with the specified name.
+ public static bool HasInterfaceWithFullyQualifiedName(this ITypeSymbol typeSymbol, string name)
+ {
+ foreach (var interfaceType in typeSymbol.AllInterfaces)
+ {
+ if (interfaceType.HasFullyQualifiedName(name))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Checks whether or not a given has or inherits a specified attribute.
+ ///
+ /// The target instance to check.
+ /// The predicate used to match available attributes.
+ /// Whether or not has an attribute matching .
+ public static bool HasOrInheritsAttribute(this ITypeSymbol typeSymbol, Func predicate)
+ {
+ for (var currentType = typeSymbol; currentType is not null; currentType = currentType.BaseType)
+ {
+ if (currentType.GetAttributes().Any(predicate))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Checks whether or not a given has or inherits a specified attribute.
+ ///
+ /// The target instance to check.
+ /// The name of the attribute to look for.
+ /// Whether or not has an attribute with the specified type name.
+ public static bool HasOrInheritsAttributeWithFullyQualifiedName(this ITypeSymbol typeSymbol, string name)
+ {
+ for (var currentType = typeSymbol; currentType is not null; currentType = currentType.BaseType)
+ {
+ if (currentType.HasAttributeWithFullyQualifiedName(name))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Checks whether or not a given inherits a specified attribute.
+ /// If the type has no base type, this method will automatically handle that and return .
+ ///
+ /// The target instance to check.
+ /// The name of the attribute to look for.
+ /// Whether or not has an attribute with the specified type name.
+ public static bool InheritsAttributeWithFullyQualifiedName(this ITypeSymbol typeSymbol, string name)
+ {
+ if (typeSymbol.BaseType is INamedTypeSymbol baseTypeSymbol)
+ {
+ return baseTypeSymbol.HasOrInheritsAttributeWithFullyQualifiedName(name);
+ }
+
+ return false;
+ }
+}
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/LabsUITestMethodAttribute.cs b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/LabsUITestMethodAttribute.cs
new file mode 100644
index 000000000..b5f05b575
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/LabsUITestMethodAttribute.cs
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+///
+/// Generates a test method that auto-dispatches method contents to the UI thread,
+/// and provides an instance of a control as a parameter if present in the method signature.
+///
+[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
+public sealed class LabsUITestMethodAttribute : Attribute
+{
+}
diff --git a/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/LabsUITestMethodGenerator.cs b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/LabsUITestMethodGenerator.cs
new file mode 100644
index 000000000..e60a249fc
--- /dev/null
+++ b/common/CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod/LabsUITestMethodGenerator.cs
@@ -0,0 +1,104 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Diagnostics;
+using CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.Extensions;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System.Linq;
+
+namespace CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod;
+
+///
+/// Generates code that provides access to XAML elements with x:Name from code-behind by wrapping an instance of a control, without the need to use x:FieldProvider="public" directly in markup.
+///
+[Generator]
+public class LabsUITestMethodGenerator : IIncrementalGenerator
+{
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ // Get all method declarations with at least one attribute
+ var methodSymbols = context.SyntaxProvider
+ .CreateSyntaxProvider(
+ static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 },
+ static (context, token) => context.SemanticModel.GetDeclaredSymbol((MethodDeclarationSyntax)context.Node, cancellationToken: token))
+ .Where(x => x is not null)
+ .Select((x, _) => x!);
+
+ // Filter the methods using [LabsUITestMethod]
+ var methodAndPageTypeSymbols = methodSymbols
+ .Select(static (item, _) =>
+ (
+ Symbol: item,
+ Attribute: item.GetAttributes().FirstOrDefault(a => a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Labs.Core.SourceGenerators.LabsUITestMethod.LabsUITestMethodAttribute") ?? false)
+ ))
+
+ .Where(static item => item.Attribute is not null && item.Symbol is IMethodSymbol)
+ .Select(static (x, _) => (IMethodSymbol)x.Symbol)
+
+ .Select(static (x, _) => (MethodSymbol: x, ControlTypeSymbol: GetControlTypeSymbolFromMethodParameters(x)))
+
+ .Where(static x => x.ControlTypeSymbol is not null)
+ .Select(static (x, _) => (x.MethodSymbol, ControlTypeSymbol: x.ControlTypeSymbol!));
+
+ // Generate source
+ context.RegisterSourceOutput(methodAndPageTypeSymbols, (x, y) => GenerateTestMethod(x, y.MethodSymbol, y.ControlTypeSymbol));
+ }
+
+ private static void GenerateTestMethod(SourceProductionContext context, IMethodSymbol methodSymbol, INamedTypeSymbol? controlTypeSymbol)
+ {
+ if (controlTypeSymbol is not null && controlTypeSymbol.Constructors.Any(x => x.DeclaredAccessibility == Accessibility.Public && !x.Parameters.IsEmpty))
+ {
+ context.ReportDiagnostic(Diagnostic.Create(DiagnosticDescriptors.TestControlHasConstructorWithParameters, methodSymbol.Locations.FirstOrDefault(), controlTypeSymbol.Name));
+ return;
+ }
+
+ var isAsync = methodSymbol.ReturnType.HasFullyQualifiedName("global::System.Threading.Tasks.Task") ||
+ methodSymbol.ReturnType.InheritsFromFullyQualifiedName("global::System.Threading.Tasks.Task");
+
+ var source = $@"using System.Threading.Tasks;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace {methodSymbol.ContainingType.ContainingNamespace}
+{{
+ partial class {methodSymbol.ContainingType.Name}
+ {{
+ [TestMethod]
+ public Task {methodSymbol.Name}_Test()
+ {{
+ return EnqueueAsync(async () => {{
+ {(controlTypeSymbol is not null ? @$"
+ // Create content
+ var testControl = new {controlTypeSymbol.GetFullyQualifiedName()}();
+
+ // Load content
+ await LoadTestContentAsync(testControl);" : string.Empty)}
+
+ // Run test
+ {(isAsync ? "await " : string.Empty)}{methodSymbol.Name}({(controlTypeSymbol is not null ? "testControl" : string.Empty)});
+
+ {(controlTypeSymbol is not null ?
+ @"// Unload content
+ await UnloadTestContentAsync(testControl);" : string.Empty)}
+ }});
+ }}
+ }}
+}}
+";
+
+ context.AddSource($"{methodSymbol.Name}.g", source);
+ }
+
+ private static bool ControlTypeInheritsFrameworkElement(ITypeSymbol controlType)
+ {
+ return controlType.HasOrInheritsFromFullyQualifiedName("global::Windows.UI.Xaml.FrameworkElement") ||
+ controlType.HasOrInheritsFromFullyQualifiedName("global::Microsoft.UI.Xaml.FrameworkElement");
+ }
+
+ private static INamedTypeSymbol? GetControlTypeSymbolFromMethodParameters(IMethodSymbol methodSymbol)
+ {
+ return methodSymbol.Parameters.FirstOrDefault(x => ControlTypeInheritsFrameworkElement(x.Type))?.Type as INamedTypeSymbol;
+ }
+}
+
diff --git a/common/CommunityToolkit.Labs.UnitTests.Shared/App.xaml.cs b/common/CommunityToolkit.Labs.UnitTests.Shared/App.xaml.cs
index 3b77613fb..0225f5757 100644
--- a/common/CommunityToolkit.Labs.UnitTests.Shared/App.xaml.cs
+++ b/common/CommunityToolkit.Labs.UnitTests.Shared/App.xaml.cs
@@ -45,20 +45,8 @@ public sealed partial class App : Application
// Holder for test content to abstract Window.Current.Content
public static FrameworkElement? ContentRoot
{
- get
- {
- var rootFrame = currentWindow.Content as Frame;
- return rootFrame?.Content as FrameworkElement;
- }
-
- set
- {
- var rootFrame = currentWindow.Content as Frame;
- if (rootFrame != null)
- {
- rootFrame.Content = value;
- }
- }
+ get => currentWindow.Content as FrameworkElement;
+ set => currentWindow.Content = value;
}
// Abstract CoreApplication.MainView.DispatcherQueue
diff --git a/common/CommunityToolkit.Labs.UnitTests.Shared/CommunityToolkit.Labs.UnitTests.Shared.projitems b/common/CommunityToolkit.Labs.UnitTests.Shared/CommunityToolkit.Labs.UnitTests.Shared.projitems
index f14f94ed2..8a363fae1 100644
--- a/common/CommunityToolkit.Labs.UnitTests.Shared/CommunityToolkit.Labs.UnitTests.Shared.projitems
+++ b/common/CommunityToolkit.Labs.UnitTests.Shared/CommunityToolkit.Labs.UnitTests.Shared.projitems
@@ -19,7 +19,6 @@
App.xaml
-
\ No newline at end of file
diff --git a/common/CommunityToolkit.Labs.UnitTests.Shared/TestPageAttribute.cs b/common/CommunityToolkit.Labs.UnitTests.Shared/TestPageAttribute.cs
deleted file mode 100644
index 1c499a074..000000000
--- a/common/CommunityToolkit.Labs.UnitTests.Shared/TestPageAttribute.cs
+++ /dev/null
@@ -1,24 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-namespace CommunityToolkit.Labs.UnitTests;
-
-///
-/// Attribute to add to a implementation in order to load a XAML based page to use within that test. Class with containing method needs to inherit from for functionality to work. Requires to be set to function.
-///
-[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
-public sealed class TestPageAttribute : Attribute
-{
- public TestPageAttribute(Type pageType)
- {
- if (pageType == null)
- {
- throw new ArgumentException($"'{nameof(pageType)}' cannot be null", nameof(pageType));
- }
-
- PageType = pageType;
- }
-
- public Type PageType { get; private set; }
-}
diff --git a/common/CommunityToolkit.Labs.UnitTests.Shared/VisualUITestBase.cs b/common/CommunityToolkit.Labs.UnitTests.Shared/VisualUITestBase.cs
index ce59a7675..300c41b50 100644
--- a/common/CommunityToolkit.Labs.UnitTests.Shared/VisualUITestBase.cs
+++ b/common/CommunityToolkit.Labs.UnitTests.Shared/VisualUITestBase.cs
@@ -11,33 +11,14 @@ namespace CommunityToolkit.Labs.UnitTests;
///
public class VisualUITestBase
{
- public TestContext? TestContext { get; set; }
+ // Used by source generators to dispatch to the UI thread
+ // Methods must be declared or interfaced for compatibility with the source generator's unit tests.
- public FrameworkElement? TestPage { get; private set; }
+ protected Task EnqueueAsync(Func> function) => App.DispatcherQueue.EnqueueAsync(function);
- [TestInitialize]
- public async Task TestInitialize()
- {
- if (TestContext != null)
- {
- await App.DispatcherQueue.EnqueueAsync(async () =>
- {
- TestPage = GetPageForTest(TestContext);
-
- if (TestPage != null)
- {
- Task result = SetTestContentAsync(TestPage);
-
- await result;
-
- if (!result.IsCompletedSuccessfully)
- {
- throw new Exception($"Failed to load page for {TestContext.TestName} with Exception: {result.Exception?.Message}", result.Exception);
- }
- }
- });
- }
- }
+ protected Task EnqueueAsync(Func function) => App.DispatcherQueue.EnqueueAsync(function);
+
+ protected Task EnqueueAsync(Action function) => App.DispatcherQueue.EnqueueAsync(function);
///
/// Sets the content of the test app to a simple to load into the visual tree.
@@ -45,111 +26,62 @@ await App.DispatcherQueue.EnqueueAsync(async () =>
///
/// Content to set in test app.
/// When UI is loaded.
- protected Task SetTestContentAsync(FrameworkElement content)
- {
- return App.DispatcherQueue.EnqueueAsync(() =>
- {
- var taskCompletionSource = new TaskCompletionSource();
-
- async void Callback(object sender, RoutedEventArgs args)
- {
- content.Loaded -= Callback;
-
- // Wait for first Render pass
- await CompositionTargetHelper.ExecuteAfterCompositionRenderingAsync(() => { });
+ protected async Task LoadTestContentAsync(FrameworkElement content)
+ {
+ var taskCompletionSource = new TaskCompletionSource