diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index a65c0937ac0c..0f049922a85b 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -6,7 +6,6 @@ using System.Net.Security; using System.Security.Cryptography.X509Certificates; using Microsoft.AspNetCore.Server.Kestrel.Core; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; using Microsoft.Extensions.DependencyInjection; @@ -209,7 +208,8 @@ internal static bool TryUseHttps(this ListenOptions listenOptions) } /// - /// Configure Kestrel to use HTTPS. + /// Configure Kestrel to use HTTPS. This does not use default certificates or other defaults specified via config or + /// . /// /// The to configure. /// Options to configure HTTPS. @@ -230,12 +230,44 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn return listenOptions; } + /// + /// Configure Kestrel to use HTTPS. This does not use default certificates or other defaults specified via config or + /// . + /// + /// The to configure. + /// Callback to configure HTTPS options. + /// State for the . + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, ServerOptionsSelectionCallback serverOptionsSelectionCallback, object state) + { + return listenOptions.UseHttps(serverOptionsSelectionCallback, state, HttpsConnectionAdapterOptions.DefaultHandshakeTimeout); + } + + /// + /// Configure Kestrel to use HTTPS. This does not use default certificates or other defaults specified via config or + /// . + /// + /// The to configure. + /// Callback to configure HTTPS options. + /// State for the . + /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive and finite. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, ServerOptionsSelectionCallback serverOptionsSelectionCallback, object state, TimeSpan handshakeTimeout) + { + // HttpsOptionsCallback is an internal delegate that is just the ServerOptionsSelectionCallback + a ConnectionContext parameter. + // Given that ConnectionContext will eventually be replaced by System.Net.Connections, it doesn't make much sense to make the HttpsOptionsCallback delegate public. + HttpsOptionsCallback adaptedCallback = (connection, stream, clientHelloInfo, state, cancellationToken) => + serverOptionsSelectionCallback(stream, clientHelloInfo, state, cancellationToken); + + return listenOptions.UseHttps(adaptedCallback, state, handshakeTimeout); + } + /// /// Configure Kestrel to use HTTPS. /// /// The to configure. /// Callback to configure HTTPS options. - /// State for the . + /// State for the . /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive and finite. /// The . internal static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsOptionsCallback httpsOptionsCallback, object state, TimeSpan handshakeTimeout) diff --git a/src/Servers/Kestrel/samples/SampleApp/Startup.cs b/src/Servers/Kestrel/samples/SampleApp/Startup.cs index baed7b92b3b1..25d7eca28367 100644 --- a/src/Servers/Kestrel/samples/SampleApp/Startup.cs +++ b/src/Servers/Kestrel/samples/SampleApp/Startup.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.IO; using System.Net; +using System.Net.Security; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; @@ -109,15 +110,21 @@ public static Task Main(string[] args) options.ListenAnyIP(basePort + 5, listenOptions => { - listenOptions.UseHttps(httpsOptions => + var localhostCert = CertificateLoader.LoadFromStoreCert("localhost", "My", StoreLocation.CurrentUser, allowInvalid: true); + + listenOptions.UseHttps((stream, clientHelloInfo, state, cancellationToken) => { - var localhostCert = CertificateLoader.LoadFromStoreCert("localhost", "My", StoreLocation.CurrentUser, allowInvalid: true); - httpsOptions.ServerCertificateSelector = (features, name) => + // Here you would check the name, select an appropriate cert, and provide a fallback or fail for null names. + if (clientHelloInfo.ServerName != null && clientHelloInfo.ServerName != "localhost") { - // Here you would check the name, select an appropriate cert, and provide a fallback or fail for null names. - return localhostCert; - }; - }); + throw new AuthenticationException($"The endpoint is not configured for sever name '{clientHelloInfo.ServerName}'."); + } + + return new ValueTask(new SslServerAuthenticationOptions + { + ServerCertificate = localhostCert + }); + }, state: null); }); options diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs index 0b65712c5efd..6ad3338e2aff 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs @@ -123,6 +123,42 @@ void ConfigureListenOptions(ListenOptions listenOptions) } } + [Fact] + public async Task HandshakeDetailsAreAvailableAfterAsyncCallback() + { + void ConfigureListenOptions(ListenOptions listenOptions) + { + listenOptions.UseHttps(async (stream, clientHelloInfo, state, cancellationToken) => + { + await Task.Yield(); + + return new SslServerAuthenticationOptions + { + ServerCertificate = _x509Certificate2, + }; + }, state: null); + } + + await using (var server = new TestServer(context => + { + var tlsFeature = context.Features.Get(); + Assert.NotNull(tlsFeature); + Assert.True(tlsFeature.Protocol > SslProtocols.None, "Protocol"); + Assert.True(tlsFeature.CipherAlgorithm > CipherAlgorithmType.Null, "Cipher"); + Assert.True(tlsFeature.CipherStrength > 0, "CipherStrength"); + Assert.True(tlsFeature.HashAlgorithm >= HashAlgorithmType.None, "HashAlgorithm"); // May be None on Linux. + Assert.True(tlsFeature.HashStrength >= 0, "HashStrength"); // May be 0 for some algorithms + Assert.True(tlsFeature.KeyExchangeAlgorithm >= ExchangeAlgorithmType.None, "KeyExchangeAlgorithm"); // Maybe None on Windows 7 + Assert.True(tlsFeature.KeyExchangeStrength >= 0, "KeyExchangeStrength"); // May be 0 on mac + + return context.Response.WriteAsync("hello world"); + }, new TestServiceContext(LoggerFactory), ConfigureListenOptions)) + { + var result = await server.HttpClientSlim.GetStringAsync($"https://localhost:{server.Port}/", validateCertificate: false); + Assert.Equal("hello world", result); + } + } + [Fact] public async Task RequireCertificateFailsWhenNoCertificate() { @@ -166,22 +202,18 @@ void ConfigureListenOptions(ListenOptions listenOptions) } [Fact] - [QuarantinedTest("https://github.com/dotnet/runtime/issues/40402")] - public async Task ClientCertificateRequiredConfiguredInCallbackContinuesWhenNoCertificate() + public async Task AsyncCallbackSettingClientCertificateRequiredContinuesWhenNoCertificate() { void ConfigureListenOptions(ListenOptions listenOptions) { - listenOptions.UseHttps((connection, stream, clientHelloInfo, state, cancellationToken) => + listenOptions.UseHttps((stream, clientHelloInfo, state, cancellationToken) => new ValueTask(new SslServerAuthenticationOptions { ServerCertificate = _x509Certificate2, - // From the API Docs: "Note that this is only a request -- - // if no certificate is provided, the server still accepts the connection request." - // Not to mention this is equivalent to the test above. ClientCertificateRequired = true, RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true, CertificateRevocationCheckMode = X509RevocationMode.NoCheck - }), state: null, HttpsConnectionAdapterOptions.DefaultHandshakeTimeout); + }), state: null); } await using (var server = new TestServer(context => @@ -255,6 +287,39 @@ void ConfigureListenOptions(ListenOptions listenOptions) } } + [Fact] + public async Task UsesProvidedAsyncCallback() + { + var selectorCalled = 0; + void ConfigureListenOptions(ListenOptions listenOptions) + { + listenOptions.UseHttps(async (stream, clientHelloInfo, state, cancellationToken) => + { + await Task.Yield(); + + Assert.NotNull(stream); + Assert.Equal("localhost", clientHelloInfo.ServerName); + selectorCalled++; + + return new SslServerAuthenticationOptions + { + ServerCertificate = _x509Certificate2 + }; + }, state: null); + } + + await using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), ConfigureListenOptions)) + { + using (var connection = server.CreateConnection()) + { + var stream = OpenSslStream(connection.Stream); + await stream.AuthenticateAsClientAsync("localhost"); + Assert.True(stream.RemoteCertificate.Equals(_x509Certificate2)); + Assert.Equal(1, selectorCalled); + } + } + } + [Fact] public async Task UsesProvidedServerCertificateSelectorEachTime() { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsTests.cs index 4f2f7cde9190..3ef0c16d2a0e 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.IO; using System.Net.Security; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -13,7 +12,6 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; @@ -21,13 +19,14 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests { public class HttpsTests : LoggedTest { + private static X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate(); + private KestrelServerOptions CreateServerOptions() { var serverOptions = new KestrelServerOptions(); @@ -41,8 +40,7 @@ private KestrelServerOptions CreateServerOptions() public void UseHttpsDefaultsToDefaultCert() { var serverOptions = CreateServerOptions(); - var defaultCert = TestResources.GetTestCertificate(); - serverOptions.DefaultCertificate = defaultCert; + serverOptions.DefaultCertificate = _x509Certificate2; serverOptions.ListenLocalhost(5000, options => { @@ -62,22 +60,51 @@ public void UseHttpsDefaultsToDefaultCert() Assert.False(serverOptions.IsDevCertLoaded); } + [Fact] + public async Task UseHttpsWithAsyncCallbackDoeNotFallBackToDefaultCert() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + + var testContext = new TestServiceContext(LoggerFactory); + + await using (var server = new TestServer(context => Task.CompletedTask, + testContext, + listenOptions => + { + listenOptions.UseHttps((stream, clientHelloInfo, state, cancellationToken) => + new ValueTask(new SslServerAuthenticationOptions()), state: null); + })) + { + using (var connection = server.CreateConnection()) + using (var sslStream = new SslStream(connection.Stream, true, (sender, certificate, chain, errors) => true)) + { + var ex = await Assert.ThrowsAnyAsync(() => + sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null, + enabledSslProtocols: SslProtocols.Tls, + checkCertificateRevocation: false)); + } + } + + var errorException = Assert.Single(loggerProvider.ErrorLogger.ErrorExceptions); + Assert.IsType(errorException); + } + [Fact] public void ConfigureHttpsDefaultsNeverLoadsDefaultCert() { var serverOptions = CreateServerOptions(); - var testCert = TestResources.GetTestCertificate(); serverOptions.ConfigureHttpsDefaults(options => { Assert.Null(options.ServerCertificate); - options.ServerCertificate = testCert; + options.ServerCertificate = _x509Certificate2; options.ClientCertificateMode = ClientCertificateMode.RequireCertificate; }); serverOptions.ListenLocalhost(5000, options => { options.UseHttps(opt => { - Assert.Equal(testCert, opt.ServerCertificate); + Assert.Equal(_x509Certificate2, opt.ServerCertificate); Assert.Equal(ClientCertificateMode.RequireCertificate, opt.ClientCertificateMode); }); }); @@ -90,14 +117,13 @@ public void ConfigureHttpsDefaultsNeverLoadsDefaultCert() public void ConfigureCertSelectorNeverLoadsDefaultCert() { var serverOptions = CreateServerOptions(); - var testCert = TestResources.GetTestCertificate(); serverOptions.ConfigureHttpsDefaults(options => { Assert.Null(options.ServerCertificate); Assert.Null(options.ServerCertificateSelector); options.ServerCertificateSelector = (features, name) => { - return testCert; + return _x509Certificate2; }; options.ClientCertificateMode = ClientCertificateMode.RequireCertificate; }); @@ -126,7 +152,7 @@ public async Task EmptyRequestLoggedAsDebug() new TestServiceContext(LoggerFactory), listenOptions => { - listenOptions.UseHttps(TestResources.GetTestCertificate()); + listenOptions.UseHttps(_x509Certificate2); })) { using (var connection = server.CreateConnection()) @@ -139,7 +165,7 @@ public async Task EmptyRequestLoggedAsDebug() Assert.Equal(1, loggerProvider.FilterLogger.LastEventId.Id); Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); - Assert.True(loggerProvider.ErrorLogger.TotalErrorsLogged == 0, + Assert.True(loggerProvider.ErrorLogger.ErrorMessages.Count == 0, userMessage: string.Join(Environment.NewLine, loggerProvider.ErrorLogger.ErrorMessages)); } @@ -154,7 +180,7 @@ public async Task ClientHandshakeFailureLoggedAsDebug() new TestServiceContext(LoggerFactory), listenOptions => { - listenOptions.UseHttps(TestResources.GetTestCertificate()); + listenOptions.UseHttps(_x509Certificate2); })) { using (var connection = server.CreateConnection()) @@ -168,7 +194,7 @@ public async Task ClientHandshakeFailureLoggedAsDebug() Assert.Equal(1, loggerProvider.FilterLogger.LastEventId.Id); Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); - Assert.True(loggerProvider.ErrorLogger.TotalErrorsLogged == 0, + Assert.True(loggerProvider.ErrorLogger.ErrorMessages.Count == 0, userMessage: string.Join(Environment.NewLine, loggerProvider.ErrorLogger.ErrorMessages)); } @@ -198,7 +224,7 @@ public async Task DoesNotThrowObjectDisposedExceptionOnConnectionAbort() new TestServiceContext(LoggerFactory), listenOptions => { - listenOptions.UseHttps(TestResources.GetTestCertificate()); + listenOptions.UseHttps(_x509Certificate2); })) { using (var connection = server.CreateConnection()) @@ -242,7 +268,7 @@ public async Task DoesNotThrowObjectDisposedExceptionFromWriteAsyncAfterConnecti new TestServiceContext(LoggerFactory), listenOptions => { - listenOptions.UseHttps(TestResources.GetTestCertificate()); + listenOptions.UseHttps(_x509Certificate2); })) { using (var connection = server.CreateConnection()) @@ -273,7 +299,7 @@ public async Task DoesNotThrowObjectDisposedExceptionOnEmptyConnection() new TestServiceContext(LoggerFactory), listenOptions => { - listenOptions.UseHttps(TestResources.GetTestCertificate()); + listenOptions.UseHttps(_x509Certificate2); })) { using (var connection = server.CreateConnection()) @@ -299,7 +325,7 @@ public async Task ConnectionFilterDoesNotLeakBlock() new TestServiceContext(LoggerFactory), listenOptions => { - listenOptions.UseHttps(TestResources.GetTestCertificate()); + listenOptions.UseHttps(_x509Certificate2); })) { using (var connection = server.CreateConnection()) @@ -316,10 +342,6 @@ public async Task HandshakeTimesOutAndIsLoggedAsDebug() LoggerFactory.AddProvider(loggerProvider); var testContext = new TestServiceContext(LoggerFactory); - var heartbeatManager = new HeartbeatManager(testContext.ConnectionManager); - - var handshakeStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - TimeSpan handshakeTimeout = default; await using (var server = new TestServer(context => Task.CompletedTask, testContext, @@ -327,26 +349,47 @@ public async Task HandshakeTimesOutAndIsLoggedAsDebug() { listenOptions.UseHttps(o => { - o.ServerCertificate = new X509Certificate2(TestResources.GetTestCertificate()); - o.OnAuthenticate = (_, __) => - { - handshakeStartedTcs.SetResult(); - }; - - handshakeTimeout = o.HandshakeTimeout; + o.ServerCertificate = new X509Certificate2(_x509Certificate2); + o.HandshakeTimeout = TimeSpan.FromMilliseconds(100); }); })) { using (var connection = server.CreateConnection()) { - // HttpsConnectionAdapter dispatches via Task.Run() before starting the handshake. - // Wait for the handshake to start before advancing the system clock. - await handshakeStartedTcs.Task.DefaultTimeout(); + Assert.Equal(0, await connection.Stream.ReadAsync(new byte[1], 0, 1).DefaultTimeout()); + } + } + + await loggerProvider.FilterLogger.LogTcs.Task.DefaultTimeout(); + Assert.Equal(2, loggerProvider.FilterLogger.LastEventId); + Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); + } - // Min amount of time between requests that triggers a handshake timeout. - testContext.MockSystemClock.UtcNow += handshakeTimeout + Heartbeat.Interval + TimeSpan.FromTicks(1); - heartbeatManager.OnHeartbeat(testContext.SystemClock.UtcNow); + [Fact] + public async Task HandshakeTimesOutAndIsLoggedAsDebugWithAsyncCallback() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + + var testContext = new TestServiceContext(LoggerFactory); + + await using (var server = new TestServer(context => Task.CompletedTask, + testContext, + listenOptions => + { + listenOptions.UseHttps(async (stream, clientHelloInfo, state, cancellationToken) => + { + await Task.Yield(); + return new SslServerAuthenticationOptions + { + ServerCertificate = _x509Certificate2, + }; + }, state: null, handshakeTimeout: TimeSpan.FromMilliseconds(100)); + })) + { + using (var connection = server.CreateConnection()) + { Assert.Equal(0, await connection.Stream.ReadAsync(new byte[1], 0, 1).DefaultTimeout()); } } @@ -394,7 +437,7 @@ public async Task OnAuthenticate_SeesOtherSettings() var loggerProvider = new HandshakeErrorLoggerProvider(); LoggerFactory.AddProvider(loggerProvider); - var testCert = TestResources.GetTestCertificate(); + var testCert = _x509Certificate2; var onAuthenticateCalled = false; await using (var server = new TestServer(context => Task.CompletedTask, @@ -430,7 +473,7 @@ public async Task OnAuthenticate_CanSetSettings() var loggerProvider = new HandshakeErrorLoggerProvider(); LoggerFactory.AddProvider(loggerProvider); - var testCert = TestResources.GetTestCertificate(); + var testCert = _x509Certificate2; var onAuthenticateCalled = false; await using (var server = new TestServer(context => Task.CompletedTask, @@ -511,11 +554,8 @@ public IDisposable BeginScope(TState state) private class ApplicationErrorLogger : ILogger { - private List _errorMessages = new List(); - - public IEnumerable ErrorMessages => _errorMessages; - - public int TotalErrorsLogged => _errorMessages.Count; + public List ErrorMessages => new List(); + public List ErrorExceptions { get; } = new List(); public bool ObjectDisposedExceptionLogged { get; set; } @@ -524,7 +564,12 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except if (logLevel == LogLevel.Error) { var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}"; - _errorMessages.Add(log); + ErrorMessages.Add(log); + + if (exception != null) + { + ErrorExceptions.Add(exception); + } } if (exception is ObjectDisposedException)