diff --git a/src/Hosting/Hosting/src/GenericHost/GenericWebHostBuilder.cs b/src/Hosting/Hosting/src/GenericHost/GenericWebHostBuilder.cs index 8bf97d576368..a9625a7907a0 100644 --- a/src/Hosting/Hosting/src/GenericHost/GenericWebHostBuilder.cs +++ b/src/Hosting/Hosting/src/GenericHost/GenericWebHostBuilder.cs @@ -21,6 +21,7 @@ internal class GenericWebHostBuilder : IWebHostBuilder, ISupportsStartup, ISuppo { private readonly IHostBuilder _builder; private readonly IConfiguration _config; + private object _startupObject; private readonly object _startupKey = new object(); private AggregateException _hostingStartupErrors; @@ -198,10 +199,12 @@ public IWebHostBuilder UseDefaultServiceProvider(Action { - if (_builder.Properties.TryGetValue("UseStartup.StartupType", out var cachedType) && (Type)cachedType == startupType) + // Run this delegate if the startup type matches + if (object.ReferenceEquals(_startupObject, startupType)) { UseStartup(startupType, context, services); } @@ -210,13 +213,31 @@ public IWebHostBuilder UseStartup(Type startupType) return this; } - private void UseStartup(Type startupType, HostBuilderContext context, IServiceCollection services) + public IWebHostBuilder UseStartup(Func startupFactory) + { + // Clear the startup type + _startupObject = startupFactory; + + _builder.ConfigureServices((context, services) => + { + // UseStartup can be called multiple times. Only run the last one. + if (object.ReferenceEquals(_startupObject, startupFactory)) + { + var webHostBuilderContext = GetWebHostBuilderContext(context); + var instance = startupFactory(webHostBuilderContext) ?? throw new InvalidOperationException("The specified factory returned null startup instance."); + UseStartup(instance.GetType(), context, services, instance); + } + }); + + return this; + } + + private void UseStartup(Type startupType, HostBuilderContext context, IServiceCollection services, object instance = null) { var webHostBuilderContext = GetWebHostBuilderContext(context); var webHostOptions = (WebHostOptions)context.Properties[typeof(WebHostOptions)]; ExceptionDispatchInfo startupError = null; - object instance = null; ConfigureBuilder configureBuilder = null; try @@ -231,7 +252,7 @@ private void UseStartup(Type startupType, HostBuilderContext context, IServiceCo throw new NotSupportedException($"ConfigureServices returning an {typeof(IServiceProvider)} isn't supported."); } - instance = ActivatorUtilities.CreateInstance(new HostServiceProvider(webHostBuilderContext), startupType); + instance ??= ActivatorUtilities.CreateInstance(new HostServiceProvider(webHostBuilderContext), startupType); context.Properties[_startupKey] = instance; // Startup.ConfigureServices @@ -296,13 +317,19 @@ private void ConfigureContainer(HostBuilderContext context, TContain public IWebHostBuilder Configure(Action configure) { + // Clear the startup type + _startupObject = configure; + _builder.ConfigureServices((context, services) => { - services.Configure(options => + if (object.ReferenceEquals(_startupObject, configure)) { - var webhostBuilderContext = GetWebHostBuilderContext(context); - options.ConfigureApplication = app => configure(webhostBuilderContext, app); - }); + services.Configure(options => + { + var webhostBuilderContext = GetWebHostBuilderContext(context); + options.ConfigureApplication = app => configure(webhostBuilderContext, app); + }); + } }); return this; diff --git a/src/Hosting/Hosting/src/GenericHost/HostingStartupWebHostBuilder.cs b/src/Hosting/Hosting/src/GenericHost/HostingStartupWebHostBuilder.cs index 42cda1cc81f9..f4034fc38bd6 100644 --- a/src/Hosting/Hosting/src/GenericHost/HostingStartupWebHostBuilder.cs +++ b/src/Hosting/Hosting/src/GenericHost/HostingStartupWebHostBuilder.cs @@ -75,5 +75,10 @@ public IWebHostBuilder UseStartup(Type startupType) { return _builder.UseStartup(startupType); } + + public IWebHostBuilder UseStartup(Func startupFactory) + { + return _builder.UseStartup(startupFactory); + } } } diff --git a/src/Hosting/Hosting/src/GenericHost/ISupportsStartup.cs b/src/Hosting/Hosting/src/GenericHost/ISupportsStartup.cs index 8a2d1dd8abd5..998f73d06bc0 100644 --- a/src/Hosting/Hosting/src/GenericHost/ISupportsStartup.cs +++ b/src/Hosting/Hosting/src/GenericHost/ISupportsStartup.cs @@ -10,5 +10,6 @@ internal interface ISupportsStartup { IWebHostBuilder Configure(Action configure); IWebHostBuilder UseStartup(Type startupType); + IWebHostBuilder UseStartup(Func startupFactory); } } diff --git a/src/Hosting/Hosting/src/Internal/StartupLoader.cs b/src/Hosting/Hosting/src/Internal/StartupLoader.cs index 46c195c22a1e..8c80884b83f8 100644 --- a/src/Hosting/Hosting/src/Internal/StartupLoader.cs +++ b/src/Hosting/Hosting/src/Internal/StartupLoader.cs @@ -37,15 +37,14 @@ internal class StartupLoader // // If the Startup class ConfigureServices returns an and there is at least an registered we // throw as the filters can't be applied. - public static StartupMethods LoadMethods(IServiceProvider hostingServiceProvider, Type startupType, string environmentName) + public static StartupMethods LoadMethods(IServiceProvider hostingServiceProvider, Type startupType, string environmentName, object instance = null) { var configureMethod = FindConfigureDelegate(startupType, environmentName); var servicesMethod = FindConfigureServicesDelegate(startupType, environmentName); var configureContainerMethod = FindConfigureContainerDelegate(startupType, environmentName); - object instance = null; - if (!configureMethod.MethodInfo.IsStatic || (servicesMethod != null && !servicesMethod.MethodInfo.IsStatic)) + if (instance == null && (!configureMethod.MethodInfo.IsStatic || (servicesMethod != null && !servicesMethod.MethodInfo.IsStatic))) { instance = ActivatorUtilities.GetServiceOrCreateInstance(hostingServiceProvider, startupType); } @@ -54,7 +53,7 @@ public static StartupMethods LoadMethods(IServiceProvider hostingServiceProvider // going to be used for anything. var type = configureContainerMethod.MethodInfo != null ? configureContainerMethod.GetContainerType() : typeof(object); - var builder = (ConfigureServicesDelegateBuilder) Activator.CreateInstance( + var builder = (ConfigureServicesDelegateBuilder)Activator.CreateInstance( typeof(ConfigureServicesDelegateBuilder<>).MakeGenericType(type), hostingServiceProvider, servicesMethod, @@ -104,13 +103,13 @@ Action ConfigureContainerPipeline(Action action) // The ConfigureContainer pipeline needs an Action as source, so we just adapt the // signature with this function. - void Source(TContainerBuilder containerBuilder) => + void Source(TContainerBuilder containerBuilder) => action(containerBuilder); // The ConfigureContainerBuilder.ConfigureContainerFilters expects an Action as value, but our pipeline // produces an Action given a source, so we wrap it on an Action that internally casts // the object containerBuilder to TContainerBuilder to match the expected signature of our ConfigureContainer pipeline. - void Target(object containerBuilder) => + void Target(object containerBuilder) => BuildStartupConfigureContainerFiltersPipeline(Source)((TContainerBuilder)containerBuilder); } } diff --git a/src/Hosting/Hosting/src/WebHostBuilderExtensions.cs b/src/Hosting/Hosting/src/WebHostBuilderExtensions.cs index 9189786b1f90..5d2b7133cc0b 100644 --- a/src/Hosting/Hosting/src/WebHostBuilderExtensions.cs +++ b/src/Hosting/Hosting/src/WebHostBuilderExtensions.cs @@ -61,6 +61,48 @@ private static IWebHostBuilder Configure(this IWebHostBuilder hostBuilder, Actio }); } + /// + /// Specify a factory that creates the startup instance to be used by the web host. + /// + /// The to configure. + /// A delegate that specifies a factory for the startup class. + /// The . + public static IWebHostBuilder UseStartup(this IWebHostBuilder hostBuilder, Func startupFactory) + { + if (startupFactory == null) + { + throw new ArgumentNullException(nameof(startupFactory)); + } + + var startupAssemblyName = startupFactory.GetMethodInfo().DeclaringType.GetTypeInfo().Assembly.GetName().Name; + + hostBuilder.UseSetting(WebHostDefaults.ApplicationKey, startupAssemblyName); + + // Light up the GenericWebHostBuilder implementation + if (hostBuilder is ISupportsStartup supportsStartup) + { + return supportsStartup.UseStartup(startupFactory); + } + + return hostBuilder + .ConfigureServices((context, services) => + { + services.AddSingleton(typeof(IStartup), sp => + { + var instance = startupFactory(context) ?? throw new InvalidOperationException("The specified factory returned null startup instance."); + + var hostingEnvironment = sp.GetRequiredService(); + + // Check if the instance implements IStartup before wrapping + if (instance is IStartup startup) + { + return startup; + } + + return new ConventionBasedStartup(StartupLoader.LoadMethods(sp, instance.GetType(), hostingEnvironment.EnvironmentName, instance)); + }); + }); + } /// /// Specify the startup type to be used by the web host. @@ -70,6 +112,11 @@ private static IWebHostBuilder Configure(this IWebHostBuilder hostBuilder, Actio /// The . public static IWebHostBuilder UseStartup(this IWebHostBuilder hostBuilder, Type startupType) { + if (startupType == null) + { + throw new ArgumentNullException(nameof(startupType)); + } + var startupAssemblyName = startupType.GetTypeInfo().Assembly.GetName().Name; hostBuilder.UseSetting(WebHostDefaults.ApplicationKey, startupAssemblyName); diff --git a/src/Hosting/Hosting/test/Fakes/GenericWebHostBuilderWrapper.cs b/src/Hosting/Hosting/test/Fakes/GenericWebHostBuilderWrapper.cs index 876cc7199b07..e87be3a9ad92 100644 --- a/src/Hosting/Hosting/test/Fakes/GenericWebHostBuilderWrapper.cs +++ b/src/Hosting/Hosting/test/Fakes/GenericWebHostBuilderWrapper.cs @@ -73,5 +73,11 @@ public IWebHostBuilder UseStartup(Type startupType) _builder.UseStartup(startupType); return this; } + + public IWebHostBuilder UseStartup(Func startupFactory) + { + _builder.UseStartup(startupFactory); + return this; + } } } diff --git a/src/Hosting/Hosting/test/WebHostBuilderTests.cs b/src/Hosting/Hosting/test/WebHostBuilderTests.cs index e34ae4938805..81a8ab494124 100644 --- a/src/Hosting/Hosting/test/WebHostBuilderTests.cs +++ b/src/Hosting/Hosting/test/WebHostBuilderTests.cs @@ -6,15 +6,19 @@ using System.IO; using System.Linq; using System.Reflection; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; +using System.Web; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting.Fakes; using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Hosting.Tests.Fakes; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Extensions; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -68,6 +72,71 @@ public async Task StartupStaticCtorThrows_Fallback(IWebHostBuilder builder) } } + [Theory] + [MemberData(nameof(DefaultWebHostBuildersWithConfig))] + public void UseStartupThrowsWhenFactoryIsNull(IWebHostBuilder builder) + { + var server = new TestServer(); + Assert.Throws(() => builder.UseServer(server).UseStartup((Func)null)); + } + + [Theory] + [MemberData(nameof(DefaultWebHostBuilders))] + public void UseStartupThrowsWhenFactoryReturnsNull(IWebHostBuilder builder) + { + var server = new TestServer(); + var ex = Assert.Throws(() => builder.UseServer(server).UseStartup(context => null).Build()); + Assert.Equal("The specified factory returned null startup instance.", ex.Message); + } + + [Theory] + [MemberData(nameof(DefaultWebHostBuildersWithConfig))] + public async Task MultipleUseStartupCallsLastWins(IWebHostBuilder builder) + { + var server = new TestServer(); + var host = builder.UseServer(server) + .UseStartup() + .UseStartup(context => throw new InvalidOperationException("This doesn't run")) + .Configure(app => + { + throw new InvalidOperationException("This doesn't run"); + }) + .Configure(app => + { + app.Run(context => + { + return context.Response.WriteAsync("This wins"); + }); + }) + .Build(); + using (host) + { + await host.StartAsync(); + await AssertResponseContains(server.RequestDelegate, "This wins"); + } + } + + [Theory] + [MemberData(nameof(DefaultWebHostBuildersWithConfig))] + public async Task UseStartupFactoryWorks(IWebHostBuilder builder) + { + void ConfigureServices(IServiceCollection services) { } + void Configure(IApplicationBuilder app) + { + app.Run(context => context.Response.WriteAsync("UseStartupFactoryWorks")); + } + + var server = new TestServer(); + var host = builder.UseServer(server) + .UseStartup(context => new DelegatingStartup(ConfigureServices, Configure)) + .Build(); + using (host) + { + await host.StartAsync(); + await AssertResponseContains(server.RequestDelegate, "UseStartupFactoryWorks"); + } + } + [Theory] [MemberData(nameof(DefaultWebHostBuildersWithConfig))] public async Task StartupCtorThrows_Fallback(IWebHostBuilder builder) @@ -199,7 +268,7 @@ public void ConfigureDefaultServiceProviderWithContext(IWebHostBuilder builder) options.ValidateScopes = true; }); - using var host = hostBuilder.Build(); + using var host = hostBuilder.Build(); Assert.Throws(() => host.Start()); Assert.True(configurationCallbackCalled); } @@ -728,6 +797,22 @@ public void DefaultApplicationNameWithConfigure(IWebHostBuilder builder) } } + [Theory] + [MemberData(nameof(DefaultWebHostBuilders))] + public void DefaultApplicationNameWithUseStartupFactory(IWebHostBuilder builder) + { + using (var host = builder + .UseServer(new TestServer()) + .UseStartup(context => new DelegatingStartup(s => { }, app => { })) + .Build()) + { + var hostingEnv = host.Services.GetService(); + + // Should be the assembly containing this test, because that's where the delegate comes from + Assert.Equal(typeof(WebHostBuilderTests).Assembly.GetName().Name, hostingEnv.ApplicationName); + } + } + [Theory] [MemberData(nameof(DefaultWebHostBuilders))] public void Configure_SupportsNonStaticMethodDelegate(IWebHostBuilder builder) @@ -770,6 +855,27 @@ public void Build_DoesNotAllowBuildingMuiltipleTimes() } } + [Fact] + public async Task UseStartupImplementingIStartupWorks() + { + void Configure(IApplicationBuilder app) + { + app.Run(context => context.Response.WriteAsync("Configure")); + } + + IServiceProvider ConfigureServices(IServiceCollection services) => services.BuildServiceProvider(); + + var builder = CreateWebHostBuilder(); + var server = new TestServer(); + using (var host = builder.UseServer(server) + .UseStartup(context => new DelegatingStartupWithIStartup(ConfigureServices, Configure)) + .Build()) + { + await host.StartAsync(); + await AssertResponseContains(server.RequestDelegate, "Configure"); + } + } + [Theory] [MemberData(nameof(DefaultWebHostBuildersWithConfig))] public void Build_DoesNotOverrideILoggerFactorySetByConfigureServices(IWebHostBuilder builder) @@ -1218,7 +1324,7 @@ public void UseConfigurationWithSectionAddsSubKeys(IWebHostBuilder builder) Assert.Equal("nestedvalue", builder.GetSetting("key")); - using var host = builder.Build(); + using var host = builder.Build(); var appConfig = host.Services.GetRequiredService(); Assert.Equal("nestedvalue", appConfig["key"]); } @@ -1574,6 +1680,37 @@ public void Configure(IWebHostBuilder builder) } } + private class DelegatingStartupWithIStartup : IStartup + { + private readonly Func _configureServices; + private readonly Action _configure; + + public DelegatingStartupWithIStartup(Func configureServices, Action configure) + { + _configureServices = configureServices; + _configure = configure; + } + + // These are explicitly implemented to verify they don't get called via reflection + IServiceProvider IStartup.ConfigureServices(IServiceCollection services) => _configureServices(services); + void IStartup.Configure(IApplicationBuilder app) => _configure(app); + } + + public class DelegatingStartup + { + private readonly Action _configureServices; + private readonly Action _configure; + + public DelegatingStartup(Action configureServices, Action configure) + { + _configureServices = configureServices; + _configure = configure; + } + + public void ConfigureServices(IServiceCollection services) => _configureServices(services); + public void Configure(IApplicationBuilder app) => _configure(app); + } + public class StartupWithResolvedDisposableThatThrows { public StartupWithResolvedDisposableThatThrows(DisposableService service)