diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs index 8ecb2c0ced09..bd8e6681511c 100644 --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs @@ -20,7 +20,8 @@ internal static partial class Interop { internal static partial class OpenSsl { - private static Ssl.SslCtxSetVerifyCallback s_verifyClientCertificate = VerifyClientCertificate; + private static readonly Ssl.SslCtxSetVerifyCallback s_verifyClientCertificate = VerifyClientCertificate; + private static readonly Ssl.SslCtxSetAlpnCallback s_alpnServerCallback = AlpnServerSelectCallback; #region internal methods @@ -47,7 +48,7 @@ internal static SafeChannelBindingHandle QueryChannelBinding(SafeSslHandle conte return bindingHandle; } - internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX509Handle certHandle, SafeEvpPKeyHandle certKeyHandle, EncryptionPolicy policy, bool isServer, bool remoteCertRequired) + internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX509Handle certHandle, SafeEvpPKeyHandle certKeyHandle, EncryptionPolicy policy, SslAuthenticationOptions sslAuthenticationOptions) { SafeSslHandle context = null; @@ -88,17 +89,32 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50 SetSslCertificate(innerContext, certHandle, certKeyHandle); } - if (remoteCertRequired) + if (sslAuthenticationOptions.IsServer && sslAuthenticationOptions.RemoteCertRequired) { - Debug.Assert(isServer, "isServer flag should be true"); - Ssl.SslCtxSetVerify(innerContext, - s_verifyClientCertificate); + Ssl.SslCtxSetVerify(innerContext, s_verifyClientCertificate); //update the client CA list UpdateCAListFromRootStore(innerContext); } - context = SafeSslHandle.Create(innerContext, isServer); + if (sslAuthenticationOptions.ApplicationProtocols != null) + { + if (sslAuthenticationOptions.IsServer) + { + byte[] protos = Interop.Ssl.ConvertAlpnProtocolListToByteArray(sslAuthenticationOptions.ApplicationProtocols); + sslAuthenticationOptions.AlpnProtocolsHandle = GCHandle.Alloc(protos); + Interop.Ssl.SslCtxSetAlpnSelectCb(innerContext, s_alpnServerCallback, GCHandle.ToIntPtr(sslAuthenticationOptions.AlpnProtocolsHandle)); + } + else + { + if (Interop.Ssl.SslCtxSetAlpnProtos(innerContext, sslAuthenticationOptions.ApplicationProtocols) != 0) + { + throw CreateSslException(SR.net_alpn_config_failed); + } + } + } + + context = SafeSslHandle.Create(innerContext, sslAuthenticationOptions.IsServer); Debug.Assert(context != null, "Expected non-null return value from SafeSslHandle.Create"); if (context.IsInvalid) { @@ -314,6 +330,18 @@ private static int VerifyClientCertificate(int preverify_ok, IntPtr x509_ctx_ptr return OpenSslSuccess; } + private static unsafe int AlpnServerSelectCallback(IntPtr ssl, out IntPtr outp, out byte outlen, IntPtr inp, uint inlen, IntPtr arg) + { + GCHandle protocols = GCHandle.FromIntPtr(arg); + byte[] server = (byte[])protocols.Target; + + fixed (byte* sp = server) + { + return Interop.Ssl.SslSelectNextProto(out outp, out outlen, (IntPtr)sp, (uint)server.Length, inp, inlen) == Interop.Ssl.OPENSSL_NPN_NEGOTIATED ? + Interop.Ssl.SSL_TLSEXT_ERR_OK : Interop.Ssl.SSL_TLSEXT_ERR_NOACK; + } + } + private static void UpdateCAListFromRootStore(SafeSslContextHandle context) { using (SafeX509NameStackHandle nameStack = Crypto.NewX509NameStack()) diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs index c468e1a34ad0..2828856c7e14 100644 --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs @@ -12,6 +12,10 @@ internal static partial class Interop { internal static partial class Ssl { + internal const int SSL_TLSEXT_ERR_OK = 0; + internal const int OPENSSL_NPN_NEGOTIATED = 1; + internal const int SSL_TLSEXT_ERR_NOACK = 3; + internal delegate int SslCtxSetVerifyCallback(int preverify_ok, IntPtr x509_ctx); [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_EnsureLibSslInitialized")] @@ -44,6 +48,26 @@ internal static partial class Ssl [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetVersion")] private static extern IntPtr SslGetVersion(SafeSslHandle ssl); + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSelectNextProto")] + internal static extern int SslSelectNextProto(out IntPtr outp, out byte outlen, IntPtr server, uint serverlen, IntPtr client, uint clientlen); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGet0AlpnSelected")] + internal static extern void SslGetAlpnSelected(SafeSslHandle ssl, out IntPtr protocol, out int len); + + internal static byte[] SslGetAlpnSelected(SafeSslHandle ssl) + { + IntPtr protocol; + int len; + SslGetAlpnSelected(ssl, out protocol, out len); + + if (len == 0) + return null; + + byte[] result = new byte[len]; + Marshal.Copy(protocol, result, 0, len); + return result; + } + internal static string GetProtocolVersion(SafeSslHandle ssl) { return Marshal.PtrToStringAnsi(SslGetVersion(ssl)); @@ -156,12 +180,12 @@ internal enum SslErrorCode SSL_ERROR_WANT_WRITE = 3, SSL_ERROR_SYSCALL = 5, SSL_ERROR_ZERO_RETURN = 6, - + // NOTE: this SslErrorCode value doesn't exist in OpenSSL, but // we use it to distinguish when a renegotiation is pending. // Choosing an arbitrarily large value that shouldn't conflict // with any actual OpenSSL error codes - SSL_ERROR_RENEGOTIATE = 29304 + SSL_ERROR_RENEGOTIATE = 29304 } } } diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs index 7512efe60915..229f31a74e76 100644 --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs @@ -3,7 +3,10 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Net.Security; using System.Runtime.InteropServices; +using System.Text; using Microsoft.Win32.SafeHandles; internal static partial class Interop @@ -12,6 +15,7 @@ internal static partial class Ssl { internal delegate int AppVerifyCallback(IntPtr storeCtx, IntPtr arg); internal delegate int ClientCertCallback(IntPtr ssl, out IntPtr x509, out IntPtr pkey); + internal delegate int SslCtxSetAlpnCallback(IntPtr ssl, out IntPtr outp, out byte outlen, IntPtr inp, uint inlen, IntPtr arg); [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxCreate")] internal static extern SafeSslContextHandle SslCtxCreate(IntPtr method); @@ -24,6 +28,46 @@ internal static partial class Ssl [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetClientCertCallback")] internal static extern void SslCtxSetClientCertCallback(IntPtr ctx, ClientCertCallback callback); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetAlpnProtos")] + internal static extern int SslCtxSetAlpnProtos(SafeSslContextHandle ctx, IntPtr protos, int len); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetAlpnSelectCb")] + internal static unsafe extern void SslCtxSetAlpnSelectCb(SafeSslContextHandle ctx, SslCtxSetAlpnCallback callback, IntPtr arg); + + internal static unsafe int SslCtxSetAlpnProtos(SafeSslContextHandle ctx, List protocols) + { + byte[] buffer = ConvertAlpnProtocolListToByteArray(protocols); + fixed (byte* b = buffer) + { + return SslCtxSetAlpnProtos(ctx, (IntPtr)b, buffer.Length); + } + } + + internal static byte[] ConvertAlpnProtocolListToByteArray(List applicationProtocols) + { + int protocolSize = 0; + foreach (SslApplicationProtocol protocol in applicationProtocols) + { + if (protocol.Protocol.Length == 0 || protocol.Protocol.Length > byte.MaxValue) + { + throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols)); + } + + protocolSize += protocol.Protocol.Length + 1; + } + + byte[] buffer = new byte[protocolSize]; + var offset = 0; + foreach (SslApplicationProtocol protocol in applicationProtocols) + { + buffer[offset++] = (byte)(protocol.Protocol.Length); + protocol.Protocol.Span.CopyTo(new Span(buffer).Slice(offset)); + offset += protocol.Protocol.Length; + } + + return buffer; + } } } diff --git a/src/Common/src/Interop/Windows/SChannel/Interop.SECURITY_STATUS.cs b/src/Common/src/Interop/Windows/SChannel/Interop.SECURITY_STATUS.cs index 5d27bddd51c7..435de74c59a9 100644 --- a/src/Common/src/Interop/Windows/SChannel/Interop.SECURITY_STATUS.cs +++ b/src/Common/src/Interop/Windows/SChannel/Interop.SECURITY_STATUS.cs @@ -48,7 +48,8 @@ internal enum SECURITY_STATUS SmartcardLogonRequired = unchecked((int)0x8009033E), UnsupportedPreauth = unchecked((int)0x80090343), BadBinding = unchecked((int)0x80090346), - DowngradeDetected = unchecked((int)0x80090350) + DowngradeDetected = unchecked((int)0x80090350), + ApplicationProtocolMismatch = unchecked((int)0x80090367), } #if TRACE_VERBOSE diff --git a/src/Common/src/Interop/Windows/SChannel/Interop.SecPkgContext_ApplicationProtocol.cs b/src/Common/src/Interop/Windows/SChannel/Interop.SecPkgContext_ApplicationProtocol.cs new file mode 100644 index 000000000000..447afb98dbbe --- /dev/null +++ b/src/Common/src/Interop/Windows/SChannel/Interop.SecPkgContext_ApplicationProtocol.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal enum ApplicationProtocolNegotiationStatus + { + None = 0, + Success, + SelectedClientOnly + } + + internal enum ApplicationProtocolNegotiationExt + { + None = 0, + NPN, + ALPN + } + + [StructLayout(LayoutKind.Sequential)] + internal class SecPkgContext_ApplicationProtocol + { + private const int MaxProtocolIdSize = 0xFF; + + public ApplicationProtocolNegotiationStatus ProtoNegoStatus; + public ApplicationProtocolNegotiationExt ProtoNegoExt; + public byte ProtocolIdSize; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = MaxProtocolIdSize)] + public byte[] ProtocolId; + public byte[] Protocol + { + get + { + return new Span(ProtocolId, 0, ProtocolIdSize).ToArray(); + } + } + } +} diff --git a/src/Common/src/Interop/Windows/SChannel/Interop.Sec_Application_Protocols.cs b/src/Common/src/Interop/Windows/SChannel/Interop.Sec_Application_Protocols.cs new file mode 100644 index 000000000000..a18dac2ad949 --- /dev/null +++ b/src/Common/src/Interop/Windows/SChannel/Interop.Sec_Application_Protocols.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Security; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + [StructLayout(LayoutKind.Sequential, Pack = 1)] + internal struct Sec_Application_Protocols + { + private static readonly int ProtocolListOffset = Marshal.SizeOf(); + private static readonly int ProtocolListConstSize = ProtocolListOffset - (int)Marshal.OffsetOf(nameof(ProtocolExtenstionType)); + public uint ProtocolListsSize; + public ApplicationProtocolNegotiationExt ProtocolExtenstionType; + public short ProtocolListSize; + + public static unsafe byte[] ToByteArray(List applicationProtocols) + { + long protocolListSize = 0; + for (int i = 0; i < applicationProtocols.Count; i++) + { + if (applicationProtocols[i].Protocol.Length == 0 || applicationProtocols[i].Protocol.Length > byte.MaxValue) + { + throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols)); + } + + protocolListSize += applicationProtocols[i].Protocol.Length + 1; + + if (protocolListSize > short.MaxValue) + { + throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols)); + } + } + + Sec_Application_Protocols protocols = new Sec_Application_Protocols(); + protocols.ProtocolListsSize = (uint)(ProtocolListConstSize + protocolListSize); + protocols.ProtocolExtenstionType = ApplicationProtocolNegotiationExt.ALPN; + protocols.ProtocolListSize = (short)protocolListSize; + + Span pBuffer = new byte[protocolListSize]; + int index = 0; + for (int i = 0; i < applicationProtocols.Count; i++) + { + pBuffer[index++] = (byte)applicationProtocols[i].Protocol.Length; + applicationProtocols[i].Protocol.Span.CopyTo(pBuffer.Slice(index)); + index += applicationProtocols[i].Protocol.Length; + } + + byte[] buffer = new byte[ProtocolListOffset + protocolListSize]; + fixed (byte* bufferPtr = buffer) + { + Marshal.StructureToPtr(protocols, new IntPtr(bufferPtr), false); + byte* pList = bufferPtr + ProtocolListOffset; + pBuffer.CopyTo(new Span(pList, index)); + } + + return buffer; + } + } +} diff --git a/src/Common/src/Interop/Windows/sspicli/Interop.SSPI.cs b/src/Common/src/Interop/Windows/sspicli/Interop.SSPI.cs index a08b49dba7b6..99457c511bff 100644 --- a/src/Common/src/Interop/Windows/sspicli/Interop.SSPI.cs +++ b/src/Common/src/Interop/Windows/sspicli/Interop.SSPI.cs @@ -54,6 +54,7 @@ internal enum ContextAttribute SECPKG_ATTR_UNIQUE_BINDINGS = 25, SECPKG_ATTR_ENDPOINT_BINDINGS = 26, SECPKG_ATTR_CLIENT_SPECIFIED_TARGET = 27, + SECPKG_ATTR_APPLICATION_PROTOCOL = 35, // minschannel.h SECPKG_ATTR_REMOTE_CERT_CONTEXT = 0x53, // returns PCCERT_CONTEXT diff --git a/src/Common/src/Interop/Windows/sspicli/SSPIWrapper.cs b/src/Common/src/Interop/Windows/sspicli/SSPIWrapper.cs index 16e3f6231dca..ab47ef534411 100644 --- a/src/Common/src/Interop/Windows/sspicli/SSPIWrapper.cs +++ b/src/Common/src/Interop/Windows/sspicli/SSPIWrapper.cs @@ -476,6 +476,10 @@ public static object QueryContextAttributes(SSPIInterface secModule, SafeDeleteC nativeBlockSize = Marshal.SizeOf(); break; + case Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL: + nativeBlockSize = Marshal.SizeOf(); + break; + default: throw new ArgumentException(SR.Format(SR.net_invalid_enum, nameof(contextAttribute)), nameof(contextAttribute)); } @@ -540,6 +544,17 @@ public static object QueryContextAttributes(SSPIInterface secModule, SafeDeleteC case Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CONNECTION_INFO: attribute = new SecPkgContext_ConnectionInfo(nativeBuffer); break; + + case Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL: + unsafe + { + fixed (void *ptr = nativeBuffer) + { + attribute = Marshal.PtrToStructure(new IntPtr(ptr)); + } + } + break; + default: // Will return null. break; diff --git a/src/Common/src/System/Net/Security/Unix/SafeDeleteSslContext.cs b/src/Common/src/System/Net/Security/Unix/SafeDeleteSslContext.cs index 5941fa56854f..ded5784ae878 100644 --- a/src/Common/src/System/Net/Security/Unix/SafeDeleteSslContext.cs +++ b/src/Common/src/System/Net/Security/Unix/SafeDeleteSslContext.cs @@ -5,6 +5,7 @@ using Microsoft.Win32.SafeHandles; using System.Diagnostics; +using System.Net.Security; using System.Runtime.InteropServices; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; @@ -25,7 +26,7 @@ public SafeSslHandle SslContext } } - public SafeDeleteSslContext(SafeFreeSslCredentials credential, bool isServer, bool remoteCertRequired) + public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthenticationOptions sslAuthenticationOptions) : base(credential) { Debug.Assert((null != credential) && !credential.IsInvalid, "Invalid credential used in SafeDeleteSslContext"); @@ -37,8 +38,7 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, bool isServer, bo credential.CertHandle, credential.CertKeyHandle, credential.Policy, - isServer, - remoteCertRequired); + sslAuthenticationOptions); } catch(Exception ex) { diff --git a/src/Common/src/System/Net/SecurityStatusAdapterPal.Windows.cs b/src/Common/src/System/Net/SecurityStatusAdapterPal.Windows.cs index f24ff1e18245..b7ce1e6b1eff 100644 --- a/src/Common/src/System/Net/SecurityStatusAdapterPal.Windows.cs +++ b/src/Common/src/System/Net/SecurityStatusAdapterPal.Windows.cs @@ -10,7 +10,7 @@ namespace System.Net { internal static class SecurityStatusAdapterPal { - private const int StatusDictionarySize = 40; + private const int StatusDictionarySize = 41; #if DEBUG static SecurityStatusAdapterPal() @@ -22,6 +22,7 @@ static SecurityStatusAdapterPal() private static readonly BidirectionalDictionary s_statusDictionary = new BidirectionalDictionary(StatusDictionarySize) { { Interop.SECURITY_STATUS.AlgorithmMismatch, SecurityStatusPalErrorCode.AlgorithmMismatch }, + { Interop.SECURITY_STATUS.ApplicationProtocolMismatch, SecurityStatusPalErrorCode.ApplicationProtocolMismatch }, { Interop.SECURITY_STATUS.BadBinding, SecurityStatusPalErrorCode.BadBinding }, { Interop.SECURITY_STATUS.BufferNotEnough, SecurityStatusPalErrorCode.BufferNotEnough }, { Interop.SECURITY_STATUS.CannotInstall, SecurityStatusPalErrorCode.CannotInstall }, diff --git a/src/Common/src/System/Net/SecurityStatusPal.cs b/src/Common/src/System/Net/SecurityStatusPal.cs index 3ebb6330716f..af113e824880 100644 --- a/src/Common/src/System/Net/SecurityStatusPal.cs +++ b/src/Common/src/System/Net/SecurityStatusPal.cs @@ -67,6 +67,7 @@ internal enum SecurityStatusPalErrorCode SmartcardLogonRequired, UnsupportedPreauth, BadBinding, - DowngradeDetected + DowngradeDetected, + ApplicationProtocolMismatch } } diff --git a/src/Native/Unix/System.Security.Cryptography.Native/configure.cmake b/src/Native/Unix/System.Security.Cryptography.Native/configure.cmake index 3560657985d2..809ffe318e29 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/configure.cmake +++ b/src/Native/Unix/System.Security.Cryptography.Native/configure.cmake @@ -13,6 +13,10 @@ check_function_exists( EC_GF2m_simple_method HAVE_OPENSSL_EC2M) +check_function_exists( + SSL_get0_alpn_selected + HAVE_OPENSSL_ALPN) + configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/pal_crypto_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/pal_crypto_config.h) diff --git a/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h b/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h index c77d4bc0f6f8..ee211be126b3 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h +++ b/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h @@ -50,6 +50,20 @@ int EC_POINT_set_affine_coordinates_GF2m(const EC_GROUP *group, EC_POINT *p, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx); #endif +#if !HAVE_OPENSSL_ALPN +#undef HAVE_OPENSSL_ALPN +#define HAVE_OPENSSL_ALPN 1 +int SSL_CTX_set_alpn_protos(SSL_CTX* ctx, const unsigned char* protos, unsigned int protos_len); +void SSL_CTX_set_alpn_select_cb(SSL_CTX* ctx, int (*cb) (SSL *ssl, + const unsigned char **out, + unsigned char *outlen, + const unsigned char *in, + unsigned int inlen, + void *arg), void *arg); +void SSL_get0_alpn_selected(const SSL* ssl, const unsigned char** protocol, unsigned int* len); +int32_t SSL_select_next_proto(unsigned char** out, unsigned char* outlen, const unsigned char* server, unsigned int server_len, const unsigned char* client, unsigned int client_len); +#endif + #define API_EXISTS(fn) (fn != nullptr) // List of all functions from the libssl that are used in the System.Security.Cryptography.Native. @@ -251,6 +265,8 @@ int EC_POINT_set_affine_coordinates_GF2m(const EC_GROUP *group, EC_POINT *p, PER_FUNCTION_BLOCK(SSL_CTX_ctrl, true) \ PER_FUNCTION_BLOCK(SSL_CTX_free, true) \ PER_FUNCTION_BLOCK(SSL_CTX_new, true) \ + PER_FUNCTION_BLOCK(SSL_CTX_set_alpn_protos, false) \ + PER_FUNCTION_BLOCK(SSL_CTX_set_alpn_select_cb, false) \ PER_FUNCTION_BLOCK(SSL_CTX_set_cert_verify_callback, true) \ PER_FUNCTION_BLOCK(SSL_CTX_set_cipher_list, true) \ PER_FUNCTION_BLOCK(SSL_CTX_set_client_CA_list, true) \ @@ -270,11 +286,13 @@ int EC_POINT_set_affine_coordinates_GF2m(const EC_GROUP *group, EC_POINT *p, PER_FUNCTION_BLOCK(SSL_get_peer_finished, true) \ PER_FUNCTION_BLOCK(SSL_get_SSL_CTX, true) \ PER_FUNCTION_BLOCK(SSL_get_version, true) \ + PER_FUNCTION_BLOCK(SSL_get0_alpn_selected, false) \ PER_FUNCTION_BLOCK(SSL_library_init, true) \ PER_FUNCTION_BLOCK(SSL_load_error_strings, true) \ PER_FUNCTION_BLOCK(SSL_new, true) \ PER_FUNCTION_BLOCK(SSL_read, true) \ PER_FUNCTION_BLOCK(SSL_renegotiate_pending, true) \ + PER_FUNCTION_BLOCK(SSL_select_next_proto, false) \ PER_FUNCTION_BLOCK(SSL_set_accept_state, true) \ PER_FUNCTION_BLOCK(SSL_set_bio, true) \ PER_FUNCTION_BLOCK(SSL_set_connect_state, true) \ @@ -541,6 +559,8 @@ FOR_ALL_OPENSSL_FUNCTIONS #define SSL_CTX_ctrl SSL_CTX_ctrl_ptr #define SSL_CTX_free SSL_CTX_free_ptr #define SSL_CTX_new SSL_CTX_new_ptr +#define SSL_CTX_set_alpn_protos SSL_CTX_set_alpn_protos_ptr +#define SSL_CTX_set_alpn_select_cb SSL_CTX_set_alpn_select_cb_ptr #define SSL_CTX_set_cert_verify_callback SSL_CTX_set_cert_verify_callback_ptr #define SSL_CTX_set_cipher_list SSL_CTX_set_cipher_list_ptr #define SSL_CTX_set_client_CA_list SSL_CTX_set_client_CA_list_ptr @@ -560,11 +580,13 @@ FOR_ALL_OPENSSL_FUNCTIONS #define SSL_get_peer_finished SSL_get_peer_finished_ptr #define SSL_get_SSL_CTX SSL_get_SSL_CTX_ptr #define SSL_get_version SSL_get_version_ptr +#define SSL_get0_alpn_selected SSL_get0_alpn_selected_ptr #define SSL_library_init SSL_library_init_ptr #define SSL_load_error_strings SSL_load_error_strings_ptr #define SSL_new SSL_new_ptr #define SSL_read SSL_read_ptr #define SSL_renegotiate_pending SSL_renegotiate_pending_ptr +#define SSL_select_next_proto SSL_select_next_proto_ptr #define SSL_set_accept_state SSL_set_accept_state_ptr #define SSL_set_bio SSL_set_bio_ptr #define SSL_set_connect_state SSL_set_connect_state_ptr diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_crypto_config.h.in b/src/Native/Unix/System.Security.Cryptography.Native/pal_crypto_config.h.in index 3922df324ff6..6aea13f2c6de 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_crypto_config.h.in +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_crypto_config.h.in @@ -3,3 +3,4 @@ #cmakedefine01 HAVE_TLS_V1_1 #cmakedefine01 HAVE_TLS_V1_2 #cmakedefine01 HAVE_OPENSSL_EC2M +#cmakedefine01 HAVE_OPENSSL_ALPN diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp index fda1f4da9537..dc4602437d96 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp @@ -524,3 +524,51 @@ extern "C" int32_t CryptoNative_SslAddExtraChainCert(SSL* ssl, X509* x509) return 0; } + +extern "C" int32_t CryptoNative_SslSelectNextProto(uint8_t** out, uint8_t* outlen, const uint8_t* server, uint32_t server_len, const uint8_t* client, uint32_t client_len) +{ +#ifdef HAVE_OPENSSL_ALPN + if (API_EXISTS(SSL_select_next_proto)) + { + return SSL_select_next_proto(out, outlen, server, server_len, client, client_len); + } + else +#endif + { + return -1; + } +} + +extern "C" void CryptoNative_SslCtxSetAlpnSelectCb(SSL_CTX* ctx, SslCtxSetAlpnCallback cb, void* arg) +{ +#ifdef HAVE_OPENSSL_ALPN + if (API_EXISTS(SSL_CTX_set_alpn_select_cb)) + { + SSL_CTX_set_alpn_select_cb(ctx, cb, arg); + } +#endif +} + +extern "C" int32_t CryptoNative_SslCtxSetAlpnProtos(SSL_CTX* ctx, const uint8_t* protos, uint32_t protos_len) +{ +#ifdef HAVE_OPENSSL_ALPN + if (API_EXISTS(SSL_CTX_set_alpn_protos)) + { + return SSL_CTX_set_alpn_protos(ctx, protos, protos_len); + } + else +#endif + { + return 0; + } +} + +extern "C" void CryptoNative_SslGet0AlpnSelected(SSL* ssl, const uint8_t** protocol, uint32_t* len) +{ +#ifdef HAVE_OPENSSL_ALPN + if (API_EXISTS(SSL_get0_alpn_selected)) + { + SSL_get0_alpn_selected(ssl, protocol, len); + } +#endif +} diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h index 4fd7fb59c051..10e6aca0f8b7 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h @@ -117,6 +117,14 @@ typedef int32_t (*SslCtxSetCertVerifyCallbackCallback)(X509_STORE_CTX*, void* ar // the function pointer definition for the callback used in SslCtxSetClientCertCallback typedef int32_t (*SslClientCertCallback)(SSL* ssl, X509** x509, EVP_PKEY** pkey); + +// the function pointer definition for the callback used in SslCtxSetAlpnSelectCb +typedef int32_t (*SslCtxSetAlpnCallback)(SSL* ssl, + const uint8_t** out, + uint8_t* outlen, + const uint8_t* in, + uint32_t inlen, + void* arg); /* Ensures that libssl is correctly initialized and ready to use. */ @@ -365,3 +373,25 @@ libssl frees the x509 object. Returns 1 if success and 0 in case of failure */ extern "C" int32_t CryptoNative_SslAddExtraChainCert(SSL* ssl, X509* x509); + +/* +Shims the SSL_select_next_proto method. +Returns 1 on success, 0 on failure. +*/ +extern "C" int32_t CryptoNative_SslSelectNextProto(uint8_t** out, uint8_t* outlen, const uint8_t* server, uint32_t server_len, const uint8_t* client, uint32_t client_len); + +/* +Shims the ssl_ctx_set_alpn_select_cb method. +*/ +extern "C" void CryptoNative_SslCtxSetAlpnSelectCb(SSL_CTX* ctx, SslCtxSetAlpnCallback cb, void *arg); + +/* +Shims the ssl_ctx_set_alpn_protos method. +Returns 0 on success, non-zero on failure. +*/ +extern "C" int32_t CryptoNative_SslCtxSetAlpnProtos(SSL_CTX* ctx, const uint8_t* protos, uint32_t protos_len); + +/* +Shims the ssl_get0_alpn_selected method. +*/ +extern "C" void CryptoNative_SslGet0AlpnSelected(SSL* ssl, const uint8_t** protocol, uint32_t* len); diff --git a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj index 458751a21c59..d9ef8bad4cb7 100644 --- a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj +++ b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj @@ -228,6 +228,10 @@ + + + Common\Interop\Windows\SChannel\Interop.SecPkgContext_ApplicationProtocol.cs + Common\System\Net\Security\NegotiateStreamPal.Windows.cs diff --git a/src/System.Net.Http/src/Resources/Strings.resx b/src/System.Net.Http/src/Resources/Strings.resx index f73a721dc8f6..015ce6cf0d24 100644 --- a/src/System.Net.Http/src/Resources/Strings.resx +++ b/src/System.Net.Http/src/Resources/Strings.resx @@ -387,4 +387,7 @@ An attempt was made to move the position before the beginning of the stream. + + The application protocol list is invalid. + diff --git a/src/System.Net.HttpListener/src/System.Net.HttpListener.csproj b/src/System.Net.HttpListener/src/System.Net.HttpListener.csproj index 16307ee5c9af..821d6b4c2355 100644 --- a/src/System.Net.HttpListener/src/System.Net.HttpListener.csproj +++ b/src/System.Net.HttpListener/src/System.Net.HttpListener.csproj @@ -143,6 +143,9 @@ Common\Interop\Windows\Interop.Libraries.cs + + Common\Interop\Windows\SChannel\Interop.SecPkgContext_ApplicationProtocol.cs + Common\Interop\Windows\Interop.BOOL.cs diff --git a/src/System.Net.Mail/src/System.Net.Mail.csproj b/src/System.Net.Mail/src/System.Net.Mail.csproj index 531a297f37f2..9d5a0a364e83 100644 --- a/src/System.Net.Mail/src/System.Net.Mail.csproj +++ b/src/System.Net.Mail/src/System.Net.Mail.csproj @@ -203,6 +203,9 @@ + + Common\Interop\Windows\SChannel\Interop.SecPkgContext_ApplicationProtocol.cs + Common\System\Net\Security\SecurityContextTokenHandle.cs diff --git a/src/System.Net.Security/ref/System.Net.Security.cs b/src/System.Net.Security/ref/System.Net.Security.cs index eca66781740d..b47abcf906f6 100644 --- a/src/System.Net.Security/ref/System.Net.Security.cs +++ b/src/System.Net.Security/ref/System.Net.Security.cs @@ -5,6 +5,11 @@ // Changes to this file must follow the http://aka.ms/api-review process. // ------------------------------------------------------------------------------ +using System.Collections.Generic; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; namespace System.Net.Security { @@ -94,6 +99,46 @@ public enum ProtectionLevel EncryptAndSign = 2 } public delegate bool RemoteCertificateValidationCallback(object sender, System.Security.Cryptography.X509Certificates.X509Certificate certificate, System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors); + public class SslServerAuthenticationOptions + { + public bool AllowRenegotiation { get { throw null; } set { } } + public X509Certificate ServerCertificate { get { throw null; } set { } } + public bool ClientCertificateRequired { get { throw null; } set { } } + public SslProtocols EnabledSslProtocols { get { throw null; } set { } } + public X509RevocationMode CertificateRevocationCheckMode { get { throw null; } set { } } + public List ApplicationProtocols { get { throw null; } set { } } + public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get { throw null; } set { } } + public EncryptionPolicy EncryptionPolicy { get { throw null; } set { } } + } + public partial class SslClientAuthenticationOptions + { + public bool AllowRenegotiation { get { throw null; } set { } } + public string TargetHost { get { throw null; } set { } } + public X509CertificateCollection ClientCertificates { get { throw null; } set { } } + public LocalCertificateSelectionCallback LocalCertificateSelectionCallback { get { throw null; } set { } } + public SslProtocols EnabledSslProtocols { get { throw null; } set { } } + public X509RevocationMode CertificateRevocationCheckMode { get { throw null; } set { } } + public List ApplicationProtocols { get { throw null; } set { } } + public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get { throw null; } set { } } + public EncryptionPolicy EncryptionPolicy { get { throw null; } set { } } + } + public partial struct SslApplicationProtocol : IEquatable + { + public static readonly SslApplicationProtocol Http2; + public static readonly SslApplicationProtocol Http11; + + public SslApplicationProtocol(byte[] protocol) { } + public SslApplicationProtocol(string protocol) { } + + public ReadOnlyMemory Protocol { get { throw null; } } + + public bool Equals(SslApplicationProtocol other) { throw null; } + public override bool Equals(object obj) { throw null; } + public override int GetHashCode() { throw null; } + public override string ToString() { throw null; } + public static bool operator ==(SslApplicationProtocol left, SslApplicationProtocol right) { throw null; } + public static bool operator !=(SslApplicationProtocol left, SslApplicationProtocol right) { throw null; } + } public partial class SslStream : AuthenticatedStream { public SslStream(System.IO.Stream innerStream) : base(innerStream, false) { } @@ -101,6 +146,7 @@ public SslStream(System.IO.Stream innerStream, bool leaveInnerStreamOpen) : base public SslStream(System.IO.Stream innerStream, bool leaveInnerStreamOpen, System.Net.Security.RemoteCertificateValidationCallback userCertificateValidationCallback) : base(innerStream, leaveInnerStreamOpen) { } public SslStream(System.IO.Stream innerStream, bool leaveInnerStreamOpen, System.Net.Security.RemoteCertificateValidationCallback userCertificateValidationCallback, System.Net.Security.LocalCertificateSelectionCallback userCertificateSelectionCallback) : base(innerStream, leaveInnerStreamOpen) { } public SslStream(System.IO.Stream innerStream, bool leaveInnerStreamOpen, System.Net.Security.RemoteCertificateValidationCallback userCertificateValidationCallback, System.Net.Security.LocalCertificateSelectionCallback userCertificateSelectionCallback, System.Net.Security.EncryptionPolicy encryptionPolicy) : base(innerStream, leaveInnerStreamOpen) { } + public SslApplicationProtocol NegotiatedApplicationProtocol { get { throw null; } } public override bool CanRead { get { throw null; } } public override bool CanSeek { get { throw null; } } public override bool CanTimeout { get { throw null; } } @@ -134,9 +180,11 @@ public virtual void AuthenticateAsServer(System.Security.Cryptography.X509Certif public virtual System.Threading.Tasks.Task AuthenticateAsClientAsync(string targetHost) { throw null; } public virtual System.Threading.Tasks.Task AuthenticateAsClientAsync(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection clientCertificates, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { throw null; } public virtual System.Threading.Tasks.Task AuthenticateAsClientAsync(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection clientCertificates, bool checkCertificateRevocation) { throw null; } + public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken) { throw null; } public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate) { throw null; } public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate, bool clientCertificateRequired, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { throw null; } public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) { throw null; } + public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken) { throw null; } public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.AsyncCallback asyncCallback, object asyncState) { throw null; } public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection clientCertificates, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation, System.AsyncCallback asyncCallback, object asyncState) { throw null; } public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection clientCertificates, bool checkCertificateRevocation, System.AsyncCallback asyncCallback, object asyncState) { throw null; } diff --git a/src/System.Net.Security/ref/System.Net.Security.csproj b/src/System.Net.Security/ref/System.Net.Security.csproj index 7de67fd7159f..b67b18765f05 100644 --- a/src/System.Net.Security/ref/System.Net.Security.csproj +++ b/src/System.Net.Security/ref/System.Net.Security.csproj @@ -13,6 +13,7 @@ + @@ -21,5 +22,8 @@ + + + \ No newline at end of file diff --git a/src/System.Net.Security/src/Resources/Strings.resx b/src/System.Net.Security/src/Resources/Strings.resx index fb3a593245f6..6fafaafd89d3 100644 --- a/src/System.Net.Security/src/Resources/Strings.resx +++ b/src/System.Net.Security/src/Resources/Strings.resx @@ -361,4 +361,16 @@ The '{0}' encryption policy is not supported on this platform. + + ALPN configuration failed on this platform. + + + The application protocol list is invalid. + + + The application protocol value is invalid. + + + The '{0}' option was already set in the SslStream constructor. + diff --git a/src/System.Net.Security/src/System.Net.Security.csproj b/src/System.Net.Security/src/System.Net.Security.csproj index 03f8063a06aa..4ad9991cd68b 100644 --- a/src/System.Net.Security/src/System.Net.Security.csproj +++ b/src/System.Net.Security/src/System.Net.Security.csproj @@ -22,6 +22,10 @@ + + + + @@ -161,6 +165,12 @@ Common\Interop\Windows\SChannel\SecPkgContext_ConnectionInfo.cs + + Common\Interop\Windows\SChannel\Interop.SecPkgContext_ApplicationProtocol.cs + + + Common\Interop\Windows\SChannel\Interop.Sec_Application_Protocols.cs + Common\Interop\Windows\SChannel\UnmanagedCertificateContext.cs @@ -396,6 +406,7 @@ + diff --git a/src/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs b/src/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs index bb6d29c6f8b0..2007763ab64f 100644 --- a/src/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs +++ b/src/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs @@ -22,7 +22,7 @@ internal sealed class SafeDeleteSslContext : SafeDeleteContext public SafeSslHandle SslContext => _sslContext; - public SafeDeleteSslContext(SafeFreeSslCredentials credential, bool isServer) + public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthenticationOptions sslAuthenticationOptions) : base(credential) { Debug.Assert((null != credential) && !credential.IsInvalid, "Invalid credential used in SafeDeleteSslContext"); @@ -35,7 +35,7 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, bool isServer) _writeCallback = WriteToConnection; } - _sslContext = CreateSslContext(credential, isServer); + _sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer); int osStatus = Interop.AppleCrypto.SslSetIoCallbacks( _sslContext, diff --git a/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs b/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs index 6f4cadecd2fd..2045f66284c3 100644 --- a/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs +++ b/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Globalization; using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; using System.Security; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; @@ -23,68 +24,45 @@ internal class SecureChannel private SafeFreeCredentials _credentialsHandle; private SafeDeleteContext _securityContext; - private readonly string _destination; - private readonly string _hostName; - private readonly bool _serverMode; - private readonly bool _remoteCertRequired; - private readonly SslProtocols _sslProtocols; - private readonly EncryptionPolicy _encryptionPolicy; private SslConnectionInfo _connectionInfo; - - private X509Certificate _serverCertificate; private X509Certificate _selectedClientCertificate; private bool _isRemoteCertificateAvailable; - private X509CertificateCollection _clientCertificates; - private LocalCertSelectionCallback _certSelectionDelegate; - // These are the MAX encrypt buffer output sizes, not the actual sizes. private int _headerSize = 5; //ATTN must be set to at least 5 by default private int _trailerSize = 16; private int _maxDataSize = 16354; - private bool _checkCertRevocation; - private bool _checkCertName; - private bool _refreshCredentialNeeded; + private SslAuthenticationOptions _sslAuthenticationOptions; + private SslApplicationProtocol _negotiatedApplicationProtocol; + private readonly Oid _serverAuthOid = new Oid("1.3.6.1.5.5.7.3.1", "1.3.6.1.5.5.7.3.1"); private readonly Oid _clientAuthOid = new Oid("1.3.6.1.5.5.7.3.2", "1.3.6.1.5.5.7.3.2"); - internal SecureChannel(string hostname, bool serverMode, SslProtocols sslProtocols, X509Certificate serverCertificate, X509CertificateCollection clientCertificates, bool remoteCertRequired, bool checkCertName, - bool checkCertRevocationStatus, EncryptionPolicy encryptionPolicy, LocalCertSelectionCallback certSelectionDelegate) + internal SecureChannel(SslAuthenticationOptions sslAuthenticationOptions) { if (NetEventSource.IsEnabled) { - NetEventSource.Enter(this, hostname, clientCertificates); - NetEventSource.Log.SecureChannelCtor(this, hostname, clientCertificates, encryptionPolicy); + NetEventSource.Enter(this, sslAuthenticationOptions.TargetHost, sslAuthenticationOptions.ClientCertificates); + NetEventSource.Log.SecureChannelCtor(this, sslAuthenticationOptions.TargetHost, sslAuthenticationOptions.ClientCertificates, sslAuthenticationOptions.EncryptionPolicy); } SslStreamPal.VerifyPackageInfo(); - _destination = hostname; - - if (hostname == null) + if (sslAuthenticationOptions.TargetHost == null) { - NetEventSource.Fail(this, "hostname == null"); + NetEventSource.Fail(this, "sslAuthenticationOptions.TargetHost == null"); } - _hostName = hostname; - _serverMode = serverMode; - - _sslProtocols = sslProtocols; - _serverCertificate = serverCertificate; - _clientCertificates = clientCertificates; - _remoteCertRequired = remoteCertRequired; _securityContext = null; - _checkCertRevocation = checkCertRevocationStatus; - _checkCertName = checkCertName; - _certSelectionDelegate = certSelectionDelegate; _refreshCredentialNeeded = true; - _encryptionPolicy = encryptionPolicy; - - if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + _sslAuthenticationOptions = sslAuthenticationOptions; + + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this); } // @@ -100,7 +78,7 @@ internal X509Certificate LocalServerCertificate { get { - return _serverCertificate; + return _sslAuthenticationOptions.ServerCertificate; } } @@ -122,7 +100,8 @@ internal bool IsRemoteCertificateAvailable internal ChannelBinding GetChannelBinding(ChannelBindingKind kind) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this, kind); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this, kind); ChannelBinding result = null; if (_securityContext != null) @@ -130,15 +109,16 @@ internal ChannelBinding GetChannelBinding(ChannelBindingKind kind) result = SslStreamPal.QueryContextChannelBinding(_securityContext, kind); } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, result); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, result); return result; } - internal bool CheckCertRevocationStatus + internal X509RevocationMode CheckCertRevocationStatus { get { - return _checkCertRevocation; + return _sslAuthenticationOptions.CertificateRevocationCheckMode; } } @@ -178,7 +158,7 @@ internal bool IsServer { get { - return _serverMode; + return _sslAuthenticationOptions.IsServer; } } @@ -186,7 +166,15 @@ internal bool RemoteCertRequired { get { - return _remoteCertRequired; + return _sslAuthenticationOptions.RemoteCertRequired; + } + } + + internal SslApplicationProtocol NegotiatedApplicationProtocol + { + get + { + return _negotiatedApplicationProtocol; } } @@ -197,6 +185,11 @@ internal void SetRefreshCredentialNeeded() internal void Close() { + if (_sslAuthenticationOptions.AlpnProtocolsHandle.IsAllocated) + { + _sslAuthenticationOptions.AlpnProtocolsHandle.Free(); + } + if (_securityContext != null) { _securityContext.Dispose(); @@ -219,7 +212,8 @@ private X509Certificate2 EnsurePrivateKey(X509Certificate certificate) return null; } - if (NetEventSource.IsEnabled) NetEventSource.Log.LocatingPrivateKey(certificate, this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.LocatingPrivateKey(certificate, this); try { @@ -234,7 +228,8 @@ private X509Certificate2 EnsurePrivateKey(X509Certificate certificate) { if (certEx.HasPrivateKey) { - if (NetEventSource.IsEnabled) NetEventSource.Log.CertIsType2(this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.CertIsType2(this); return certEx; } @@ -249,24 +244,26 @@ private X509Certificate2 EnsurePrivateKey(X509Certificate certificate) // ELSE Try the MY user and machine stores for private key check. // For server side mode MY machine store takes priority. - X509Store store = CertificateValidationPal.EnsureStoreOpened(_serverMode); + X509Store store = CertificateValidationPal.EnsureStoreOpened(_sslAuthenticationOptions.IsServer); if (store != null) { collectionEx = store.Certificates.Find(X509FindType.FindByThumbprint, certHash, false); if (collectionEx.Count > 0 && collectionEx[0].HasPrivateKey) { - if (NetEventSource.IsEnabled) NetEventSource.Log.FoundCertInStore(_serverMode, this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.FoundCertInStore(_sslAuthenticationOptions.IsServer, this); return collectionEx[0]; } } - store = CertificateValidationPal.EnsureStoreOpened(!_serverMode); + store = CertificateValidationPal.EnsureStoreOpened(!_sslAuthenticationOptions.IsServer); if (store != null) { collectionEx = store.Certificates.Find(X509FindType.FindByThumbprint, certHash, false); if (collectionEx.Count > 0 && collectionEx[0].HasPrivateKey) { - if (NetEventSource.IsEnabled) NetEventSource.Log.FoundCertInStore(_serverMode, this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.FoundCertInStore(_sslAuthenticationOptions.IsServer, this); return collectionEx[0]; } } @@ -275,7 +272,8 @@ private X509Certificate2 EnsurePrivateKey(X509Certificate certificate) { } - if (NetEventSource.IsEnabled) NetEventSource.Log.NotFoundCertInStore(this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.NotFoundCertInStore(this); return null; } @@ -356,7 +354,8 @@ This will not restart a session but helps minimizing the number of handles we cr private bool AcquireClientCredentials(ref byte[] thumbPrint) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); // Acquire possible Client Certificate information and set it on the handle. X509Certificate clientCertificate = null; // This is a candidate that can come from the user callback or be guessed when targeting a session restart. @@ -365,22 +364,23 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) bool sessionRestartAttempt = false; // If true and no cached creds we will use anonymous creds. - if (_certSelectionDelegate != null) + if (_sslAuthenticationOptions.CertSelectionDelegate != null) { issuers = GetRequestCertificateAuthorities(); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Calling CertificateSelectionCallback"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, "Calling CertificateSelectionCallback"); X509Certificate2 remoteCert = null; try { X509Certificate2Collection dummyCollection; remoteCert = CertificateValidationPal.GetRemoteCertificate(_securityContext, out dummyCollection); - if (_clientCertificates == null) + if (_sslAuthenticationOptions.ClientCertificates == null) { - _clientCertificates = new X509CertificateCollection(); + _sslAuthenticationOptions.ClientCertificates = new X509CertificateCollection(); } - clientCertificate = _certSelectionDelegate(_hostName, _clientCertificates, remoteCert, issuers); + clientCertificate = _sslAuthenticationOptions.CertSelectionDelegate(_sslAuthenticationOptions.TargetHost, _sslAuthenticationOptions.ClientCertificates, remoteCert, issuers); } finally { @@ -399,36 +399,40 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) } filteredCerts.Add(clientCertificate); - if (NetEventSource.IsEnabled) NetEventSource.Log.CertificateFromDelegate(this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.CertificateFromDelegate(this); } else { - if (_clientCertificates == null || _clientCertificates.Count == 0) + if (_sslAuthenticationOptions.ClientCertificates == null || _sslAuthenticationOptions.ClientCertificates.Count == 0) { - if (NetEventSource.IsEnabled) NetEventSource.Log.NoDelegateNoClientCert(this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.NoDelegateNoClientCert(this); sessionRestartAttempt = true; } else { - if (NetEventSource.IsEnabled) NetEventSource.Log.NoDelegateButClientCert(this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.NoDelegateButClientCert(this); } } } - else if (_credentialsHandle == null && _clientCertificates != null && _clientCertificates.Count > 0) + else if (_credentialsHandle == null && _sslAuthenticationOptions.ClientCertificates != null && _sslAuthenticationOptions.ClientCertificates.Count > 0) { // This is where we attempt to restart a session by picking the FIRST cert from the collection. // Otherwise it is either server sending a client cert request or the session is renegotiated. - clientCertificate = _clientCertificates[0]; + clientCertificate = _sslAuthenticationOptions.ClientCertificates[0]; sessionRestartAttempt = true; if (clientCertificate != null) { filteredCerts.Add(clientCertificate); } - if (NetEventSource.IsEnabled) NetEventSource.Log.AttemptingRestartUsingCert(clientCertificate, this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.AttemptingRestartUsingCert(clientCertificate, this); } - else if (_clientCertificates != null && _clientCertificates.Count > 0) + else if (_sslAuthenticationOptions.ClientCertificates != null && _sslAuthenticationOptions.ClientCertificates.Count > 0) { // // This should be a server request for the client cert sent over currently anonymous sessions. @@ -447,7 +451,7 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) } } - for (int i = 0; i < _clientCertificates.Count; ++i) + for (int i = 0; i < _sslAuthenticationOptions.ClientCertificates.Count; ++i) { // // Make sure we add only if the cert matches one of the issuers. @@ -459,13 +463,14 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) X509Chain chain = null; try { - certificateEx = MakeEx(_clientCertificates[i]); + certificateEx = MakeEx(_sslAuthenticationOptions.ClientCertificates[i]); if (certificateEx == null) { continue; } - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Root cert: {certificateEx}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"Root cert: {certificateEx}"); chain = new X509Chain(); @@ -485,10 +490,12 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) found = Array.IndexOf(issuers, issuer) != -1; if (found) { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Matched {issuer}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"Matched {issuer}"); break; } - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"No match: {issuer}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"No match: {issuer}"); } } @@ -504,16 +511,17 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) chain.Dispose(); } - if (certificateEx != null && (object)certificateEx != (object)_clientCertificates[i]) + if (certificateEx != null && (object)certificateEx != (object)_sslAuthenticationOptions.ClientCertificates[i]) { certificateEx.Dispose(); } } } - if (NetEventSource.IsEnabled) NetEventSource.Log.SelectedCert(_clientCertificates[i], this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.SelectedCert(_sslAuthenticationOptions.ClientCertificates[i], this); - filteredCerts.Add(_clientCertificates[i]); + filteredCerts.Add(_sslAuthenticationOptions.ClientCertificates[i]); } } @@ -555,7 +563,8 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) NetEventSource.Fail(this, "'selectedCert' does not match 'clientCertificate'."); } - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Selected cert = {selectedCert}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"Selected cert = {selectedCert}"); try { @@ -564,7 +573,7 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) // SECURITY: selectedCert ref if not null is a safe object that does not depend on possible **user** inherited X509Certificate type. // byte[] guessedThumbPrint = selectedCert == null ? null : selectedCert.GetCertHash(); - SafeFreeCredentials cachedCredentialHandle = SslSessionsCache.TryCachedCredential(guessedThumbPrint, _sslProtocols, _serverMode, _encryptionPolicy); + SafeFreeCredentials cachedCredentialHandle = SslSessionsCache.TryCachedCredential(guessedThumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy); // We can probably do some optimization here. If the selectedCert is returned by the delegate // we can always go ahead and use the certificate to create our credential @@ -574,7 +583,8 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) selectedCert != null && SslStreamPal.StartMutualAuthAsAnonymous) { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Reset to anonymous session."); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, "Reset to anonymous session."); // IIS does not renegotiate a restarted session if client cert is needed. // So we don't want to reuse **anonymous** cached credential for a new SSL connection if the client has passed some certificate. @@ -592,7 +602,8 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) if (cachedCredentialHandle != null) { - if (NetEventSource.IsEnabled) NetEventSource.Log.UsingCachedCredential(this); + if (NetEventSource.IsEnabled) + NetEventSource.Log.UsingCachedCredential(this); _credentialsHandle = cachedCredentialHandle; _selectedClientCertificate = clientCertificate; @@ -600,7 +611,7 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) } else { - _credentialsHandle = SslStreamPal.AcquireCredentialsHandle(selectedCert, _sslProtocols, _encryptionPolicy, _serverMode); + _credentialsHandle = SslStreamPal.AcquireCredentialsHandle(selectedCert, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.EncryptionPolicy, _sslAuthenticationOptions.IsServer); thumbPrint = guessedThumbPrint; // Delay until here in case something above threw. _selectedClientCertificate = clientCertificate; @@ -615,7 +626,8 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) } } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, cachedCred, _credentialsHandle); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, cachedCred, _credentialsHandle); return cachedCred; } @@ -625,21 +637,23 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) // private bool AcquireServerCredentials(ref byte[] thumbPrint) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); X509Certificate localCertificate = null; bool cachedCred = false; - if (_certSelectionDelegate != null) + if (_sslAuthenticationOptions.CertSelectionDelegate != null) { X509CertificateCollection tempCollection = new X509CertificateCollection(); - tempCollection.Add(_serverCertificate); - localCertificate = _certSelectionDelegate(string.Empty, tempCollection, null, Array.Empty()); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Use delegate selected Cert"); + tempCollection.Add(_sslAuthenticationOptions.ServerCertificate); + localCertificate = _sslAuthenticationOptions.CertSelectionDelegate(string.Empty, tempCollection, null, Array.Empty()); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, "Use delegate selected Cert"); } else { - localCertificate = _serverCertificate; + localCertificate = _sslAuthenticationOptions.ServerCertificate; } if (localCertificate == null) @@ -668,19 +682,19 @@ private bool AcquireServerCredentials(ref byte[] thumbPrint) byte[] guessedThumbPrint = selectedCert.GetCertHash(); try { - SafeFreeCredentials cachedCredentialHandle = SslSessionsCache.TryCachedCredential(guessedThumbPrint, _sslProtocols, _serverMode, _encryptionPolicy); + SafeFreeCredentials cachedCredentialHandle = SslSessionsCache.TryCachedCredential(guessedThumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy); if (cachedCredentialHandle != null) { _credentialsHandle = cachedCredentialHandle; - _serverCertificate = localCertificate; + _sslAuthenticationOptions.ServerCertificate = localCertificate; cachedCred = true; } else { - _credentialsHandle = SslStreamPal.AcquireCredentialsHandle(selectedCert, _sslProtocols, _encryptionPolicy, _serverMode); + _credentialsHandle = SslStreamPal.AcquireCredentialsHandle(selectedCert, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.EncryptionPolicy, _sslAuthenticationOptions.IsServer); thumbPrint = guessedThumbPrint; - _serverCertificate = localCertificate; + _sslAuthenticationOptions.ServerCertificate = localCertificate; } } finally @@ -692,28 +706,32 @@ private bool AcquireServerCredentials(ref byte[] thumbPrint) } } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, cachedCred, _credentialsHandle); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, cachedCred, _credentialsHandle); return cachedCred; } // internal ProtocolToken NextMessage(byte[] incoming, int offset, int count) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); byte[] nextmsg = null; SecurityStatusPal status = GenerateToken(incoming, offset, count, ref nextmsg); - if (!_serverMode && status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded) + if (!_sslAuthenticationOptions.IsServer && status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded) { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, "NextMessage() returned SecurityStatusPal.CredentialsNeeded"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, "NextMessage() returned SecurityStatusPal.CredentialsNeeded"); SetRefreshCredentialNeeded(); status = GenerateToken(incoming, offset, count, ref nextmsg); } ProtocolToken token = new ProtocolToken(nextmsg, status); - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, token); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, token); return token; } @@ -756,13 +774,10 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref if (input != null) { incomingSecurity = new SecurityBuffer(input, offset, count, SecurityBufferType.SECBUFFER_TOKEN); - incomingSecurityBuffers = new SecurityBuffer[] - { - incomingSecurity, - new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY) - }; } + incomingSecurityBuffers = SslStreamPal.GetIncomingSecurityBuffers(_sslAuthenticationOptions, ref incomingSecurity); + SecurityBuffer outgoingSecurity = new SecurityBuffer(null, SecurityBufferType.SECBUFFER_TOKEN); SecurityStatusPal status = default(SecurityStatusPal); @@ -781,45 +796,52 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref thumbPrint = null; if (_refreshCredentialNeeded) { - cachedCreds = _serverMode + cachedCreds = _sslAuthenticationOptions.IsServer ? AcquireServerCredentials(ref thumbPrint) : AcquireClientCredentials(ref thumbPrint); } - if (_serverMode) + if (_sslAuthenticationOptions.IsServer) { status = SslStreamPal.AcceptSecurityContext( ref _credentialsHandle, ref _securityContext, - incomingSecurity, + incomingSecurityBuffers, outgoingSecurity, - _remoteCertRequired); + _sslAuthenticationOptions); } else { - if (incomingSecurity == null) + if (incomingSecurityBuffers == null) { status = SslStreamPal.InitializeSecurityContext( ref _credentialsHandle, ref _securityContext, - _destination, + _sslAuthenticationOptions.TargetHost, incomingSecurity, - outgoingSecurity); + outgoingSecurity, + _sslAuthenticationOptions); } else { status = SslStreamPal.InitializeSecurityContext( _credentialsHandle, ref _securityContext, - _destination, + _sslAuthenticationOptions.TargetHost, incomingSecurityBuffers, - outgoingSecurity); + outgoingSecurity, + _sslAuthenticationOptions); } } } while (cachedCreds && _credentialsHandle == null); } finally { + if (_sslAuthenticationOptions.AlpnProtocolsHandle.IsAllocated) + { + _sslAuthenticationOptions.AlpnProtocolsHandle.Free(); + } + if (_refreshCredentialNeeded) { _refreshCredentialNeeded = false; @@ -839,13 +861,16 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref // if (!cachedCreds && _securityContext != null && !_securityContext.IsInvalid && _credentialsHandle != null && !_credentialsHandle.IsInvalid) { - SslSessionsCache.CacheCredential(_credentialsHandle, thumbPrint, _sslProtocols, _serverMode, _encryptionPolicy); + SslSessionsCache.CacheCredential(_credentialsHandle, thumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy); } } } output = outgoingSecurity.token; - + + byte[] alpnResult = SslStreamPal.GetNegotiatedApplicationProtocol(_securityContext); + _negotiatedApplicationProtocol = alpnResult == null ? default : new SslApplicationProtocol(alpnResult, false); + return status; } @@ -858,7 +883,8 @@ Fills in the information about established protocol --*/ internal void ProcessHandshakeSuccess() { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); StreamSizes streamSizes; SslStreamPal.QueryContextStreamSizes(_securityContext, out streamSizes); @@ -882,7 +908,8 @@ internal void ProcessHandshakeSuccess() SslStreamPal.QueryContextConnectionInfo(_securityContext, out _connectionInfo); - if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this); } /*++ @@ -917,12 +944,14 @@ internal SecurityStatusPal Encrypt(ReadOnlyMemory buffer, ref byte[] outpu if (secStatus.ErrorCode != SecurityStatusPalErrorCode.OK) { - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, $"ERROR {secStatus}"); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, $"ERROR {secStatus}"); } else { output = writeBuffer; - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, $"OK data size:{resultSize}"); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, $"OK data size:{resultSize}"); } return secStatus; @@ -930,7 +959,8 @@ internal SecurityStatusPal Encrypt(ReadOnlyMemory buffer, ref byte[] outpu internal SecurityStatusPal Decrypt(byte[] payload, ref int offset, ref int count) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this, payload, offset, count); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this, payload, offset, count); if (offset < 0 || offset > (payload == null ? 0 : payload.Length)) { @@ -962,7 +992,8 @@ internal SecurityStatusPal Decrypt(byte[] payload, ref int offset, ref int count // internal bool VerifyRemoteCertificate(RemoteCertValidationCallback remoteCertValidationCallback, ref ProtocolToken alertToken) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None; @@ -979,17 +1010,18 @@ internal bool VerifyRemoteCertificate(RemoteCertValidationCallback remoteCertVal if (remoteCertificateEx == null) { - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, "(no remote cert)", !_remoteCertRequired); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, "(no remote cert)", !_sslAuthenticationOptions.RemoteCertRequired); sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable; } else { chain = new X509Chain(); - chain.ChainPolicy.RevocationMode = _checkCertRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck; + chain.ChainPolicy.RevocationMode = _sslAuthenticationOptions.CertificateRevocationCheckMode; chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot; // Authenticate the remote party: (e.g. when operating in server mode, authenticate the client). - chain.ChainPolicy.ApplicationPolicy.Add(_serverMode ? _clientAuthOid : _serverAuthOid); + chain.ChainPolicy.ApplicationPolicy.Add(_sslAuthenticationOptions.IsServer ? _clientAuthOid : _serverAuthOid); if (remoteCertificateStore != null) { @@ -1000,18 +1032,18 @@ internal bool VerifyRemoteCertificate(RemoteCertValidationCallback remoteCertVal _securityContext, chain, remoteCertificateEx, - _checkCertName, - _serverMode, - _hostName); + _sslAuthenticationOptions.CheckCertName, + _sslAuthenticationOptions.IsServer, + _sslAuthenticationOptions.TargetHost); } if (remoteCertValidationCallback != null) { - success = remoteCertValidationCallback(_hostName, remoteCertificateEx, chain, sslPolicyErrors); + success = remoteCertValidationCallback(_sslAuthenticationOptions.TargetHost, remoteCertificateEx, chain, sslPolicyErrors); } else { - if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateNotAvailable && !_remoteCertRequired) + if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateNotAvailable && !_sslAuthenticationOptions.RemoteCertRequired) { success = true; } @@ -1024,7 +1056,8 @@ internal bool VerifyRemoteCertificate(RemoteCertValidationCallback remoteCertVal if (NetEventSource.IsEnabled) { LogCertificateValidation(remoteCertValidationCallback, sslPolicyErrors, success, chain); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Cert validation, remote cert = {remoteCertificateEx}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"Cert validation, remote cert = {remoteCertificateEx}"); } if (!success) @@ -1047,14 +1080,16 @@ internal bool VerifyRemoteCertificate(RemoteCertValidationCallback remoteCertVal } } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, success); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, success); return success; } public ProtocolToken CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); TlsAlertMessage alertMessage; @@ -1072,14 +1107,16 @@ public ProtocolToken CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErr break; } - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"alertMessage:{alertMessage}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"alertMessage:{alertMessage}"); SecurityStatusPal status; status = SslStreamPal.ApplyAlertToken(ref _credentialsHandle, _securityContext, TlsAlertType.Fatal, alertMessage); if (status.ErrorCode != SecurityStatusPalErrorCode.OK) { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"ApplyAlertToken() returned {status.ErrorCode}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"ApplyAlertToken() returned {status.ErrorCode}"); if (status.Exception != null) { @@ -1090,20 +1127,23 @@ public ProtocolToken CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErr } ProtocolToken token = GenerateAlertToken(); - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, token); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, token); return token; } public ProtocolToken CreateShutdownToken() { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); SecurityStatusPal status; status = SslStreamPal.ApplyShutdownToken(ref _credentialsHandle, _securityContext); if (status.ErrorCode != SecurityStatusPalErrorCode.OK) { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"ApplyAlertToken() returned {status.ErrorCode}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(this, $"ApplyAlertToken() returned {status.ErrorCode}"); if (status.Exception != null) { @@ -1114,7 +1154,8 @@ public ProtocolToken CreateShutdownToken() } ProtocolToken token = GenerateAlertToken(); - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, token); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, token); return token; } @@ -1129,7 +1170,7 @@ private ProtocolToken GenerateAlertToken() return token; } - + private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) { foreach (X509ChainStatus chainStatus in chain.ChainStatus) @@ -1147,7 +1188,7 @@ private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) } if ((chainStatus.Status & - (X509ChainStatusFlags.Revoked | X509ChainStatusFlags.OfflineRevocation )) != 0) + (X509ChainStatusFlags.Revoked | X509ChainStatusFlags.OfflineRevocation)) != 0) { return TlsAlertMessage.CertificateRevoked; } @@ -1161,7 +1202,7 @@ private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) if ((chainStatus.Status & X509ChainStatusFlags.CtlNotValidForUsage) != 0) { - return TlsAlertMessage.UnsupportedCert; + return TlsAlertMessage.UnsupportedCert; } if ((chainStatus.Status & @@ -1182,7 +1223,8 @@ private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) private void LogCertificateValidation(RemoteCertValidationCallback remoteCertValidationCallback, SslPolicyErrors sslPolicyErrors, bool success, X509Chain chain) { - if (!NetEventSource.IsEnabled) return; + if (!NetEventSource.IsEnabled) + return; if (sslPolicyErrors != SslPolicyErrors.None) { diff --git a/src/System.Net.Security/src/System/Net/Security/SslApplicationProtocol.cs b/src/System.Net.Security/src/System/Net/Security/SslApplicationProtocol.cs new file mode 100644 index 000000000000..2e3fcdff4b17 --- /dev/null +++ b/src/System.Net.Security/src/System/Net/Security/SslApplicationProtocol.cs @@ -0,0 +1,140 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; + +namespace System.Net.Security +{ + public struct SslApplicationProtocol : IEquatable + { + private readonly ReadOnlyMemory _readOnlyProtocol; + private static readonly Encoding s_utf8 = Encoding.GetEncoding(Encoding.UTF8.CodePage, EncoderFallback.ExceptionFallback, DecoderFallback.ExceptionFallback); + + // Refer IANA on ApplicationProtocols: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids + // h2 + public static readonly SslApplicationProtocol Http2 = new SslApplicationProtocol(new byte[] { 0x68, 0x32 }, false); + // http/1.1 + public static readonly SslApplicationProtocol Http11 = new SslApplicationProtocol(new byte[] { 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31 }, false); + + internal SslApplicationProtocol(byte[] protocol, bool copy) + { + if (protocol == null) + { + throw new ArgumentNullException(nameof(protocol)); + } + + // RFC 7301 states protocol size <= 255 bytes. + if (protocol.Length == 0 || protocol.Length > 255) + { + throw new ArgumentException(SR.net_ssl_app_protocol_invalid, nameof(protocol)); + } + + if (copy) + { + byte[] temp = new byte[protocol.Length]; + Array.Copy(protocol, temp, protocol.Length); + _readOnlyProtocol = new ReadOnlyMemory(protocol); + } + else + { + _readOnlyProtocol = new ReadOnlyMemory(protocol); + } + } + + public SslApplicationProtocol(byte[] protocol) : this(protocol, true) { } + + public SslApplicationProtocol(string protocol) : this(s_utf8.GetBytes(protocol), copy: false) { } + + public ReadOnlyMemory Protocol + { + get => _readOnlyProtocol; + } + + public bool Equals(SslApplicationProtocol other) + { + if (_readOnlyProtocol.Length != other._readOnlyProtocol.Length) + return false; + + return (_readOnlyProtocol.IsEmpty && other._readOnlyProtocol.IsEmpty) || + _readOnlyProtocol.Span.SequenceEqual(other._readOnlyProtocol.Span); + } + + public override bool Equals(object obj) + { + if (obj is SslApplicationProtocol protocol) + { + return Equals(protocol); + } + + return false; + } + + public override int GetHashCode() + { + if (_readOnlyProtocol.Length == 0) + return 0; + + int hash1 = 0; + ReadOnlySpan pSpan = _readOnlyProtocol.Span; + for (int i = 0; i < _readOnlyProtocol.Length; i++) + { + hash1 = ((hash1 << 5) + hash1) ^ pSpan[i]; + } + + return hash1; + } + + public override string ToString() + { + try + { + if (_readOnlyProtocol.Length == 0) + { + return null; + } + + return s_utf8.GetString(_readOnlyProtocol.Span); + } + catch + { + // In case of decoding errors, return the byte values as hex string. + int byteCharsLength = _readOnlyProtocol.Length * 5; + char[] byteChars = new char[byteCharsLength]; + int index = 0; + + ReadOnlySpan pSpan = _readOnlyProtocol.Span; + for (int i = 0; i < byteCharsLength; i += 5) + { + byte b = pSpan[index++]; + byteChars[i] = '0'; + byteChars[i + 1] = 'x'; + byteChars[i + 2] = GetHexValue(Math.DivRem(b, 16, out int rem)); + byteChars[i + 3] = GetHexValue(rem); + byteChars[i + 4] = ' '; + } + + return new string(byteChars, 0, byteCharsLength - 1); + + char GetHexValue(int i) + { + if (i < 10) + return (char)(i + '0'); + + return (char)(i - 10 + 'a'); + } + } + } + + public static bool operator ==(SslApplicationProtocol left, SslApplicationProtocol right) + { + return left.Equals(right); + } + + public static bool operator !=(SslApplicationProtocol left, SslApplicationProtocol right) + { + return !(left == right); + } + } +} + diff --git a/src/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs b/src/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs new file mode 100644 index 000000000000..d2d5bb9a79ea --- /dev/null +++ b/src/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; + +namespace System.Net.Security +{ + internal class SslAuthenticationOptions + { + internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions) + { + // Common options. + AllowRenegotiation = sslClientAuthenticationOptions.AllowRenegotiation; + ApplicationProtocols = sslClientAuthenticationOptions.ApplicationProtocols; + CertValidationDelegate = sslClientAuthenticationOptions._certValidationDelegate; + CheckCertName = true; + EnabledSslProtocols = sslClientAuthenticationOptions.EnabledSslProtocols; + EncryptionPolicy = sslClientAuthenticationOptions.EncryptionPolicy; + IsServer = false; + RemoteCertRequired = true; + RemoteCertificateValidationCallback = sslClientAuthenticationOptions.RemoteCertificateValidationCallback; + TargetHost = sslClientAuthenticationOptions.TargetHost; + + // Client specific options. + CertSelectionDelegate = sslClientAuthenticationOptions._certSelectionDelegate; + CertificateRevocationCheckMode = sslClientAuthenticationOptions.CertificateRevocationCheckMode; + ClientCertificates = sslClientAuthenticationOptions.ClientCertificates; + LocalCertificateSelectionCallback = sslClientAuthenticationOptions.LocalCertificateSelectionCallback; + } + + internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + // Common options. + AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation; + ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols; + CertValidationDelegate = sslServerAuthenticationOptions._certValidationDelegate; + CheckCertName = false; + EnabledSslProtocols = sslServerAuthenticationOptions.EnabledSslProtocols; + EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy; + IsServer = true; + RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired; + RemoteCertificateValidationCallback = sslServerAuthenticationOptions.RemoteCertificateValidationCallback; + TargetHost = string.Empty; + + // Server specific options. + CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode; + ServerCertificate = sslServerAuthenticationOptions.ServerCertificate; + } + + internal bool AllowRenegotiation { get; set; } + internal string TargetHost { get; set; } + internal X509CertificateCollection ClientCertificates { get; set; } + internal List ApplicationProtocols { get; } + internal bool IsServer { get; set; } + internal RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; } + internal LocalCertificateSelectionCallback LocalCertificateSelectionCallback { get; set; } + internal X509Certificate ServerCertificate { get; set; } + internal SslProtocols EnabledSslProtocols { get; set; } + internal X509RevocationMode CertificateRevocationCheckMode { get; set; } + internal EncryptionPolicy EncryptionPolicy { get; set; } + internal bool RemoteCertRequired { get; set; } + internal bool CheckCertName { get; set; } + internal RemoteCertValidationCallback CertValidationDelegate { get; set; } + internal LocalCertSelectionCallback CertSelectionDelegate { get; set; } + internal GCHandle AlpnProtocolsHandle { get; set; } + } +} + diff --git a/src/System.Net.Security/src/System/Net/Security/SslClientAuthenticationOptions.cs b/src/System.Net.Security/src/System/Net/Security/SslClientAuthenticationOptions.cs new file mode 100644 index 000000000000..72c822f124b1 --- /dev/null +++ b/src/System.Net.Security/src/System/Net/Security/SslClientAuthenticationOptions.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; + +namespace System.Net.Security +{ + public class SslClientAuthenticationOptions + { + private EncryptionPolicy _encryptionPolicy = EncryptionPolicy.RequireEncryption; + private X509RevocationMode _checkCertificateRevocation = X509RevocationMode.NoCheck; + private SslProtocols _enabledSslProtocols = SecurityProtocol.SystemDefaultSecurityProtocols; + + internal RemoteCertValidationCallback _certValidationDelegate; + internal LocalCertSelectionCallback _certSelectionDelegate; + + public bool AllowRenegotiation { get; set; } + + public LocalCertificateSelectionCallback LocalCertificateSelectionCallback { get; set; } + + public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; } + + public List ApplicationProtocols { get; set; } + + public string TargetHost { get; set; } + + public X509CertificateCollection ClientCertificates { get; set; } + + public X509RevocationMode CertificateRevocationCheckMode + { + get => _checkCertificateRevocation; + set + { + if (value != X509RevocationMode.NoCheck && value != X509RevocationMode.Offline && value != X509RevocationMode.Online) + { + throw new ArgumentException(SR.Format(SR.net_invalid_enum, nameof(X509RevocationMode)), nameof(value)); + } + + _checkCertificateRevocation = value; + } + } + + public EncryptionPolicy EncryptionPolicy + { + get => _encryptionPolicy; + set + { + if (value != EncryptionPolicy.RequireEncryption && value != EncryptionPolicy.AllowNoEncryption && value != EncryptionPolicy.NoEncryption) + { + throw new ArgumentException(SR.Format(SR.net_invalid_enum, nameof(EncryptionPolicy)), nameof(value)); + } + + _encryptionPolicy = value; + } + } + + public SslProtocols EnabledSslProtocols + { + get => _enabledSslProtocols; + set => _enabledSslProtocols = value; + } + } +} + diff --git a/src/System.Net.Security/src/System/Net/Security/SslServerAuthenticationOptions.cs b/src/System.Net.Security/src/System/Net/Security/SslServerAuthenticationOptions.cs new file mode 100644 index 000000000000..dea8860a8ae7 --- /dev/null +++ b/src/System.Net.Security/src/System/Net/Security/SslServerAuthenticationOptions.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; + +namespace System.Net.Security +{ + public class SslServerAuthenticationOptions + { + private X509RevocationMode _checkCertificateRevocation = X509RevocationMode.NoCheck; + private SslProtocols _enabledSslProtocols = SecurityProtocol.SystemDefaultSecurityProtocols; + private EncryptionPolicy _encryptionPolicy = EncryptionPolicy.RequireEncryption; + + internal RemoteCertValidationCallback _certValidationDelegate; + + public bool AllowRenegotiation { get; set; } + + public bool ClientCertificateRequired { get; set; } + + public List ApplicationProtocols { get; set; } + + public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; } + + public X509Certificate ServerCertificate { get; set; } + + public SslProtocols EnabledSslProtocols + { + get => _enabledSslProtocols; + set => _enabledSslProtocols = value; + } + + public X509RevocationMode CertificateRevocationCheckMode + { + get => _checkCertificateRevocation; + set + { + if (value != X509RevocationMode.NoCheck && value != X509RevocationMode.Offline && value != X509RevocationMode.Online) + { + throw new ArgumentException(SR.Format(SR.net_invalid_enum, nameof(X509RevocationMode)), nameof(value)); + } + + _checkCertificateRevocation = value; + } + } + + public EncryptionPolicy EncryptionPolicy + { + get => _encryptionPolicy; + set + { + if (value != EncryptionPolicy.RequireEncryption && value != EncryptionPolicy.AllowNoEncryption && value != EncryptionPolicy.NoEncryption) + { + throw new ArgumentException(SR.Format(SR.net_invalid_enum, nameof(EncryptionPolicy)), nameof(value)); + } + + _encryptionPolicy = value; + } + } + } +} + diff --git a/src/System.Net.Security/src/System/Net/Security/SslState.cs b/src/System.Net.Security/src/System/Net/Security/SslState.cs index b905f92fcaf9..fe368691c07e 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslState.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslState.cs @@ -22,8 +22,7 @@ internal class SslState private static AsyncProtocolCallback s_readFrameCallback = new AsyncProtocolCallback(ReadFrameCallback); private static AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback); - private RemoteCertValidationCallback _certValidationDelegate; - private LocalCertSelectionCallback _certSelectionDelegate; + private SslAuthenticationOptions _sslAuthenticationOptions; private Stream _innerStream; @@ -70,31 +69,17 @@ private enum CachedSessionStatus : byte private int _lockReadState; private object _queuedReadStateRequest; - private readonly EncryptionPolicy _encryptionPolicy; - // // The public Client and Server classes enforce the parameters rules before // calling into this .ctor. // - internal SslState(Stream innerStream, RemoteCertValidationCallback certValidationCallback, LocalCertSelectionCallback certSelectionCallback, EncryptionPolicy encryptionPolicy) + internal SslState(Stream innerStream) { _innerStream = innerStream; - _certValidationDelegate = certValidationCallback; - _certSelectionDelegate = certSelectionCallback; - _encryptionPolicy = encryptionPolicy; - } - - internal void ValidateCreateContext(bool isServer, string targetHost, SslProtocols enabledSslProtocols, X509Certificate serverCertificate, X509CertificateCollection clientCertificates, bool remoteCertRequired, bool checkCertRevocationStatus) - { - ValidateCreateContext(isServer, targetHost, enabledSslProtocols, serverCertificate, clientCertificates, remoteCertRequired, - checkCertRevocationStatus, !isServer); } - internal void ValidateCreateContext(bool isServer, string targetHost, SslProtocols enabledSslProtocols, X509Certificate serverCertificate, X509CertificateCollection clientCertificates, bool remoteCertRequired, bool checkCertRevocationStatus, bool checkCertName) + internal void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions) { - // - // We don't support SSL alerts right now, hence any exception is fatal and cannot be retried. - // if (_exception != null) { _exception.Throw(); @@ -105,31 +90,60 @@ internal void ValidateCreateContext(bool isServer, string targetHost, SslProtoco throw new InvalidOperationException(SR.net_auth_reauth); } - if (Context != null && IsServer != isServer) + if (Context != null && IsServer) { throw new InvalidOperationException(SR.net_auth_client_server); } - if (targetHost == null) + if (sslClientAuthenticationOptions.TargetHost == null) + { + throw new ArgumentNullException(nameof(sslClientAuthenticationOptions.TargetHost)); + } + + if (sslClientAuthenticationOptions.TargetHost.Length == 0) + { + sslClientAuthenticationOptions.TargetHost = "?" + Interlocked.Increment(ref s_uniqueNameInteger).ToString(NumberFormatInfo.InvariantInfo); + } + + _exception = null; + try { - throw new ArgumentNullException(nameof(targetHost)); + _sslAuthenticationOptions = new SslAuthenticationOptions(sslClientAuthenticationOptions); + _context = new SecureChannel(_sslAuthenticationOptions); } + catch (Win32Exception e) + { + throw new AuthenticationException(SR.net_auth_SSPI, e); + } + } - if (isServer && serverCertificate == null) + internal void ValidateCreateContext(SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + if (_exception != null) { - throw new ArgumentNullException(nameof(serverCertificate)); + _exception.Throw(); } - if (targetHost.Length == 0) + if (Context != null && Context.IsValidContext) { - targetHost = "?" + Interlocked.Increment(ref s_uniqueNameInteger).ToString(NumberFormatInfo.InvariantInfo); + throw new InvalidOperationException(SR.net_auth_reauth); + } + + if (Context != null && !IsServer) + { + throw new InvalidOperationException(SR.net_auth_client_server); + } + + if (sslServerAuthenticationOptions.ServerCertificate == null) + { + throw new ArgumentNullException(nameof(sslServerAuthenticationOptions.ServerCertificate)); } _exception = null; try { - _context = new SecureChannel(targetHost, isServer, enabledSslProtocols, serverCertificate, clientCertificates, remoteCertRequired, - checkCertName, checkCertRevocationStatus, _encryptionPolicy, _certSelectionDelegate); + _sslAuthenticationOptions = new SslAuthenticationOptions(sslServerAuthenticationOptions); + _context = new SecureChannel(_sslAuthenticationOptions); } catch (Win32Exception e) { @@ -137,6 +151,17 @@ internal void ValidateCreateContext(bool isServer, string targetHost, SslProtoco } } + internal SslApplicationProtocol NegotiatedApplicationProtocol + { + get + { + if (Context == null) + return default; + + return Context.NegotiatedApplicationProtocol; + } + } + internal bool IsAuthenticated { get @@ -172,14 +197,6 @@ internal bool IsServer } } - // - // SSL related properties - // - internal void SetCertValidationDelegate(RemoteCertValidationCallback certValidationCallback) - { - _certValidationDelegate = certValidationCallback; - } - // // This will return selected local cert for both client/server streams // @@ -209,7 +226,7 @@ internal bool CheckCertRevocationStatus { get { - return Context != null && Context.CheckCertRevocationStatus; + return Context != null && Context.CheckCertRevocationStatus != X509RevocationMode.NoCheck; } } @@ -580,14 +597,15 @@ internal void ProcessAuthentication(LazyAsyncResult lazyResult) // Not aync so the connection is completed at this point. if (lazyResult == null && NetEventSource.IsEnabled) { - if (NetEventSource.IsEnabled) NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication), - SslProtocol, - CipherAlgorithm, - CipherStrength, - HashAlgorithm, - HashStrength, - KeyExchangeAlgorithm, - KeyExchangeStrength); + if (NetEventSource.IsEnabled) + NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication), + SslProtocol, + CipherAlgorithm, + CipherStrength, + HashAlgorithm, + HashStrength, + KeyExchangeAlgorithm, + KeyExchangeStrength); } } catch (Exception) @@ -714,14 +732,15 @@ internal void EndProcessAuthentication(IAsyncResult result) // Connection is completed at this point. if (NetEventSource.IsEnabled) { - if (NetEventSource.IsEnabled) NetEventSource.Log.SspiSelectedCipherSuite(nameof(EndProcessAuthentication), - SslProtocol, - CipherAlgorithm, - CipherStrength, - HashAlgorithm, - HashStrength, - KeyExchangeAlgorithm, - KeyExchangeStrength); + if (NetEventSource.IsEnabled) + NetEventSource.Log.SspiSelectedCipherSuite(nameof(EndProcessAuthentication), + SslProtocol, + CipherAlgorithm, + CipherStrength, + HashAlgorithm, + HashStrength, + KeyExchangeAlgorithm, + KeyExchangeStrength); } } @@ -1014,23 +1033,26 @@ private void StartSendAuthResetSignal(ProtocolToken message, AsyncProtocolReques // private bool CompleteHandshake(ref ProtocolToken alertToken) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this); Context.ProcessHandshakeSuccess(); - if (!Context.VerifyRemoteCertificate(_certValidationDelegate, ref alertToken)) + if (!Context.VerifyRemoteCertificate(_sslAuthenticationOptions.CertValidationDelegate, ref alertToken)) { _handshakeCompleted = false; _certValidationFailed = true; - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, false); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, false); return false; } _certValidationFailed = false; _handshakeCompleted = true; - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, true); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, true); return true; } @@ -1088,7 +1110,8 @@ private static void WriteCallback(IAsyncResult transportResult) private static void PartialFrameCallback(AsyncProtocolRequest asyncRequest) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(null); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(null); // Async ONLY completion. SslState sslState = (SslState)asyncRequest.AsyncObject; @@ -1112,7 +1135,8 @@ private static void PartialFrameCallback(AsyncProtocolRequest asyncRequest) // private static void ReadFrameCallback(AsyncProtocolRequest asyncRequest) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(null); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(null); // Async ONLY completion. SslState sslState = (SslState)asyncRequest.AsyncObject; @@ -1679,7 +1703,8 @@ private Framing DetectFraming(byte[] bytes, int length) // This is called from SslStream class too. internal int GetRemainingFrameSize(byte[] buffer, int offset, int dataSize) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(this, buffer, offset, dataSize); + if (NetEventSource.IsEnabled) + NetEventSource.Enter(this, buffer, offset, dataSize); int payloadSize = -1; switch (_Framing) @@ -1719,7 +1744,8 @@ internal int GetRemainingFrameSize(byte[] buffer, int offset, int dataSize) break; } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, payloadSize); + if (NetEventSource.IsEnabled) + NetEventSource.Exit(this, payloadSize); return payloadSize; } @@ -1841,7 +1867,7 @@ private void RehandshakeCompleteCallback(IAsyncResult result) internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState) { - CheckThrow(authSuccessCheck:true, shutdownCheck:true); + CheckThrow(authSuccessCheck: true, shutdownCheck: true); ProtocolToken message = Context.CreateShutdownToken(); return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState); @@ -1849,7 +1875,7 @@ internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncSta internal void EndShutdown(IAsyncResult result) { - CheckThrow(authSuccessCheck: true, shutdownCheck:true); + CheckThrow(authSuccessCheck: true, shutdownCheck: true); TaskToApm.End(result); _shutdown = true; diff --git a/src/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/System.Net.Security/src/System/Net/Security/SslStream.cs index c110f932abfd..aa30e5e3ca20 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using System.IO; +using System.Runtime.InteropServices; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; using System.Security.Cryptography.X509Certificates; @@ -36,10 +38,14 @@ public enum EncryptionPolicy public class SslStream : AuthenticatedStream { private SslState _sslState; - private RemoteCertificateValidationCallback _userCertificateValidationCallback; - private LocalCertificateSelectionCallback _userCertificateSelectionCallback; private object _remoteCertificateOrBytes; + internal RemoteCertificateValidationCallback _userCertificateValidationCallback; + internal LocalCertificateSelectionCallback _userCertificateSelectionCallback; + internal RemoteCertValidationCallback _certValidationDelegate; + internal LocalCertSelectionCallback _certSelectionDelegate; + internal EncryptionPolicy _encryptionPolicy; + public SslStream(Stream innerStream) : this(innerStream, false, null, null) { @@ -72,9 +78,44 @@ public SslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificat _userCertificateValidationCallback = userCertificateValidationCallback; _userCertificateSelectionCallback = userCertificateSelectionCallback; - RemoteCertValidationCallback _userCertValidationCallbackWrapper = new RemoteCertValidationCallback(UserCertValidationCallbackWrapper); - LocalCertSelectionCallback _userCertSelectionCallbackWrapper = userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper); - _sslState = new SslState(innerStream, _userCertValidationCallbackWrapper, _userCertSelectionCallbackWrapper, encryptionPolicy); + _encryptionPolicy = encryptionPolicy; + _certValidationDelegate = new RemoteCertValidationCallback(UserCertValidationCallbackWrapper); + _certSelectionDelegate = userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper); + _sslState = new SslState(innerStream); + } + + public SslApplicationProtocol NegotiatedApplicationProtocol + { + get + { + return _sslState.NegotiatedApplicationProtocol; + } + } + + private void SetAndVerifyValidationCallback(RemoteCertificateValidationCallback callback) + { + if (_userCertificateValidationCallback == null) + { + _userCertificateValidationCallback = callback; + _certValidationDelegate = new RemoteCertValidationCallback(UserCertValidationCallbackWrapper); + } + else if (callback != null && _userCertificateValidationCallback != callback) + { + throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(RemoteCertificateValidationCallback))); + } + } + + private void SetAndVerifySelectionCallback(LocalCertificateSelectionCallback callback) + { + if (_userCertificateSelectionCallback == null) + { + _userCertificateSelectionCallback = callback; + _certSelectionDelegate = _userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper); + } + else if (callback != null && _userCertificateSelectionCallback != callback) + { + throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(LocalCertificateSelectionCallback))); + } } private bool UserCertValidationCallbackWrapper(string hostName, X509Certificate2 certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) @@ -119,8 +160,29 @@ public virtual IAsyncResult BeginAuthenticateAsClient(string targetHost, X509Cer SslProtocols enabledSslProtocols, bool checkCertificateRevocation, AsyncCallback asyncCallback, object asyncState) { - SecurityProtocol.ThrowOnNotAllowed(enabledSslProtocols); - _sslState.ValidateCreateContext(false, targetHost, enabledSslProtocols, null, clientCertificates, true, checkCertificateRevocation); + SslClientAuthenticationOptions options = new SslClientAuthenticationOptions + { + TargetHost = targetHost, + ClientCertificates = clientCertificates, + EnabledSslProtocols = enabledSslProtocols, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + return BeginAuthenticateAsClient(options, CancellationToken.None, asyncCallback, asyncState); + } + + internal virtual IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) + { + SecurityProtocol.ThrowOnNotAllowed(sslClientAuthenticationOptions.EnabledSslProtocols); + SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); + SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); + + // Set the delegates on the options. + sslClientAuthenticationOptions._certValidationDelegate = _certValidationDelegate; + sslClientAuthenticationOptions._certSelectionDelegate = _certSelectionDelegate; + + _sslState.ValidateCreateContext(sslClientAuthenticationOptions); LazyAsyncResult result = new LazyAsyncResult(_sslState, asyncState, asyncCallback); _sslState.ProcessAuthentication(result); @@ -154,8 +216,28 @@ public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCert AsyncCallback asyncCallback, object asyncState) { - SecurityProtocol.ThrowOnNotAllowed(enabledSslProtocols); - _sslState.ValidateCreateContext(true, string.Empty, enabledSslProtocols, serverCertificate, null, clientCertificateRequired, checkCertificateRevocation); + SslServerAuthenticationOptions options = new SslServerAuthenticationOptions + { + ServerCertificate = serverCertificate, + ClientCertificateRequired = clientCertificateRequired, + EnabledSslProtocols = enabledSslProtocols, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + return BeginAuthenticateAsServer(options, CancellationToken.None, asyncCallback, asyncState); + } + + private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) + { + SecurityProtocol.ThrowOnNotAllowed(sslServerAuthenticationOptions.EnabledSslProtocols); + SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); + + // Set the delegate on the options. + sslServerAuthenticationOptions._certValidationDelegate = _certValidationDelegate; + + _sslState.ValidateCreateContext(sslServerAuthenticationOptions); + LazyAsyncResult result = new LazyAsyncResult(_sslState, asyncState, asyncCallback); _sslState.ProcessAuthentication(result); return result; @@ -202,8 +284,29 @@ public virtual void AuthenticateAsClient(string targetHost, X509CertificateColle public virtual void AuthenticateAsClient(string targetHost, X509CertificateCollection clientCertificates, SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { - SecurityProtocol.ThrowOnNotAllowed(enabledSslProtocols); - _sslState.ValidateCreateContext(false, targetHost, enabledSslProtocols, null, clientCertificates, true, checkCertificateRevocation); + SslClientAuthenticationOptions options = new SslClientAuthenticationOptions + { + TargetHost = targetHost, + ClientCertificates = clientCertificates, + EnabledSslProtocols = enabledSslProtocols, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + AuthenticateAsClient(options); + } + + private void AuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions) + { + SecurityProtocol.ThrowOnNotAllowed(sslClientAuthenticationOptions.EnabledSslProtocols); + SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); + SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); + + // Set the delegates on the options. + sslClientAuthenticationOptions._certValidationDelegate = _certValidationDelegate; + sslClientAuthenticationOptions._certSelectionDelegate = _certSelectionDelegate; + + _sslState.ValidateCreateContext(sslClientAuthenticationOptions); _sslState.ProcessAuthentication(null); } @@ -219,8 +322,27 @@ public virtual void AuthenticateAsServer(X509Certificate serverCertificate, bool public virtual void AuthenticateAsServer(X509Certificate serverCertificate, bool clientCertificateRequired, SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { - SecurityProtocol.ThrowOnNotAllowed(enabledSslProtocols); - _sslState.ValidateCreateContext(true, string.Empty, enabledSslProtocols, serverCertificate, null, clientCertificateRequired, checkCertificateRevocation); + SslServerAuthenticationOptions options = new SslServerAuthenticationOptions + { + ServerCertificate = serverCertificate, + ClientCertificateRequired = clientCertificateRequired, + EnabledSslProtocols = enabledSslProtocols, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + AuthenticateAsServer(options); + } + + private void AuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + SecurityProtocol.ThrowOnNotAllowed(sslServerAuthenticationOptions.EnabledSslProtocols); + SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); + + // Set the delegate on the options. + sslServerAuthenticationOptions._certValidationDelegate = _certValidationDelegate; + + _sslState.ValidateCreateContext(sslServerAuthenticationOptions); _sslState.ProcessAuthentication(null); } #endregion @@ -252,6 +374,15 @@ public virtual Task AuthenticateAsClientAsync(string targetHost, X509Certificate this); } + public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken) + { + return Task.Factory.FromAsync( + (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, callback, state), + iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar), + sslClientAuthenticationOptions, cancellationToken, + this); + } + public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate) => Task.Factory.FromAsync( (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, callback, state), @@ -278,6 +409,15 @@ public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, this); } + public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken) + { + return Task.Factory.FromAsync( + (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, callback, state), + iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar), + sslServerAuthenticationOptions, cancellationToken, + this); + } + public virtual Task ShutdownAsync() => Task.Factory.FromAsync( (callback, state) => ((SslStream)state).BeginShutdown(callback, state), diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs index b426ebe466b8..eecf4b365009 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs @@ -37,11 +37,21 @@ public static void VerifyPackageInfo() public static SecurityStatusPal AcceptSecurityContext( ref SafeFreeCredentials credential, ref SafeDeleteContext context, - SecurityBuffer inputBuffer, + SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, - bool remoteCertRequired) + SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, outputBuffer, true, remoteCertRequired, null); + if (inputBuffers != null) + { + Debug.Assert(inputBuffers.Length == 2); + Debug.Assert(inputBuffers[1].token == null); + + return HandshakeInternal(credential, ref context, inputBuffers[0], outputBuffer, sslAuthenticationOptions); + } + else + { + return HandshakeInternal(credential, ref context, inputBuffer: null, outputBuffer, sslAuthenticationOptions); + } } public static SecurityStatusPal InitializeSecurityContext( @@ -49,9 +59,10 @@ public static SecurityStatusPal InitializeSecurityContext( ref SafeDeleteContext context, string targetName, SecurityBuffer inputBuffer, - SecurityBuffer outputBuffer) + SecurityBuffer outputBuffer, + SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, outputBuffer, false, false, targetName); + return HandshakeInternal(credential, ref context, inputBuffer, outputBuffer, sslAuthenticationOptions); } public static SecurityStatusPal InitializeSecurityContext( @@ -59,11 +70,28 @@ public static SecurityStatusPal InitializeSecurityContext( ref SafeDeleteContext context, string targetName, SecurityBuffer[] inputBuffers, - SecurityBuffer outputBuffer) + SecurityBuffer outputBuffer, + SslAuthenticationOptions sslAuthenticationOptions) { Debug.Assert(inputBuffers.Length == 2); Debug.Assert(inputBuffers[1].token == null); - return HandshakeInternal(credential, ref context, inputBuffers[0], outputBuffer, false, false, targetName); + return HandshakeInternal(credential, ref context, inputBuffers[0], outputBuffer, sslAuthenticationOptions); + } + + public static SecurityBuffer[] GetIncomingSecurityBuffers(SslAuthenticationOptions options, ref SecurityBuffer incomingSecurity) + { + SecurityBuffer[] incomingSecurityBuffers = null; + + if (incomingSecurity != null) + { + incomingSecurityBuffers = new SecurityBuffer[] + { + incomingSecurity, + new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY) + }; + } + + return incomingSecurityBuffers; } public static SafeFreeCredentials AcquireCredentialsHandle( @@ -75,6 +103,12 @@ public static SafeFreeCredentials AcquireCredentialsHandle( return new SafeFreeSslCredentials(certificate, protocols, policy); } + internal static byte[] GetNegotiatedApplicationProtocol(SafeDeleteContext context) + { + // OSX SecureTransport does not export APIs to support ALPN, no-op. + return null; + } + public static SecurityStatusPal EncryptMessage( SafeDeleteContext securityContext, ReadOnlyMemory input, @@ -226,9 +260,7 @@ private static SecurityStatusPal HandshakeInternal( ref SafeDeleteContext context, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, - bool isServer, - bool remoteCertRequired, - string targetName) + SslAuthenticationOptions sslAuthenticationOptions) { Debug.Assert(!credential.IsInvalid); @@ -238,18 +270,17 @@ private static SecurityStatusPal HandshakeInternal( if ((null == context) || context.IsInvalid) { - sslContext = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, isServer); + sslContext = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions); context = sslContext; - if (!string.IsNullOrEmpty(targetName)) + if (!string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost)) { - Debug.Assert(!isServer, "targetName should not be set for server-side handshakes"); - Interop.AppleCrypto.SslSetTargetName(sslContext.SslContext, targetName); + Debug.Assert(!sslAuthenticationOptions.IsServer, "targetName should not be set for server-side handshakes"); + Interop.AppleCrypto.SslSetTargetName(sslContext.SslContext, sslAuthenticationOptions.TargetHost); } - if (remoteCertRequired) + if (sslAuthenticationOptions.IsServer && sslAuthenticationOptions.RemoteCertRequired) { - Debug.Assert(isServer, "remoteCertRequired should not be set for client-side handshakes"); Interop.AppleCrypto.SslSetAcceptClientCert(sslContext.SslContext); } } diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs index 6eb1b6296e4d..7833be6c5838 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using System.Diagnostics; using System.Net.Security; using System.Runtime.InteropServices; @@ -29,22 +30,49 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, - SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, bool remoteCertRequired) + SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, outputBuffer, true, remoteCertRequired); + if (inputBuffers != null) + { + Debug.Assert(inputBuffers.Length == 2); + Debug.Assert(inputBuffers[1].token == null); + + return HandshakeInternal(credential, ref context, inputBuffers[0], outputBuffer, sslAuthenticationOptions); + } + else + { + return HandshakeInternal(credential, ref context, inputBuffer: null, outputBuffer, sslAuthenticationOptions); + } } public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, - string targetName, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer) + string targetName, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, outputBuffer, false, false); + return HandshakeInternal(credential, ref context, inputBuffer, outputBuffer, sslAuthenticationOptions); } - public static SecurityStatusPal InitializeSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer) + public static SecurityStatusPal InitializeSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Debug.Assert(inputBuffers.Length == 2); Debug.Assert(inputBuffers[1].token == null); - return HandshakeInternal(credential, ref context, inputBuffers[0], outputBuffer, false, false); + + return HandshakeInternal(credential, ref context, inputBuffers[0], outputBuffer, sslAuthenticationOptions); + } + + public static SecurityBuffer[] GetIncomingSecurityBuffers(SslAuthenticationOptions options, ref SecurityBuffer incomingSecurity) + { + SecurityBuffer[] incomingSecurityBuffers = null; + + if (incomingSecurity != null) + { + incomingSecurityBuffers = new SecurityBuffer[] + { + incomingSecurity, + new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY) + }; + } + + return incomingSecurityBuffers; } public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate, @@ -102,8 +130,13 @@ public static void QueryContextConnectionInfo(SafeDeleteContext securityContext, connectionInfo = new SslConnectionInfo(((SafeDeleteSslContext)securityContext).SslContext); } - private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credential, ref SafeDeleteContext context, - SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, bool isServer, bool remoteCertRequired) + public static byte[] ConvertAlpnProtocolListToByteArray(List applicationProtocols) + { + return Interop.Ssl.ConvertAlpnProtocolListToByteArray(applicationProtocols); + } + + private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer inputBuffer, + SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Debug.Assert(!credential.IsInvalid); @@ -111,7 +144,7 @@ private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credentia { if ((null == context) || context.IsInvalid) { - context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, isServer, remoteCertRequired); + context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions); } byte[] output = null; @@ -139,6 +172,14 @@ private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credentia } } + internal static byte[] GetNegotiatedApplicationProtocol(SafeDeleteContext context) + { + if (context == null) + return null; + + return Interop.Ssl.SslGetAlpnSelected(((SafeDeleteSslContext)context).SslContext); + } + private static SecurityStatusPal EncryptDecryptHelper(SafeDeleteContext securityContext, ReadOnlyMemory input, int offset, int size, bool encrypt, ref byte[] output, out int resultSize) { resultSize = 0; diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs index 54e91d66e870..d2b822b84de9 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs @@ -2,15 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.Win32.SafeHandles; +using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; -using System.Net.Security; using System.Runtime.InteropServices; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; using System.Security.Cryptography.X509Certificates; using System.Security.Principal; +using Microsoft.Win32.SafeHandles; namespace System.Net.Security { @@ -41,24 +41,29 @@ public static void VerifyPackageInfo() SSPIWrapper.GetVerifyPackageInfo(GlobalSSPI.SSPISecureChannel, SecurityPackage, true); } - public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteContext context, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, bool remoteCertRequired) + public static byte[] ConvertAlpnProtocolListToByteArray(List protocols) + { + return Interop.Sec_Application_Protocols.ToByteArray(protocols); + } + + public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default(Interop.SspiCli.ContextFlags); int errorCode = SSPIWrapper.AcceptSecurityContext( GlobalSSPI.SSPISecureChannel, - ref credentialsHandle, + credentialsHandle, ref context, - ServerRequiredFlags | (remoteCertRequired ? Interop.SspiCli.ContextFlags.MutualAuth : Interop.SspiCli.ContextFlags.Zero), + ServerRequiredFlags | (sslAuthenticationOptions.RemoteCertRequired ? Interop.SspiCli.ContextFlags.MutualAuth : Interop.SspiCli.ContextFlags.Zero), Interop.SspiCli.Endianness.SECURITY_NATIVE_DREP, - inputBuffer, + inputBuffers, outputBuffer, ref unusedAttributes); return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } - public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteContext context, string targetName, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer) + public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteContext context, string targetName, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default(Interop.SspiCli.ContextFlags); @@ -76,7 +81,7 @@ public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredential return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } - public static SecurityStatusPal InitializeSecurityContext(SafeFreeCredentials credentialsHandle, ref SafeDeleteContext context, string targetName, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer) + public static SecurityStatusPal InitializeSecurityContext(SafeFreeCredentials credentialsHandle, ref SafeDeleteContext context, string targetName, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default(Interop.SspiCli.ContextFlags); @@ -94,6 +99,45 @@ public static SecurityStatusPal InitializeSecurityContext(SafeFreeCredentials cr return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } + public static SecurityBuffer[] GetIncomingSecurityBuffers(SslAuthenticationOptions options, ref SecurityBuffer incomingSecurity) + { + SecurityBuffer alpnBuffer = null; + SecurityBuffer[] incomingSecurityBuffers = null; + + if (options.ApplicationProtocols != null && options.ApplicationProtocols.Count != 0) + { + byte[] alpnBytes = SslStreamPal.ConvertAlpnProtocolListToByteArray(options.ApplicationProtocols); + alpnBuffer = new SecurityBuffer(alpnBytes, 0, alpnBytes.Length, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS); + } + + if (incomingSecurity != null) + { + if (alpnBuffer != null) + { + incomingSecurityBuffers = new SecurityBuffer[] + { + incomingSecurity, + new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY), + alpnBuffer + }; + } + else + { + incomingSecurityBuffers = new SecurityBuffer[] + { + incomingSecurity, + new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY) + }; + } + } + else if (alpnBuffer != null) + { + incomingSecurity = alpnBuffer; + } + + return incomingSecurityBuffers; + } + public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate, SslProtocols protocols, EncryptionPolicy policy, bool isServer) { int protocolFlags = GetProtocolFlagsFromSslProtocols(protocols, isServer); @@ -103,9 +147,9 @@ public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certi if (!isServer) { direction = Interop.SspiCli.CredentialUse.SECPKG_CRED_OUTBOUND; - flags = - Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_CRED_MANUAL_CRED_VALIDATION | - Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_CRED_NO_DEFAULT_CREDS | + flags = + Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_CRED_MANUAL_CRED_VALIDATION | + Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_CRED_NO_DEFAULT_CREDS | Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_SEND_AUX_RECORD; // CoreFX: always opt-in SCH_USE_STRONG_CRYPTO except for SSL3. @@ -131,6 +175,24 @@ public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certi return AcquireCredentialsHandle(direction, secureCredential); } + internal static byte[] GetNegotiatedApplicationProtocol(SafeDeleteContext context) + { + Interop.SecPkgContext_ApplicationProtocol alpnContext = SSPIWrapper.QueryContextAttributes( + GlobalSSPI.SSPISecureChannel, + context, + Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL) as Interop.SecPkgContext_ApplicationProtocol; + + // Check if the context returned is alpn data, with successful negotiation. + if (alpnContext == null || + alpnContext.ProtoNegoExt != Interop.ApplicationProtocolNegotiationExt.ALPN || + alpnContext.ProtoNegoStatus != Interop.ApplicationProtocolNegotiationStatus.Success) + { + return null; + } + + return alpnContext.Protocol; + } + public static unsafe SecurityStatusPal EncryptMessage(SafeDeleteContext securityContext, ReadOnlyMemory input, int headerSize, int trailerSize, ref byte[] output, out int resultSize) { // Ensure that there is sufficient space for the message output. @@ -182,7 +244,8 @@ public static unsafe SecurityStatusPal EncryptMessage(SafeDeleteContext security if (errorCode != 0) { - if (NetEventSource.IsEnabled) NetEventSource.Info(securityContext, $"Encrypt ERROR {errorCode:X}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(securityContext, $"Encrypt ERROR {errorCode:X}"); resultSize = 0; return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } @@ -266,7 +329,7 @@ public static SecurityStatusPal ApplyAlertToken(ref SafeFreeCredentials credenti ref securityContext, bufferDesc); - return SecurityStatusAdapterPal.GetSecurityStatusPalFromInterop(errorCode, attachException:true); + return SecurityStatusAdapterPal.GetSecurityStatusPalFromInterop(errorCode, attachException: true); } finally { @@ -287,7 +350,7 @@ public static SecurityStatusPal ApplyShutdownToken(ref SafeFreeCredentials crede ref securityContext, bufferDesc); - return SecurityStatusAdapterPal.GetSecurityStatusPalFromInterop(errorCode, attachException:true); + return SecurityStatusAdapterPal.GetSecurityStatusPalFromInterop(errorCode, attachException: true); } public static unsafe SafeFreeContextBufferChannelBinding QueryContextChannelBinding(SafeDeleteContext securityContext, ChannelBindingKind attribute) @@ -402,7 +465,7 @@ private static SafeFreeCredentials AcquireCredentialsHandle(Interop.SspiCli.Cred return SSPIWrapper.AcquireCredentialsHandle(GlobalSSPI.SSPISecureChannel, SecurityPackage, credUsage, secureCredential); }); } - catch(Exception ex) + catch (Exception ex) { Debug.Fail("AcquireCredentialsHandle failed.", ex.ToString()); return SSPIWrapper.AcquireCredentialsHandle(GlobalSSPI.SSPISecureChannel, SecurityPackage, credUsage, secureCredential); diff --git a/src/System.Net.Security/tests/FunctionalTests/SslStreamAlpnTests.cs b/src/System.Net.Security/tests/FunctionalTests/SslStreamAlpnTests.cs new file mode 100644 index 000000000000..bcf10160df61 --- /dev/null +++ b/src/System.Net.Security/tests/FunctionalTests/SslStreamAlpnTests.cs @@ -0,0 +1,174 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using System.Net.Test.Common; +using System.Runtime.InteropServices; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; + +namespace System.Net.Security.Tests +{ + using Configuration = System.Net.Test.Common.Configuration; + + public class SslStreamAlpnTests + { + private bool DoHandshakeWithOptions(SslStream clientSslStream, SslStream serverSslStream, SslClientAuthenticationOptions clientOptions, SslServerAuthenticationOptions serverOptions) + { + using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) + { + clientOptions.RemoteCertificateValidationCallback = AllowAnyServerCertificate; + clientOptions.TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false); + serverOptions.ServerCertificate = certificate; + + Task t1 = clientSslStream.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); + Task t2 = serverSslStream.AuthenticateAsServerAsync(serverOptions, CancellationToken.None); + + return Task.WaitAll(new[] { t1, t2 }, TestConfiguration.PassingTestTimeoutMilliseconds); + } + } + + protected bool AllowAnyServerCertificate( + object sender, + X509Certificate certificate, + X509Chain chain, + SslPolicyErrors sslPolicyErrors) + { + SslPolicyErrors expectedSslPolicyErrors = SslPolicyErrors.None; + + if (!Capability.IsTrustedRootCertificateInstalled()) + { + expectedSslPolicyErrors = SslPolicyErrors.RemoteCertificateChainErrors; + } + + Assert.Equal(expectedSslPolicyErrors, sslPolicyErrors); + + if (sslPolicyErrors == expectedSslPolicyErrors) + { + return true; + } + else + { + return false; + } + } + + [Fact] + public void SslStream_StreamToStream_DuplicateOptions_Throws() + { + RemoteCertificateValidationCallback rCallback = (sender, certificate, chain, errors) => { return true; }; + LocalCertificateSelectionCallback lCallback = (sender, host, localCertificates, remoteCertificate, issuers) => { return null; }; + + VirtualNetwork network = new VirtualNetwork(); + using (var clientStream = new VirtualNetworkStream(network, false)) + using (var serverStream = new VirtualNetworkStream(network, true)) + using (var client = new SslStream(clientStream, false, rCallback, lCallback, EncryptionPolicy.RequireEncryption)) + using (var server = new SslStream(serverStream, false, rCallback)) + using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) + { + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions(); + clientOptions.RemoteCertificateValidationCallback = AllowAnyServerCertificate; + clientOptions.TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false); + + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions(); + serverOptions.ServerCertificate = certificate; + serverOptions.RemoteCertificateValidationCallback = AllowAnyServerCertificate; + + Assert.ThrowsAsync(() => { return client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); }); + Assert.ThrowsAsync(() => { return server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None); }); + } + } + + [Theory] + [ActiveIssue(24722, TestPlatforms.Linux)] + [MemberData(nameof(Alpn_TestData))] + public void SslStream_StreamToStream_Alpn_Success(List clientProtocols, List serverProtocols, SslApplicationProtocol expected) + { + VirtualNetwork network = new VirtualNetwork(); + using (var clientStream = new VirtualNetworkStream(network, false)) + using (var serverStream = new VirtualNetworkStream(network, true)) + using (var client = new SslStream(clientStream, false)) + using (var server = new SslStream(serverStream, false)) + { + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + ApplicationProtocols = clientProtocols, + }; + + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + ApplicationProtocols = serverProtocols, + }; + + Assert.True(DoHandshakeWithOptions(client, server, clientOptions, serverOptions)); + + Assert.Equal(expected, client.NegotiatedApplicationProtocol); + Assert.Equal(expected, server.NegotiatedApplicationProtocol); + } + } + + [Fact] + [PlatformSpecific(~TestPlatforms.OSX)] + public void SslStream_StreamToStream_Alpn_NonMatchingProtocols_Fail() + { + VirtualNetwork network = new VirtualNetwork(); + using (var clientStream = new VirtualNetworkStream(network, false)) + using (var serverStream = new VirtualNetworkStream(network, true)) + using (var client = new SslStream(clientStream, false)) + using (var server = new SslStream(serverStream, false)) + using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) + { + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + ApplicationProtocols = new List { SslApplicationProtocol.Http11 }, + RemoteCertificateValidationCallback = AllowAnyServerCertificate, + TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false), + }; + + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + ApplicationProtocols = new List { SslApplicationProtocol.Http2 }, + ServerCertificate = certificate, + }; + + Assert.ThrowsAsync(() => { return client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); }); + Assert.ThrowsAsync(() => { return server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None); }); + } + } + + internal static IEnumerable Alpn_TestData() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + yield return new object[] { new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, new List { SslApplicationProtocol.Http2 }, null }; + yield return new object[] { new List { SslApplicationProtocol.Http11 }, new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, null }; + yield return new object[] { new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, null }; + yield return new object[] { null, new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, null }; + yield return new object[] { new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, null, null }; + yield return new object[] { new List { SslApplicationProtocol.Http11 }, new List { SslApplicationProtocol.Http2 }, null }; + yield return new object[] { null, null, null }; + } + else + { + // Works on linux distros with openssl 1.0.2, CI machines Ubuntu14.04 and Debian 87 don't have openssl 1.0.2 + // Works on Windows OSes > 7.0 + bool featureWorks = (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) && !(PlatformDetection.IsUbuntu1404 || PlatformDetection.IsDebian)) || + (PlatformDetection.IsWindows && !PlatformDetection.IsWindows7); + + yield return new object[] { new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, new List { SslApplicationProtocol.Http2 }, featureWorks ? SslApplicationProtocol.Http2 : default }; + yield return new object[] { new List { SslApplicationProtocol.Http11 }, new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, featureWorks ? SslApplicationProtocol.Http11 : default }; + yield return new object[] { new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, featureWorks ? SslApplicationProtocol.Http11 : default }; + yield return new object[] { null, new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, default(SslApplicationProtocol) }; + yield return new object[] { new List { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 }, null, default(SslApplicationProtocol) }; + yield return new object[] { null, null, default(SslApplicationProtocol) }; + } + } + } +} diff --git a/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs b/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs index a0277c5bfa56..c625af8420df 100644 --- a/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs +++ b/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs @@ -145,8 +145,8 @@ public void SslStream_StreamToStream_LargeWrites_Sync_Success(bool randomizedDat { VirtualNetwork network = new VirtualNetwork(); - using (var clientStream = new VirtualNetworkStream(network, isServer:false)) - using (var serverStream = new VirtualNetworkStream(network, isServer:true)) + using (var clientStream = new VirtualNetworkStream(network, isServer: false)) + using (var serverStream = new VirtualNetworkStream(network, isServer: true)) using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate)) using (var serverSslStream = new SslStream(serverStream)) { @@ -232,8 +232,8 @@ public void SslStream_StreamToStream_Write_ReadByte_Success() { VirtualNetwork network = new VirtualNetwork(); - using (var clientStream = new VirtualNetworkStream(network, isServer:false)) - using (var serverStream = new VirtualNetworkStream(network, isServer:true)) + using (var clientStream = new VirtualNetworkStream(network, isServer: false)) + using (var serverStream = new VirtualNetworkStream(network, isServer: true)) using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate)) using (var serverSslStream = new SslStream(serverStream)) { @@ -354,7 +354,7 @@ private bool VerifyOutput(byte[] actualBuffer, byte[] expectedBuffer) return expectedBuffer.SequenceEqual(actualBuffer); } - private bool AllowAnyServerCertificate( + protected bool AllowAnyServerCertificate( object sender, X509Certificate certificate, X509Chain chain, diff --git a/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj b/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj index cd80ae70901a..bbe7f1e6f34d 100644 --- a/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj +++ b/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj @@ -7,7 +7,6 @@ true - @@ -20,13 +19,11 @@ - - @@ -40,12 +37,10 @@ - - Common\System\Net\Capability.Security.cs @@ -96,6 +91,7 @@ + diff --git a/src/System.Net.Security/tests/UnitTests/Configurations.props b/src/System.Net.Security/tests/UnitTests/Configurations.props index cfc57211a43f..7bdd3a170b90 100644 --- a/src/System.Net.Security/tests/UnitTests/Configurations.props +++ b/src/System.Net.Security/tests/UnitTests/Configurations.props @@ -2,8 +2,10 @@ - netstandard-Unix; - netstandard-Windows_NT; + netcoreapp-OSX; + netcoreapp-Unix; + netcoreapp-Windows_NT; + uap-Windows_NT; \ No newline at end of file diff --git a/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs b/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs index 89e892b2cd33..95f06aebe508 100644 --- a/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs +++ b/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs @@ -33,7 +33,7 @@ protected override void Dispose(bool disposing) public abstract bool IsSigned { get; } public abstract bool IsServer { get; } - public abstract Task WriteAsync(ReadOnlyMemory buffer, CancellationToken token); + public new abstract Task WriteAsync(ReadOnlyMemory buffer, CancellationToken token); } } diff --git a/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs b/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs index caae4c792302..98f481df35cb 100644 --- a/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs +++ b/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs @@ -17,16 +17,26 @@ internal class SslState // The public Client and Server classes enforce the parameters rules before // calling into this .ctor. // - internal SslState(Stream innerStream, RemoteCertValidationCallback certValidationCallback, LocalCertSelectionCallback certSelectionCallback, EncryptionPolicy encryptionPolicy) + internal SslState(Stream innerStream) { } - // - // - // - internal void ValidateCreateContext(bool isServer, string targetHost, SslProtocols enabledSslProtocols, X509Certificate serverCertificate, X509CertificateCollection clientCertificates, bool remoteCertRequired, bool checkCertRevocationStatus) + + internal void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions) + { + } + + internal void ValidateCreateContext(SslServerAuthenticationOptions sslServerAuthenticationOptions) { } + internal SslApplicationProtocol NegotiatedApplicationProtocol + { + get + { + return default; + } + } + internal bool IsAuthenticated { get @@ -269,7 +279,7 @@ public override void Write(byte[] buffer, int offset, int count) throw new NotImplementedException(); } - public Task WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + public new Task WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/src/System.Net.Security/tests/UnitTests/SslApplicationProtocolTests.cs b/src/System.Net.Security/tests/UnitTests/SslApplicationProtocolTests.cs new file mode 100644 index 000000000000..24cb8f5fa100 --- /dev/null +++ b/src/System.Net.Security/tests/UnitTests/SslApplicationProtocolTests.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Text; +using Xunit; + +namespace System.Net.Security.Tests +{ + public class SslApplicationProtocolTests + { + [Fact] + public void Constants_Values_AreCorrect() + { + Assert.Equal(new SslApplicationProtocol(new byte[] { 0x68, 0x32 }), SslApplicationProtocol.Http2); + Assert.Equal(new SslApplicationProtocol(new byte[] { 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31 }), SslApplicationProtocol.Http11); + } + + [Fact] + public void Constructor_Overloads_Succeeds() + { + const string hello = "hello"; + byte[] expected = Encoding.UTF8.GetBytes(hello); + SslApplicationProtocol byteProtocol = new SslApplicationProtocol(expected); + SslApplicationProtocol stringProtocol = new SslApplicationProtocol(hello); + Assert.Equal(byteProtocol, stringProtocol); + + SslApplicationProtocol defaultProtocol = default; + Assert.True(defaultProtocol.Protocol.IsEmpty); + + Assert.Throws(() => { new SslApplicationProtocol((byte[])null); }); + Assert.Throws(() => { new SslApplicationProtocol((string)null); }); + Assert.Throws(() => { new SslApplicationProtocol(new byte[] { }); }); + Assert.Throws(() => { new SslApplicationProtocol(string.Empty); }); + Assert.Throws(() => { new SslApplicationProtocol(Encoding.UTF8.GetBytes(new string('a', 256))); }); + Assert.Throws(() => { new SslApplicationProtocol(new string('a', 256)); }); + Assert.Throws(() => { new SslApplicationProtocol("\uDC00"); }); + } + + [Theory] + [MemberData(nameof(Protocol_Equality_TestData))] + public void Equality_Tests_Succeeds(SslApplicationProtocol left, SslApplicationProtocol right) + { + Assert.Equal(left, right); + Assert.True(left == right); + Assert.False(left != right); + Assert.Equal(left.GetHashCode(), right.GetHashCode()); + } + + [Theory] + [MemberData(nameof(Protocol_InEquality_TestData))] + public void InEquality_Tests_Succeeds(SslApplicationProtocol left, SslApplicationProtocol right) + { + Assert.NotEqual(left, right); + Assert.True(left != right); + Assert.False(left == right); + Assert.NotEqual(left.GetHashCode(), right.GetHashCode()); + } + + [Fact] + public void ToString_Rendering_Succeeds() + { + const string expected = "hello"; + SslApplicationProtocol protocol = new SslApplicationProtocol(expected); + Assert.Equal(expected, protocol.ToString()); + + byte[] bytes = new byte[] { 0x0B, 0xEE }; + protocol = new SslApplicationProtocol(bytes); + Assert.Equal("0x0b 0xee", protocol.ToString()); + + protocol = default; + Assert.Null(protocol.ToString()); + } + + public static IEnumerable Protocol_Equality_TestData() + { + yield return new object[] { new SslApplicationProtocol("hello"), new SslApplicationProtocol("hello") }; + yield return new object[] { new SslApplicationProtocol(new byte[] { 0x42 }), new SslApplicationProtocol(new byte[] { 0x42 }) }; + yield return new object[] { null, null }; + yield return new object[] { default, default }; + yield return new object[] { null, default }; + yield return new object[] { default, null }; + } + + public static IEnumerable Protocol_InEquality_TestData() + { + yield return new object[] { new SslApplicationProtocol("hello"), new SslApplicationProtocol("world") }; + yield return new object[] { new SslApplicationProtocol(new byte[] { 0x42 }), new SslApplicationProtocol(new byte[] { 0x52, 0x62 }) }; + yield return new object[] { null, new SslApplicationProtocol(new byte[] { 0x42 }) }; + yield return new object[] { new SslApplicationProtocol(new byte[] { 0x42 }), null }; + } + } +} diff --git a/src/System.Net.Security/tests/UnitTests/SslAuthenticationOptionsTests.cs b/src/System.Net.Security/tests/UnitTests/SslAuthenticationOptionsTests.cs new file mode 100644 index 000000000000..edc4160df277 --- /dev/null +++ b/src/System.Net.Security/tests/UnitTests/SslAuthenticationOptionsTests.cs @@ -0,0 +1,162 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using Xunit; + +namespace System.Net.Security.Tests +{ + public class SslAuthenticationOptionsTests + { + private readonly SslClientAuthenticationOptions _clientOptions = new SslClientAuthenticationOptions(); + private readonly SslServerAuthenticationOptions _serverOptions = new SslServerAuthenticationOptions(); + + [Fact] + public void AllowRenegotiation_Get_Set_Succeeds() + { + Assert.False(_clientOptions.AllowRenegotiation); + Assert.False(_serverOptions.AllowRenegotiation); + + _clientOptions.AllowRenegotiation = true; + _serverOptions.AllowRenegotiation = true; + + Assert.True(_clientOptions.AllowRenegotiation); + Assert.True(_serverOptions.AllowRenegotiation); + } + + [Fact] + public void ClientCertificateRequired_Get_Set_Succeeds() + { + Assert.False(_serverOptions.ClientCertificateRequired); + + _serverOptions.ClientCertificateRequired = true; + Assert.True(_serverOptions.ClientCertificateRequired); + } + + [Fact] + public void ApplicationProtocols_Get_Set_Succeeds() + { + Assert.Null(_clientOptions.ApplicationProtocols); + Assert.Null(_serverOptions.ApplicationProtocols); + + List applnProtos = new List { SslApplicationProtocol.Http2, SslApplicationProtocol.Http11 }; + _clientOptions.ApplicationProtocols = applnProtos; + _serverOptions.ApplicationProtocols = applnProtos; + + Assert.Equal(applnProtos, _clientOptions.ApplicationProtocols); + Assert.Equal(applnProtos, _serverOptions.ApplicationProtocols); + } + + [Fact] + public void RemoteCertificateValidationCallback_Get_Set_Succeeds() + { + Assert.Null(_clientOptions.RemoteCertificateValidationCallback); + Assert.Null(_serverOptions.RemoteCertificateValidationCallback); + + RemoteCertificateValidationCallback callback = (sender, certificate, chain, errors) => { return true; }; + _clientOptions.RemoteCertificateValidationCallback = callback; + _serverOptions.RemoteCertificateValidationCallback = callback; + + Assert.Equal(callback, _clientOptions.RemoteCertificateValidationCallback); + Assert.Equal(callback, _serverOptions.RemoteCertificateValidationCallback); + } + + [Fact] + public void LocalCertificateSelectionCallback_Get_Set_Succeeds() + { + Assert.Null(_clientOptions.LocalCertificateSelectionCallback); + + LocalCertificateSelectionCallback callback = (sender, host, localCertificates, remoteCertificate, issuers) => { return new X509Certificate(); }; + _clientOptions.LocalCertificateSelectionCallback = callback; + + Assert.Equal(callback, _clientOptions.LocalCertificateSelectionCallback); + } + + [Theory] + [InlineData("")] + [InlineData("\u0bee")] + [InlineData("hello")] + [InlineData(" \t")] + [InlineData(null)] + public void TargetHost_Get_Set_Succeeds(string expected) + { + Assert.Null(_clientOptions.TargetHost); + _clientOptions.TargetHost = expected; + Assert.Equal(expected, _clientOptions.TargetHost); + } + + [Fact] + public void ClientCertificates_Get_Set_Succeeds() + { + Assert.Null(_clientOptions.ClientCertificates); + + _clientOptions.ClientCertificates = null; + Assert.Null(_clientOptions.ClientCertificates); + + X509CertificateCollection expected = new X509CertificateCollection(); + _clientOptions.ClientCertificates = expected; + Assert.Equal(expected, _clientOptions.ClientCertificates); + } + + [Fact] + public void ServerCertificate_Get_Set_Succeeds() + { + Assert.Null(_serverOptions.ServerCertificate); + _serverOptions.ServerCertificate = null; + + Assert.Null(_serverOptions.ServerCertificate); + X509Certificate cert = new X509Certificate(); + _serverOptions.ServerCertificate = cert; + + Assert.Equal(cert, _serverOptions.ServerCertificate); + } + + [Fact] + public void EnabledSslProtocols_Get_Set_Succeeds() + { + Assert.Equal(SslProtocols.None, _clientOptions.EnabledSslProtocols); + Assert.Equal(SslProtocols.None, _serverOptions.EnabledSslProtocols); + + _clientOptions.EnabledSslProtocols = SslProtocols.Tls12; + _serverOptions.EnabledSslProtocols = SslProtocols.Tls12; + + Assert.Equal(SslProtocols.Tls12, _clientOptions.EnabledSslProtocols); + Assert.Equal(SslProtocols.Tls12, _serverOptions.EnabledSslProtocols); + } + + [Fact] + public void CheckCertificateRevocation_Get_Set_Succeeds() + { + Assert.Equal(X509RevocationMode.NoCheck, _clientOptions.CertificateRevocationCheckMode); + Assert.Equal(X509RevocationMode.NoCheck, _serverOptions.CertificateRevocationCheckMode); + + _clientOptions.CertificateRevocationCheckMode = X509RevocationMode.Online; + _serverOptions.CertificateRevocationCheckMode = X509RevocationMode.Offline; + + Assert.Equal(X509RevocationMode.Online, _clientOptions.CertificateRevocationCheckMode); + Assert.Equal(X509RevocationMode.Offline, _serverOptions.CertificateRevocationCheckMode); + + Assert.Throws(() => _clientOptions.CertificateRevocationCheckMode = (X509RevocationMode)3); + Assert.Throws(() => _serverOptions.CertificateRevocationCheckMode = (X509RevocationMode)3); + } + + [Fact] + public void EncryptionPolicy_Get_Set_Succeeds() + { + Assert.Equal(EncryptionPolicy.RequireEncryption, _clientOptions.EncryptionPolicy); + Assert.Equal(EncryptionPolicy.RequireEncryption, _serverOptions.EncryptionPolicy); + + _clientOptions.EncryptionPolicy = EncryptionPolicy.AllowNoEncryption; + _serverOptions.EncryptionPolicy = EncryptionPolicy.NoEncryption; + + Assert.Equal(EncryptionPolicy.AllowNoEncryption, _clientOptions.EncryptionPolicy); + Assert.Equal(EncryptionPolicy.NoEncryption, _serverOptions.EncryptionPolicy); + + Assert.Throws(() => _clientOptions.EncryptionPolicy = (EncryptionPolicy)3); + Assert.Throws(() => _serverOptions.EncryptionPolicy = (EncryptionPolicy)3); + } + } +} diff --git a/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj b/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj index db48c79c5374..d7a7ab2d6c0b 100644 --- a/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj +++ b/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj @@ -12,11 +12,17 @@ --> 436 - - - - + + + + + + + + + + @@ -34,6 +40,18 @@ ProductionCode\System\Net\Security\SslStream.cs + + ProductionCode\System\Net\Security\SslClientAuthenticationOptions.cs + + + ProductionCode\System\Net\Security\SslServerAuthenticationOptions.cs + + + ProductionCode\System\Net\Security\SslAuthenticationOptions.cs + + + ProductionCode\System\Net\Security\SslApplicationProtocol.cs + ProductionCode\System\Net\SslStreamContext.cs @@ -57,5 +75,13 @@ ProductionCode\Common\System\Net\InternalException.cs + + + + + + + +