Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ internal sealed class State

// These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
public MsQuicConnection? Connection;
public MsQuicListener.State? ListenerState;

public TaskCompletionSource<uint>? ConnectTcs;
// TODO: only allocate these when there is an outstanding shutdown.
Expand Down Expand Up @@ -135,11 +136,10 @@ public void SetClosing()
internal string TraceId() => _state.TraceId;

// constructor for inbound connections
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, MsQuicListener.State listenerState, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
_state.Connected = true;
_state.RemoteCertificateRequired = remoteCertificateRequired;
_state.RevocationMode = revocationMode;
_state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback;
Expand All @@ -161,6 +161,7 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Saf
throw;
}

_state.ListenerState = listenerState;
_state.TraceId = MsQuicTraceHelper.GetTraceId(_state.Handle);
if (NetEventSource.Log.IsEnabled())
{
Expand Down Expand Up @@ -223,7 +224,34 @@ public MsQuicConnection(QuicClientConnectionOptions options)

private static uint HandleEventConnected(State state, ref ConnectionEvent connectionEvent)
{
if (!state.Connected)
if (state.Connected)
{
return MsQuicStatusCodes.Success;
}

if (state.IsServer)
{
state.Connected = true;
MsQuicListener.State? listenerState = state.ListenerState;
state.ListenerState = null;

if (listenerState != null)
{
if (listenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
{
// Move connection from pending to Accept queue and hand it out.
if (listenerState.AcceptConnectionQueue.Writer.TryWrite(connection))
{
return MsQuicStatusCodes.Success;
}
// Listener is closed
connection.Dispose();
}
}

return MsQuicStatusCodes.UserCanceled;
}
else
{
// Connected will already be true for connections accepted from a listener.
Debug.Assert(!Monitor.IsEntered(state));
Expand Down Expand Up @@ -271,6 +299,18 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
// This is the final event on the connection, so free the GCHandle used by the event callback.
state.StateGCHandle.Free();

if (state.ListenerState != null)
{
// This is inbound connection that never got connected - becasue of TLS validation or some other reason.
// Remove connection from pending queue and dispose it.
if (state.ListenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
{
connection.Dispose();
}

state.ListenerState = null;
}

state.Connection = null;

state.ShutdownTcs.SetResult(MsQuicStatusCodes.Success);
Expand All @@ -297,6 +337,7 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
{
bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
}

return MsQuicStatusCodes.Success;
}

Expand Down Expand Up @@ -418,6 +459,11 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti

if (!success)
{
if (state.IsServer)
{
return MsQuicStatusCodes.UserCanceled;
}

throw new AuthenticationException(SR.net_quic_cert_custom_validation);
}

Expand All @@ -430,6 +476,11 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti

if (sslPolicyErrors != SslPolicyErrors.None)
{
if (state.IsServer)
{
return MsQuicStatusCodes.HandshakeFailure;
}

throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Buffers;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Net.Security;
Expand All @@ -25,14 +26,15 @@ internal sealed class MsQuicListener : QuicListenerProvider, IDisposable

private readonly IPEndPoint _listenEndPoint;

private sealed class State
internal sealed class State
{
// set immediately in ctor, but we need a GCHandle to State in order to create the handle.
public SafeMsQuicListenerHandle Handle = null!;
public string TraceId = null!; // set in ctor.

public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
public readonly Channel<MsQuicConnection> AcceptConnectionQueue;
public readonly ConcurrentDictionary<IntPtr, MsQuicConnection> PendingConnections;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not cleaning the dictionary in case the listener gets disposed. I see that we're not cleaning AcceptConnectionQueue either, so I guess it's not an issue. I'm just making sure it's intentional.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going back & forth on this one. With Connection & Stream we have FlushAcceptQueue so when the Connection is disposed, we would actively nuke all the stream.

I think for Listener everything would decompose natural since we are not playing tricks and we would release the _stateHandle. We could certainly speed it up by cleaning both. But I'm not sure if that matters as much as unlike connection, I would expect Listener would last long so this would be rare operation
I think it would be easy to add. adding @stephentoub and @geoffkizer in case they have suggestions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Connection & Stream we had to make sure that we release Connection's SafeHandle only after all Stream handles are released. I don't think Connection has the same relation to the Listener, so the order of finalization doesn't matter...


public QuicOptions ConnectionOptions = new QuicOptions();
public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();
Expand Down Expand Up @@ -66,6 +68,7 @@ public State(QuicListenerOptions options)
ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions);
}

PendingConnections = new ConcurrentDictionary<IntPtr, MsQuicConnection>();
AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
{
SingleReader = true,
Expand Down Expand Up @@ -229,7 +232,6 @@ private static unsafe uint NativeCallbackHandler(

SafeMsQuicConnectionHandle? connectionHandle = null;
MsQuicConnection? msQuicConnection = null;

try
{
ref NewConnectionInfo connectionInfo = ref *evt.Data.NewConnection.Info;
Expand Down Expand Up @@ -273,13 +275,15 @@ private static unsafe uint NativeCallbackHandler(
uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
if (MsQuicStatusHelper.SuccessfulStatusCode(status))
{
msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, state, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);

if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
if (!state.PendingConnections.TryAdd(connectionHandle.DangerousGetHandle(), msQuicConnection))
{
return MsQuicStatusCodes.Success;
msQuicConnection.Dispose();
}

return MsQuicStatusCodes.Success;
}

// If we fall-through here something wrong happened.
Expand Down
70 changes: 61 additions & 9 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,48 @@ public async Task ConnectWithCertificateChain()
clientConnection.Dispose();
}

[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public async Task UntrustedClientCertificateFails()
{
var listenerOptions = new QuicListenerOptions();
listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
return false;
};

using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
clientOptions.RemoteEndPoint = listener.ListenEndPoint;
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
QuicConnection clientConnection = CreateQuicConnection(clientOptions);

using CancellationTokenSource cts = new CancellationTokenSource();
cts.CancelAfter(500); //Some delay to see if we would get failed connection.
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();

ValueTask t = clientConnection.ConnectAsync(cts.Token);

t.AsTask().Wait(PassingTestTimeout);
await Assert.ThrowsAsync<OperationCanceledException>(() => serverTask);
// The task will likely succed but we don't really care.
// It may fail if the server aborts quickly.
try
{
await t;
}
catch { };
}

[Fact]
public async Task CertificateCallbackThrowPropagates()
{
using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
X509Certificate? receivedCertificate = null;
bool validationResult = false;

var listenerOptions = new QuicListenerOptions();
listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
Expand All @@ -118,18 +155,26 @@ public async Task CertificateCallbackThrowPropagates()
clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
receivedCertificate = cert;
if (validationResult)
{
return validationResult;
}

throw new ArithmeticException("foobar");
};

clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1";
QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);

Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
await Assert.ThrowsAsync<ArithmeticException>(() => clientConnection.ConnectAsync(cts.Token).AsTask());
QuicConnection serverConnection = await serverTask;

Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
clientConnection.Dispose();

// Make sure the listner is still usable and there is no lingering bad conenction
validationResult = true;
(clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
await PingPong(clientConnection, serverConnection);
clientConnection.Dispose();
serverConnection.Dispose();
}
Expand Down Expand Up @@ -226,7 +271,6 @@ public async Task ConnectWithCertificateForDifferentName_Throws()
using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await Assert.ThrowsAsync<AuthenticationException>(async () => await clientTask);
}

Expand Down Expand Up @@ -257,9 +301,11 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
}

[Fact]
[Theory]
[PlatformSpecific(TestPlatforms.Windows)]
public async Task ConnectWithClientCertificate()
[InlineData(true)]
// [InlineData(false)] [ActiveIssue("https://github.com/dotnet/runtime/issues/57308")]
public async Task ConnectWithClientCertificate(bool sendCerttificate)
{
bool clientCertificateOK = false;

Expand All @@ -269,17 +315,23 @@ public async Task ConnectWithClientCertificate()
listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
_output.WriteLine("client certificate {0}", cert);
Assert.NotNull(cert);
Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
if (sendCerttificate)
{
_output.WriteLine("client certificate {0}", cert);
Assert.NotNull(cert);
Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
}

clientCertificateOK = true;
return true;
};

using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
if (sendCerttificate)
{
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
}
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);

// Verify functionality of the connections.
Expand Down