From 732b9449d279be287aeed01db03203cd329dbf43 Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Fri, 16 Apr 2021 02:44:16 -0700 Subject: [PATCH 1/8] WIP --- .../Common/src/System/Net/StreamBuffer.cs | 16 + .../System/Net/Http/Http3LoopbackStream.cs | 2 +- .../SocketsHttpHandler/Http3Connection.cs | 2 +- .../SocketsHttpHandler/Http3RequestStream.cs | 14 +- .../System.Net.Quic/System.Net.Quic.sln | 35 +- .../System.Net.Quic/ref/System.Net.Quic.cs | 16 +- .../src/Resources/Strings.resx | 57 +- .../Quic/Implementations/Mock/MockStream.cs | 69 +-- .../MsQuic/Interop/MsQuicStatusCodes.cs | 1 + .../Implementations/MsQuic/MsQuicStream.cs | 544 +++++++++--------- .../Implementations/QuicStreamProvider.cs | 16 +- .../src/System/Net/Quic/QuicAbortDirection.cs | 13 + .../src/System/Net/Quic/QuicStream.cs | 27 +- .../tests/FunctionalTests/MsQuicTests.cs | 6 +- .../tests/FunctionalTests/QuicStreamTests.cs | 353 +++++++++--- .../tests/FunctionalTests/QuicTestBase.cs | 40 ++ 16 files changed, 726 insertions(+), 485 deletions(-) create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs diff --git a/src/libraries/Common/src/System/Net/StreamBuffer.cs b/src/libraries/Common/src/System/Net/StreamBuffer.cs index 6759fcdd8e20b0..5ccab63f1f2e12 100644 --- a/src/libraries/Common/src/System/Net/StreamBuffer.cs +++ b/src/libraries/Common/src/System/Net/StreamBuffer.cs @@ -18,6 +18,7 @@ internal sealed class StreamBuffer : IDisposable private bool _readAborted; private readonly ResettableValueTaskSource _readTaskSource; private readonly ResettableValueTaskSource _writeTaskSource; + private readonly TaskCompletionSource _shutdownTaskSource; public const int DefaultInitialBufferSize = 4 * 1024; public const int DefaultMaxBufferSize = 32 * 1024; @@ -28,10 +29,13 @@ public StreamBuffer(int initialBufferSize = DefaultInitialBufferSize, int maxBuf _maxBufferSize = maxBufferSize; _readTaskSource = new ResettableValueTaskSource(); _writeTaskSource = new ResettableValueTaskSource(); + _shutdownTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } private object SyncObject => _readTaskSource; + public Task Completed => _shutdownTaskSource.Task; + public bool IsComplete { get @@ -187,6 +191,11 @@ public void EndWrite() _writeEnded = true; _readTaskSource.SignalWaiter(); + + if (_buffer.IsEmpty) + { + _shutdownTaskSource.TrySetResult(); + } } } @@ -210,10 +219,16 @@ public void EndWrite() _writeTaskSource.SignalWaiter(); + if (_buffer.IsEmpty && _writeEnded) + { + _shutdownTaskSource.TrySetResult(); + } + return (false, bytesRead); } else if (_writeEnded) { + _shutdownTaskSource.TrySetResult(); return (false, 0); } @@ -280,6 +295,7 @@ public void AbortRead() _readTaskSource.SignalWaiter(); _writeTaskSource.SignalWaiter(); + _shutdownTaskSource.TrySetResult(); } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index faccfab64b27eb..294dbb89ce5b5d 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -118,7 +118,7 @@ public async Task SendFrameAsync(long frameType, ReadOnlyMemory framePaylo public async Task ShutdownSendAsync() { - _stream.Shutdown(); + await _stream.CompleteWritesAsync().ConfigureAwait(false); await _stream.ShutdownWriteCompleted().ConfigureAwait(false); } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs index ca6dd5df9fc830..a9df7b6994f6cf 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs @@ -614,7 +614,7 @@ private async Task ProcessServerStreamAsync(QuicStream stream) NetEventSource.Info(this, $"Ignoring server-initiated stream of unknown type {unknownStreamType}."); } - stream.AbortWrite((long)Http3ErrorCode.StreamCreationError); + stream.Abort((long)Http3ErrorCode.StreamCreationError, QuicAbortDirection.Read); return; } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 970890ddcd9fbe..5c53ebacc97d11 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -152,7 +152,7 @@ public async Task SendAsync(CancellationToken cancellationT } else { - _stream.Shutdown(); + _stream.CompleteWrites(); } } @@ -263,7 +263,7 @@ public async Task SendAsync(CancellationToken cancellationT if (cancellationToken.IsCancellationRequested) { - _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + _stream.Abort((long)Http3ErrorCode.RequestCancelled); throw new OperationCanceledException(ex.Message, ex, cancellationToken); } else @@ -280,7 +280,7 @@ public async Task SendAsync(CancellationToken cancellationT } catch (Exception ex) { - _stream.AbortWrite((long)Http3ErrorCode.InternalError); + _stream.Abort((long)Http3ErrorCode.InternalError); if (ex is HttpRequestException) { throw; @@ -372,7 +372,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance _sendBuffer.Discard(_sendBuffer.ActiveLength); } - _stream.Shutdown(); + _stream.CompleteWrites(); } private async ValueTask WriteRequestContentAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) @@ -777,7 +777,7 @@ private async ValueTask ReadHeadersAsync(long headersLength, CancellationToken c // https://tools.ietf.org/html/draft-ietf-quic-http-24#section-4.1.1 if (headersLength > _headerBudgetRemaining) { - _stream.AbortWrite((long)Http3ErrorCode.ExcessiveLoad); + _stream.Abort((long)Http3ErrorCode.ExcessiveLoad); throw new HttpRequestException(SR.Format(SR.net_http_response_headers_exceeded_length, _connection.Pool.Settings._maxResponseHeadersLength * 1024L)); } @@ -1114,11 +1114,11 @@ private void HandleReadResponseContentException(Exception ex, CancellationToken _connection.Abort(ex); throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex)); case OperationCanceledException oce when oce.CancellationToken == cancellationToken: - _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + _stream.Abort((long)Http3ErrorCode.RequestCancelled); ExceptionDispatchInfo.Throw(ex); // Rethrow. return; // Never reached. default: - _stream.AbortWrite((long)Http3ErrorCode.InternalError); + _stream.Abort((long)Http3ErrorCode.InternalError); throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex)); } } diff --git a/src/libraries/System.Net.Quic/System.Net.Quic.sln b/src/libraries/System.Net.Quic/System.Net.Quic.sln index 3d6fa4fc85246c..3ad2d96bdcfbca 100644 --- a/src/libraries/System.Net.Quic/System.Net.Quic.sln +++ b/src/libraries/System.Net.Quic/System.Net.Quic.sln @@ -1,4 +1,8 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.31220.234 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{55C933AA-2735-4B38-A1DD-01A27467AB18}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Win32.Registry", "..\Microsoft.Win32.Registry\ref\Microsoft.Win32.Registry.csproj", "{69CDCFD5-AA35-40D8-A437-ED1C06E9CA95}" @@ -23,18 +27,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{4BABFE90-C81 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "StreamConformanceTests", "..\Common\tests\StreamConformanceTests\StreamConformanceTests.csproj", "{CCE2D0B0-BDBE-4750-B215-2517286510EB}" +EndProject Global - GlobalSection(NestedProjects) = preSolution - {55C933AA-2735-4B38-A1DD-01A27467AB18} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} - {E8E7DD3A-EC3F-4472-9F70-B515A3D11038} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} - {69CDCFD5-AA35-40D8-A437-ED1C06E9CA95} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {D7A52855-C6DE-4FD0-9CAF-E55F292C69E5} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {7BB8C50D-4770-42CB-BE15-76AD623A5AE8} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {833418C5-FEC9-482F-A0D6-69DFC332C1B6} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {E1CABA2F-48AD-49FA-B872-BEED78C51980} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {4F87758B-D1AF-4DE3-A9A2-68B1558C02B7} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} - {9D56BA9E-1B0D-4320-9FE9-A2D326A32BE0} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} - EndGlobalSection GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU Release|Any CPU = Release|Any CPU @@ -76,10 +71,26 @@ Global {E1CABA2F-48AD-49FA-B872-BEED78C51980}.Debug|Any CPU.Build.0 = Debug|Any CPU {E1CABA2F-48AD-49FA-B872-BEED78C51980}.Release|Any CPU.ActiveCfg = Release|Any CPU {E1CABA2F-48AD-49FA-B872-BEED78C51980}.Release|Any CPU.Build.0 = Release|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {55C933AA-2735-4B38-A1DD-01A27467AB18} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} + {69CDCFD5-AA35-40D8-A437-ED1C06E9CA95} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {D7A52855-C6DE-4FD0-9CAF-E55F292C69E5} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {4F87758B-D1AF-4DE3-A9A2-68B1558C02B7} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} + {E8E7DD3A-EC3F-4472-9F70-B515A3D11038} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} + {7BB8C50D-4770-42CB-BE15-76AD623A5AE8} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {9D56BA9E-1B0D-4320-9FE9-A2D326A32BE0} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} + {833418C5-FEC9-482F-A0D6-69DFC332C1B6} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {E1CABA2F-48AD-49FA-B872-BEED78C51980} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {CCE2D0B0-BDBE-4750-B215-2517286510EB} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} + EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {4B59ACCA-7F0C-4062-AA79-B3D75EFACCCD} EndGlobalSection diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 4f4abcf7e5236e..96299c6fd3ee6e 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -6,6 +6,13 @@ namespace System.Net.Quic { + [System.FlagsAttribute] + public enum QuicAbortDirection + { + Read = 1, + Write = 2, + Both = 3, + } public partial class QuicClientConnectionOptions : System.Net.Quic.QuicOptions { public QuicClientConnectionOptions() { } @@ -85,11 +92,13 @@ internal QuicStream() { } public override long Length { get { throw null; } } public override long Position { get { throw null; } set { } } public long StreamId { get { throw null; } } - public void AbortRead(long errorCode) { } - public void AbortWrite(long errorCode) { } + public void Abort(long errorCode, System.Net.Quic.QuicAbortDirection abortDirection = System.Net.Quic.QuicAbortDirection.Both) { } public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } + public System.Threading.Tasks.ValueTask CloseAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public void CompleteWrites() { } protected override void Dispose(bool disposing) { } + public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } public override int EndRead(System.IAsyncResult asyncResult) { throw null; } public override void EndWrite(System.IAsyncResult asyncResult) { } public override void Flush() { } @@ -100,9 +109,6 @@ public override void Flush() { } public override System.Threading.Tasks.ValueTask ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; } public override void SetLength(long value) { } - public void Shutdown() { } - public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void Write(byte[] buffer, int offset, int count) { } public override void Write(System.ReadOnlySpan buffer) { } public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx index a29352a0578f59..a702b67aeb9eaa 100644 --- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx @@ -1,17 +1,17 @@  - @@ -150,5 +150,4 @@ Writing is not allowed on stream. - - + \ No newline at end of file diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index 14ecead9a7f88e..2aa698a72df075 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -151,46 +151,28 @@ internal override Task FlushAsync(CancellationToken cancellationToken) return Task.CompletedTask; } - internal override void AbortRead(long errorCode) + internal override void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both) { - throw new NotImplementedException(); - } + // TODO: support abort read direction. - internal override void AbortWrite(long errorCode) - { - if (_isInitiator) - { - _streamState._outboundErrorCode = errorCode; - } - else + if (abortDirection.HasFlag(QuicAbortDirection.Write)) { - _streamState._inboundErrorCode = errorCode; - } - - WriteStreamBuffer?.EndWrite(); - } - - - internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) - { - CheckDisposed(); - - return default; - } - - - internal override ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) - { - CheckDisposed(); + if (_isInitiator) + { + _streamState._outboundErrorCode = errorCode; + } + else + { + _streamState._inboundErrorCode = errorCode; + } - return default; + WriteStreamBuffer?.EndWrite(); + } } - internal override void Shutdown() + public override void CompleteWrites() { CheckDisposed(); - - // This seems to mean shutdown send, in particular, not both. WriteStreamBuffer?.EndWrite(); } @@ -206,29 +188,38 @@ public override void Dispose() { if (!_disposed) { - Shutdown(); + CompleteWrites(); + + _streamState._outboundStreamBuffer.Completed.GetAwaiter().GetResult(); _disposed = true; } } - public override ValueTask DisposeAsync() + public override async ValueTask DisposeAsync(CancellationToken cancellationToken) { if (!_disposed) { - Shutdown(); + CompleteWrites(); + + if (ReadStreamBuffer is StreamBuffer readStreamBuffer) + { + await ReadStreamBuffer.Completed.WaitAsync(cancellationToken).ConfigureAwait(false); + } + else + { + cancellationToken.ThrowIfCancellationRequested(); + } _disposed = true; } - - return default; } internal sealed class StreamState { public readonly long _streamId; - public StreamBuffer _outboundStreamBuffer; - public StreamBuffer? _inboundStreamBuffer; + public readonly StreamBuffer _outboundStreamBuffer; + public readonly StreamBuffer? _inboundStreamBuffer; public long _outboundErrorCode; public long _inboundErrorCode; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs index 50f736d429f7f6..b27491eab4cad6 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs @@ -7,6 +7,7 @@ internal static class MsQuicStatusCodes { internal static uint Success => OperatingSystem.IsWindows() ? Windows.Success : Posix.Success; internal static uint Pending => OperatingSystem.IsWindows() ? Windows.Pending : Posix.Pending; + internal static uint Continue => OperatingSystem.IsWindows() ? Windows.Continue : Posix.Continue; internal static uint InternalError => OperatingSystem.IsWindows() ? Windows.InternalError : Posix.InternalError; internal static uint InvalidState => OperatingSystem.IsWindows() ? Windows.InvalidState : Posix.InvalidState; internal static uint HandshakeFailure => OperatingSystem.IsWindows() ? Windows.HandshakeFailure : Posix.HandshakeFailure; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index f1bfb2efdc86db..722cd24896c964 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Reflection; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Threading; @@ -40,11 +40,20 @@ private sealed class State public SafeMsQuicStreamHandle Handle = null!; // set in ctor. public ReadState ReadState; + + // set when ReadState.Aborted: public long ReadErrorCode = -1; - public readonly List ReceiveQuicBuffers = new List(); - // Resettable completions to be used for multiple calls to receive. - public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); + // filled when ReadState.BuffersAvailable: + public QuicBuffer[] ReceiveQuicBuffers = Array.Empty(); + public int ReceiveQuicBuffersCount; + public int ReceiveQuicBuffersTotalBytes; + + // set when ReadState.PendingRead: + public Memory ReceiveUserBuffer; + public CancellationTokenRegistration ReceiveCancellationRegistration; + public MsQuicStream? RootedReceiveStream; // roots the stream in the pinned state to prevent GC during an async read I/O. + public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); public SendState SendState; public long SendErrorCode = -1; @@ -55,18 +64,13 @@ private sealed class State public int SendBufferMaxCount; public int SendBufferCount; - // Resettable completions to be used for multiple calls to send, start, and shutdown. - public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); - - public ShutdownWriteState ShutdownWriteState; + // Roots the stream in the pinned state to prevent GC during an async dispose. + public MsQuicStream? RootedDisposeStream; - // Set once writes have been shutdown. - public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - - public ShutdownState ShutdownState; + // Resettable completions to be used for multiple calls to send, start. + public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); - // Set once stream have been shutdown. + // Set once both peers have fully shut down their side of the stream. public readonly TaskCompletionSource ShutdownCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } @@ -257,7 +261,7 @@ private void HandleWriteFailedState() } } - internal override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + internal override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) { ThrowIfDisposed(); @@ -271,188 +275,176 @@ internal override async ValueTask ReadAsync(Memory destination, Cance NetEventSource.Info(this, $"[{GetHashCode()}] reading into Memory of '{destination.Length}' bytes."); } + ReadState readState; + long abortError = -1; + bool canceledSynchronously = false; + lock (_state) { - if (_state.ReadState == ReadState.ReadsCompleted) + readState = _state.ReadState; + abortError = _state.ReadErrorCode; + + if (readState != ReadState.PendingRead && cancellationToken.IsCancellationRequested) { - return 0; + readState = ReadState.Aborted; + _state.ReadState = ReadState.Aborted; + canceledSynchronously = true; } - else if (_state.ReadState == ReadState.Aborted) + else if (readState == ReadState.None) { - throw _state.ReadErrorCode switch - { - -1 => new QuicOperationAbortedException(), - long err => new QuicStreamAbortedException(err) - }; - } - } + Debug.Assert(_state.RootedReceiveStream is null); - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ReadState == ReadState.None) + _state.ReceiveUserBuffer = destination; + _state.RootedReceiveStream = this; + _state.ReadState = ReadState.PendingRead; + + if (cancellationToken.CanBeCanceled) { - shouldComplete = true; + _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) => + { + var state = (State)obj!; + bool completePendingRead; + + lock (state) + { + completePendingRead = state.ReadState == ReadState.PendingRead; + state.RootedReceiveStream = null; + state.ReadState = ReadState.Aborted; + } + + if (completePendingRead) + { + state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(token))); + } + }, _state); + } + else + { + _state.ReceiveCancellationRegistration = default; } - state.ReadState = ReadState.Aborted; + return _state.ReceiveResettableCompletionSource.GetValueTask(); } - - if (shouldComplete) + else if (readState == ReadState.BuffersAvailable) { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Read was canceled", token))); - } - }, _state); - - // TODO there could potentially be a perf gain by storing the buffer from the initial read - // This reduces the amount of async calls, however it makes it so MsQuic holds onto the buffers - // longer than it needs to. We will need to benchmark this. - int length = (int)await _state.ReceiveResettableCompletionSource.GetValueTask().ConfigureAwait(false); + _state.ReadState = ReadState.None; - int actual = Math.Min(length, destination.Length); + int taken = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span); + ReceiveComplete(taken); - static unsafe void CopyToBuffer(Span destinationBuffer, List sourceBuffers) - { - Span slicedBuffer = destinationBuffer; - for (int i = 0; i < sourceBuffers.Count; i++) - { - QuicBuffer nativeBuffer = sourceBuffers[i]; - int length = Math.Min((int)nativeBuffer.Length, slicedBuffer.Length); - new Span(nativeBuffer.Buffer, length).CopyTo(slicedBuffer); - if (length < nativeBuffer.Length) + if (taken != _state.ReceiveQuicBuffersTotalBytes) { - // The buffer passed in was larger that the received data, return - return; + // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer. + EnableReceive(); } - slicedBuffer = slicedBuffer.Slice(length); - } - } - CopyToBuffer(destination.Span, _state.ReceiveQuicBuffers); - - lock (_state) - { - if (_state.ReadState == ReadState.IndividualReadComplete) - { - _state.ReceiveQuicBuffers.Clear(); - ReceiveComplete(actual); - EnableReceive(); - _state.ReadState = ReadState.None; + return new ValueTask(taken); } } - return actual; - } - - // TODO do we want this to be a synchronization mechanism to cancel a pending read - // If so, we need to complete the read here as well. - internal override void AbortRead(long errorCode) - { - ThrowIfDisposed(); + Exception? ex = null; - lock (_state) + switch (readState) { - _state.ReadState = ReadState.Aborted; + case ReadState.EndOfReadStream: + return new ValueTask(0); + case ReadState.PendingRead: + ex = new InvalidOperationException("Only one read is supported at a time."); + break; + case ReadState.Aborted: + default: + Debug.Assert(readState == ReadState.Aborted, $"{nameof(ReadState)} of '{readState}' is unaccounted for in {nameof(ReadAsync)}."); + + ex = + canceledSynchronously ? new OperationCanceledException(cancellationToken) : // aborted by token being canceled before the async op started. + abortError == -1 ? new QuicOperationAbortedException() : // aborted by user via some other operation. + new QuicStreamAbortedException(abortError); // aborted by peer. + break; } - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, errorCode); + return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(ex!)); } - internal override void AbortWrite(long errorCode) + /// The number of bytes copied. + private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan sourceBuffers, Span destinationBuffer) { - ThrowIfDisposed(); + Debug.Assert(sourceBuffers.Length != 0); - bool shouldComplete = false; + int originalDestinationLength = destinationBuffer.Length; + QuicBuffer nativeBuffer; + int takeLength = 0; + int i = 0; - lock (_state) + do { - if (_state.ShutdownWriteState == ShutdownWriteState.None) - { - _state.ShutdownWriteState = ShutdownWriteState.Canceled; - shouldComplete = true; - } - } + nativeBuffer = sourceBuffers[i]; + takeLength = Math.Min((int)nativeBuffer.Length, destinationBuffer.Length); - if (shouldComplete) - { - _state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException("Shutdown was aborted.", errorCode))); + new Span(nativeBuffer.Buffer, takeLength).CopyTo(destinationBuffer); + destinationBuffer = destinationBuffer.Slice(takeLength); } + while (destinationBuffer.Length != 0 && ++i < sourceBuffers.Length); - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode); + return originalDestinationLength - destinationBuffer.Length; } - private void StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) + // We don't wait for QUIC_STREAM_EVENT_SEND_SHUTDOWN_COMPLETE event here, + // because it is only sent to us once the peer has acknowledged the shutdown. + // Instead, this method acts more like shutdown(SD_SEND) in that it only "queues" + // the shutdown packet to be sent without any waiting for completion. + public override void CompleteWrites() { - uint status = MsQuicApi.Api.StreamShutdownDelegate(_state.Handle, flags, errorCode); - QuicExceptionHelpers.ThrowIfFailed(status, "StreamShutdown failed."); + ThrowIfDisposed(); + + // Error code is ignored for graceful shutdown. + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } - internal override async ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) + internal override void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both) { ThrowIfDisposed(); - // TODO do anything to stop writes? - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => + QUIC_STREAM_SHUTDOWN_FLAGS flags = QUIC_STREAM_SHUTDOWN_FLAGS.NONE; + bool completeWrites = false; + bool completeReads = false; + + lock (_state) { - var state = (State)s!; - bool shouldComplete = false; - lock (state) + if (abortDirection.HasFlag(QuicAbortDirection.Write)) { - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Canceled; // TODO: should we separate states for cancelling here vs calling Abort? - shouldComplete = true; - } + completeWrites = _state.SendState == SendState.None; + _state.SendState = SendState.Aborted; + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND; } - if (shouldComplete) + if (abortDirection.HasFlag(QuicAbortDirection.Read)) { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown write was canceled", token))); + completeReads = _state.ReadState == ReadState.PendingRead; + _state.RootedReceiveStream = null; + _state.ReadState = ReadState.Aborted; + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; } - }, _state); - - await _state.ShutdownWriteCompletionSource.Task.ConfigureAwait(false); - } + } - internal override async ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) - { - ThrowIfDisposed(); + StartShutdownOrAbort(flags, errorCode); - // TODO do anything to stop writes? - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => + if (completeWrites) { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ShutdownState == ShutdownState.None) - { - state.ShutdownState = ShutdownState.Canceled; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown was canceled", token))); - } - }, _state); + _state.SendResettableCompletionSource.Complete(0); + } - await _state.ShutdownCompletionSource.Task.ConfigureAwait(false); + if (completeReads) + { + _state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); + } } - internal override void Shutdown() + /// + /// For abortive flags, the error code sent to peer. Otherwise, ignored. + private void StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { - ThrowIfDisposed(); - // it is ok to send shutdown several times, MsQuic will ignore it - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + uint status = MsQuicApi.Api.StreamShutdownDelegate(_state.Handle, flags, errorCode); + QuicExceptionHelpers.ThrowIfFailed(status, "StreamShutdown failed."); } // TODO consider removing sync-over-async with blocking calls. @@ -485,42 +477,69 @@ internal override Task FlushAsync(CancellationToken cancellationToken = default) return Task.CompletedTask; } - public override ValueTask DisposeAsync() - { - // TODO: perform a graceful shutdown and wait for completion? + public override ValueTask DisposeAsync(CancellationToken cancellationToken) => + DisposeAsync(cancellationToken, async: true, immediate: _state.SendState == SendState.Aborted); - Dispose(true); - return default; - } + public override void Dispose() => + Dispose(immediate: _state.SendState == SendState.Aborted); - public override void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } + ~MsQuicStream() => + Dispose(immediate: true); - ~MsQuicStream() + private void Dispose(bool immediate) { - Dispose(false); + ValueTask t = DisposeAsync(cancellationToken: default, async: false, immediate); + Debug.Assert(t.IsCompleted); + t.GetAwaiter().GetResult(); } - private void Dispose(bool disposing) + /// + /// + /// When true, causes immediate disposal without waiting for peer ACKs. + /// + private async ValueTask DisposeAsync(CancellationToken cancellationToken, bool async, bool immediate) { if (_disposed) { return; } + QUIC_STREAM_SHUTDOWN_FLAGS flags = immediate + ? (QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL | QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE) + : QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL; + + StartShutdownOrAbort(flags, errorCode: 0); + + if (async) + { + _state.RootedDisposeStream = this; + try + { + await _state.ShutdownCompletionSource.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + _state.RootedDisposeStream = null; + } + } + else + { + _state.ShutdownCompletionSource.Task.GetAwaiter().GetResult(); + } + _disposed = true; _state.Handle.Dispose(); Marshal.FreeHGlobal(_state.SendQuicBuffers); if (_stateHandle.IsAllocated) _stateHandle.Free(); CleanupSendState(_state); + + GC.SuppressFinalize(this); } private void EnableReceive() { - MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + uint status = MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + QuicExceptionHelpers.ThrowIfFailed(status, "StreamReceiveSetEnabled failed."); } private static uint NativeCallbackHandler( @@ -534,11 +553,6 @@ private static uint NativeCallbackHandler( private static uint HandleEvent(State state, ref StreamEvent evt) { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(state, $"[{state.GetHashCode()}] received event {evt.Type}"); - } - try { switch ((QUIC_STREAM_EVENT_TYPE)evt.Type) @@ -563,11 +577,6 @@ private static uint HandleEvent(State state, ref StreamEvent evt) // Peer has stopped receiving data, don't send anymore. case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED: return HandleEventPeerRecvAborted(state, ref evt); - // Occurs when shutdown is completed for the send side. - // This only happens for shutdown on sending, not receiving - // Receive shutdown can only be abortive. - case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE: - return HandleEventSendShutdownComplete(state, ref evt); // Shutdown for both sending and receiving is completed. case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE: return HandleEventShutdownComplete(state); @@ -583,47 +592,85 @@ private static uint HandleEvent(State state, ref StreamEvent evt) private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) { - StreamEventDataReceive receiveEvent = evt.Data.Receive; - for (int i = 0; i < receiveEvent.BufferCount; i++) - { - state.ReceiveQuicBuffers.Add(receiveEvent.Buffers[i]); - } + ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive; + + int readLength; - bool shouldComplete = false; lock (state) { - if (state.ReadState == ReadState.None) + switch (state.ReadState) { - shouldComplete = true; + case ReadState.None: + // ReadAsync() hasn't been called yet. Stash the buffer so the next ReadAsync call completes synchronously. + + if ((uint)state.ReceiveQuicBuffers.Length < receiveEvent.BufferCount) + { + QuicBuffer[] oldReceiveBuffers = state.ReceiveQuicBuffers; + state.ReceiveQuicBuffers = ArrayPool.Shared.Rent((int)receiveEvent.BufferCount); + + if (oldReceiveBuffers.Length != 0) // don't return Array.Empty. + { + ArrayPool.Shared.Return(oldReceiveBuffers); + } + } + + for (uint i = 0; i < receiveEvent.BufferCount; ++i) + { + state.ReceiveQuicBuffers[i] = receiveEvent.Buffers[i]; + } + + state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount; + state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength); + state.ReadState = ReadState.BuffersAvailable; + return MsQuicStatusCodes.Pending; + case ReadState.PendingRead: + // There is a pending ReadAsync(). + + state.ReceiveCancellationRegistration.Unregister(); + state.RootedReceiveStream = null; + state.ReadState = ReadState.None; + + readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); + break; + case ReadState.Aborted: + default: + Debug.Assert(state.ReadState == ReadState.Aborted, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); + + // There was a race between a user aborting the read stream and the callback being ran. + // This will eat any received data. + return MsQuicStatusCodes.Success; } - state.ReadState = ReadState.IndividualReadComplete; } - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.Complete((uint)receiveEvent.TotalBufferLength); - } + // We're completing a pending read. + + state.ReceiveResettableCompletionSource.Complete(readLength); - return MsQuicStatusCodes.Pending; + // Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called. + // Returning Continue will cause a second receive event to fire immediately after this returns, but allows MsQuic to clean up its buffers. + + uint ret = (uint)readLength == receiveEvent.TotalBufferLength + ? MsQuicStatusCodes.Success + : MsQuicStatusCodes.Continue; + + receiveEvent.TotalBufferLength = (uint)readLength; + return ret; } private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) { - bool shouldComplete = false; + bool shouldComplete; + lock (state) { - if (state.SendState == SendState.None || state.SendState == SendState.Pending) - { - shouldComplete = true; - } + shouldComplete = state.SendState == SendState.None || state.SendState == SendState.Pending; state.SendState = SendState.Aborted; - state.SendErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; + state.SendErrorCode = evt.Data.PeerSendAborted.ErrorCode; } if (shouldComplete) { - state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode))); + state.SendResettableCompletionSource.CompleteException(new QuicStreamAbortedException(state.SendErrorCode)); } return MsQuicStatusCodes.Success; @@ -649,72 +696,9 @@ private static uint HandleStartComplete(State state) return MsQuicStatusCodes.Success; } - private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt) - { - bool shouldComplete = false; - lock (state) - { - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Finished; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.ShutdownWriteCompletionSource.SetResult(); - } - - return MsQuicStatusCodes.Success; - } - private static uint HandleEventShutdownComplete(State state) { - bool shouldReadComplete = false; - bool shouldShutdownWriteComplete = false; - bool shouldShutdownComplete = false; - - lock (state) - { - // This event won't occur within the middle of a receive. - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info("Completing resettable event source."); - - if (state.ReadState == ReadState.None) - { - shouldReadComplete = true; - } - - state.ReadState = ReadState.ReadsCompleted; - - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Finished; - shouldShutdownWriteComplete = true; - } - - if (state.ShutdownState == ShutdownState.None) - { - state.ShutdownState = ShutdownState.Finished; - shouldShutdownComplete = true; - } - } - - if (shouldReadComplete) - { - state.ReceiveResettableCompletionSource.Complete(0); - } - - if (shouldShutdownWriteComplete) - { - state.ShutdownWriteCompletionSource.SetResult(); - } - - if (shouldShutdownComplete) - { - state.ShutdownCompletionSource.SetResult(); - } - + state.ShutdownCompletionSource.TrySetResult(); return MsQuicStatusCodes.Success; } @@ -742,22 +726,27 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) private static uint HandleEventPeerSendShutdown(State state) { - bool shouldComplete = false; + bool completePendingRead = false; lock (state) { // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info("Completing resettable event source."); - if (state.ReadState == ReadState.None) + + if (state.ReadState == ReadState.PendingRead) { - shouldComplete = true; + completePendingRead = true; + state.RootedReceiveStream = null; + state.ReadState = ReadState.EndOfReadStream; + } + else if (state.ReadState == ReadState.None) + { + state.ReadState = ReadState.EndOfReadStream; } - - state.ReadState = ReadState.ReadsCompleted; } - if (shouldComplete) + if (completePendingRead) { state.ReceiveResettableCompletionSource.Complete(0); } @@ -818,7 +807,7 @@ private unsafe ValueTask SendReadOnlyMemoryAsync( if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) { // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } return default; } @@ -873,7 +862,7 @@ private unsafe ValueTask SendReadOnlySequenceAsync( if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) { // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } return default; } @@ -942,7 +931,7 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync( if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) { // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } return default; } @@ -1016,19 +1005,24 @@ private void ThrowIfDisposed() private enum ReadState { /// - /// The stream is open, but there is no data available. + /// The stream is open, but there is no pending operation and no data available. /// None, /// - /// Data is available in . + /// There is a pending operation on the stream. + /// + PendingRead, + + /// + /// There is data available. /// - IndividualReadComplete, + BuffersAvailable, /// /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. /// - ReadsCompleted, + EndOfReadStream, /// /// User has aborted the stream, either via a cancellation token on ReadAsync(), or via AbortRead(). @@ -1036,20 +1030,6 @@ private enum ReadState Aborted } - private enum ShutdownWriteState - { - None, - Canceled, - Finished - } - - private enum ShutdownState - { - None, - Canceled, - Finished - } - private enum SendState { None, diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs index 2be277a61252a4..e97cfc1613cbe8 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs @@ -7,7 +7,7 @@ namespace System.Net.Quic.Implementations { - internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable + internal abstract class QuicStreamProvider { internal abstract long StreamId { get; } @@ -17,9 +17,7 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable internal abstract ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default); - internal abstract void AbortRead(long errorCode); - - internal abstract void AbortWrite(long errorCode); + internal abstract void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both); internal abstract bool CanWrite { get; } @@ -37,18 +35,14 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable internal abstract ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default); - internal abstract ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default); - - internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default); - - internal abstract void Shutdown(); - internal abstract void Flush(); internal abstract Task FlushAsync(CancellationToken cancellationToken); + public abstract void CompleteWrites(); + public abstract void Dispose(); - public abstract ValueTask DisposeAsync(); + public abstract ValueTask DisposeAsync(CancellationToken cancellationToken); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs new file mode 100644 index 00000000000000..f43051991ea1d7 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Quic +{ + [Flags] + public enum QuicAbortDirection + { + Read = 1, + Write = 2, + Both = Read | Write + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index e1724eee535754..f364864364215e 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -85,9 +85,17 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override Task FlushAsync(CancellationToken cancellationToken) => _provider.FlushAsync(cancellationToken); - public void AbortRead(long errorCode) => _provider.AbortRead(errorCode); + /// + /// Completes the write direction of the stream, notifying the peer of end-of-stream. + /// + public void CompleteWrites() => _provider.CompleteWrites(); - public void AbortWrite(long errorCode) => _provider.AbortWrite(errorCode); + /// + /// Aborts the . + /// + /// The error code to abort with. + /// The direction of the abort. + public void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both) => _provider.Abort(errorCode, abortDirection); public ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, endStream, cancellationToken); @@ -99,12 +107,6 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, endStream, cancellationToken); - public ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownWriteCompleted(cancellationToken); - - public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken); - - public void Shutdown() => _provider.Shutdown(); - protected override void Dispose(bool disposing) { if (disposing) @@ -112,5 +114,14 @@ protected override void Dispose(bool disposing) _provider.Dispose(); } } + + public override ValueTask DisposeAsync() => CloseAsync(); + + /// + /// Gracefully shuts down and closes the , leaving it in a disposed state. + /// + /// If triggered, an will be thrown and the stream will be left undisposed. + /// A representing the asynchronous closure of the . + public ValueTask CloseAsync(CancellationToken cancellationToken = default) => _provider.DisposeAsync(cancellationToken); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 446daf8021d8b2..d3c08bf96b691a 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -122,8 +122,7 @@ await RunClientServer( } } - stream.Shutdown(); - await stream.ShutdownCompleted(); + stream.CompleteWrites(); }, async serverConnection => { @@ -140,8 +139,7 @@ await RunClientServer( int expectedTotalBytes = writes.SelectMany(x => x).Sum(); Assert.Equal(expectedTotalBytes, totalBytes); - stream.Shutdown(); - await stream.ShutdownCompleted(); + stream.CompleteWrites(); }); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index b08f93f94486ec..a97b5104c1b4ee 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -1,11 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -19,12 +19,10 @@ public abstract class QuicStreamTests : QuicTestBase [Fact] public async Task BasicTest() { - await RunClientServer( + await RunBidirectionalClientServer( iterations: 100, - serverFunction: async connection => + serverFunction: async stream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[s_data.Length]; int bytesRead = await ReadAll(stream, buffer); @@ -32,12 +30,9 @@ await RunClientServer( Assert.Equal(s_data, buffer); await stream.WriteAsync(s_data, endStream: true); - await stream.ShutdownCompleted(); }, - clientFunction: async connection => + clientFunction: async stream => { - await using QuicStream stream = connection.OpenBidirectionalStream(); - await stream.WriteAsync(s_data, endStream: true); byte[] buffer = new byte[s_data.Length]; @@ -45,8 +40,6 @@ await RunClientServer( Assert.Equal(s_data.Length, bytesRead); Assert.Equal(s_data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -64,12 +57,10 @@ public async Task MultipleReadsAndWrites() m = m[s_data.Length..]; } - await RunClientServer( + await RunBidirectionalClientServer( iterations: 100, - serverFunction: async connection => + serverFunction: async stream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[expectedBytesCount]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(expectedBytesCount, bytesRead); @@ -80,13 +71,9 @@ await RunClientServer( await stream.WriteAsync(s_data); } await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); }, - clientFunction: async connection => + clientFunction: async stream => { - await using QuicStream stream = connection.OpenBidirectionalStream(); - for (int i = 0; i < sendCount; i++) { await stream.WriteAsync(s_data); @@ -97,8 +84,6 @@ await RunClientServer( int bytesRead = await ReadAll(stream, buffer); Assert.Equal(expectedBytesCount, bytesRead); Assert.Equal(expected, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -125,9 +110,6 @@ await RunClientServer( await stream.WriteAsync(s_data, endStream: true); await stream2.WriteAsync(s_data, endStream: true); - - await stream.ShutdownCompleted(); - await stream2.ShutdownCompleted(); }, clientFunction: async connection => { @@ -147,9 +129,6 @@ await RunClientServer( int bytesRead2 = await ReadAll(stream2, buffer2); Assert.Equal(s_data.Length, bytesRead2); Assert.Equal(s_data, buffer2); - - await stream.ShutdownCompleted(); - await stream2.ShutdownCompleted(); } ); } @@ -157,20 +136,24 @@ await RunClientServer( [Fact] public async Task GetStreamIdWithoutStartWorks() { - using QuicListener listener = CreateQuicListener(); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - - ValueTask clientTask = clientConnection.ConnectAsync(); - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - await clientTask; + using SemaphoreSlim sem = new SemaphoreSlim(0); + await RunClientServer( + async clientConnection => + { + await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); + Assert.Equal(0, clientStream.StreamId); + sem.Release(); - using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); - Assert.Equal(0, clientStream.StreamId); + }, + async serverConnection => + { + await sem.WaitAsync(); - // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer - // explicitly closing the connections seems to help, but the problem should still be investigated, we should have a meaningful - // exception instead of AccessViolationException - await clientConnection.CloseAsync(0); + // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer + // explicitly closing the connections seems to help, but the problem should still be investigated, we should have a meaningful + // exception instead of AccessViolationException + await serverConnection.CloseAsync(0); + }); } [Fact] @@ -180,12 +163,10 @@ public async Task LargeDataSentAndReceived() const int NumberOfWrites = 256; // total sent = 16M byte[] data = Enumerable.Range(0, writeSize * NumberOfWrites).Select(x => (byte)x).ToArray(); - await RunClientServer( + await RunBidirectionalClientServer( iterations: 5, - serverFunction: async connection => + serverFunction: async stream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[data.Length]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); @@ -196,13 +177,9 @@ await RunClientServer( await stream.WriteAsync(data[pos..(pos + writeSize)]); } await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); }, - clientFunction: async connection => + clientFunction: async stream => { - await using QuicStream stream = connection.OpenBidirectionalStream(); - for (int pos = 0; pos < data.Length; pos += writeSize) { await stream.WriteAsync(data[pos..(pos + writeSize)]); @@ -213,8 +190,6 @@ await RunClientServer( int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); AssertArrayEqual(data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -291,9 +266,6 @@ private static async Task TestBidirectionalStream(QuicStream s1, QuicStream s2) await SendAndReceiveEOFAsync(s1, s2); await SendAndReceiveEOFAsync(s2, s1); - - await s1.ShutdownCompleted(); - await s2.ShutdownCompleted(); } private static async Task TestUnidirectionalStream(QuicStream s1, QuicStream s2) @@ -308,9 +280,6 @@ private static async Task TestUnidirectionalStream(QuicStream s1, QuicStream s2) await SendAndReceiveDataAsync(s_data, s1, s2); await SendAndReceiveEOFAsync(s1, s2); - - await s1.ShutdownCompleted(); - await s2.ShutdownCompleted(); } private static async Task SendAndReceiveDataAsync(byte[] data, QuicStream s1, QuicStream s2) @@ -354,11 +323,9 @@ public async Task ReadWrite_Random_Success(int readSize, int writeSize) byte[] testBuffer = new byte[8192]; Random.Shared.NextBytes(testBuffer); - await RunClientServer( - async clientConnection => + await RunUnidirectionalClientServer( + async clientStream => { - await using QuicStream clientStream = clientConnection.OpenUnidirectionalStream(); - ReadOnlyMemory sendBuffer = testBuffer; while (sendBuffer.Length != 0) { @@ -367,17 +334,14 @@ await RunClientServer( sendBuffer = sendBuffer.Slice(chunk.Length); } - await clientStream.WriteAsync(Memory.Empty, endStream: true); - await clientStream.ShutdownCompleted(); + clientStream.CompleteWrites(); }, - async serverConnection => + async serverStream => { - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); - byte[] receiveBuffer = new byte[testBuffer.Length]; int totalBytesRead = 0; - while (true) // TODO: if you don't read until 0-byte read, ShutdownCompleted sometimes may not trigger - why? + while (true) { Memory recieveChunkBuffer = receiveBuffer.AsMemory(totalBytesRead, Math.Min(receiveBuffer.Length - totalBytesRead, readSize)); int bytesRead = await serverStream.ReadAsync(recieveChunkBuffer); @@ -391,8 +355,6 @@ await RunClientServer( Assert.Equal(testBuffer.Length, totalBytesRead); AssertArrayEqual(testBuffer, receiveBuffer); - - await serverStream.ShutdownCompleted(); }); } @@ -407,32 +369,116 @@ from writeSize in sizes } [Fact] - public async Task Read_StreamAborted_Throws() + public async Task Read_WriteAborted_Throws() { const int ExpectedErrorCode = 0xfffffff; - await Task.Run(async () => - { - using QuicListener listener = CreateQuicListener(); - ValueTask serverConnectionTask = listener.AcceptConnectionAsync(); + using SemaphoreSlim sem = new SemaphoreSlim(0); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); - using QuicConnection serverConnection = await serverConnectionTask; + await sem.WaitAsync(); + clientStream.Abort(ExpectedErrorCode, QuicAbortDirection.Write); + }, + async serverStream => + { + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); - await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); - await clientStream.WriteAsync(new byte[1]); + sem.Release(); - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); - await serverStream.ReadAsync(new byte[1]); + byte[] buffer = new byte[100]; + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); + } - clientStream.AbortWrite(ExpectedErrorCode); + [Fact] + public async Task Read_SynchronousCompletion_Success() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); - byte[] buffer = new byte[100]; - QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); - Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - }).WaitAsync(TimeSpan.FromSeconds(15)); + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + sem.Release(); + clientStream.CompleteWrites(); + sem.Release(); + }, + async serverStream => + { + await sem.WaitAsync(); + await Task.Delay(1000); + + ValueTask task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + int received = await task; + Assert.Equal(1, received); + + await sem.WaitAsync(); + await Task.Delay(1000); + + task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + received = await task; + Assert.Equal(0, received); + }); + } + + [Fact] + public async Task ReadOutstanding_ReadAborted_Throws() + { + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + Task exTask = Assert.ThrowsAsync(() => serverStream.ReadAsync(new byte[1]).AsTask()); + + Assert.False(exTask.IsCompleted); + + serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Read); + + await exTask; + + sem.Release(); + }); + } + + [Fact] + public async Task Read_ConcurrentReads_Throws() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + ValueTask readTask = serverStream.ReadAsync(new byte[1]); + Assert.False(readTask.IsCompleted); + + await Assert.ThrowsAsync(async () => await serverStream.ReadAsync(new byte[1])); + + sem.Release(); + + int res = await readTask; + Assert.Equal(0, res); + }); } [ActiveIssue("https://github.com/dotnet/runtime/issues/32050")] @@ -464,6 +510,141 @@ await Task.Run(async () => Assert.Equal(ExpectedErrorCode, ex.ErrorCode); }).WaitAsync(TimeSpan.FromSeconds(5)); } + + [Fact] + public async Task CloseAsync_Cancelled_Then_CloseAsync_Success() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + // Make sure the first task throws an OCE. + + using CancellationTokenSource cts = new CancellationTokenSource(); + CancellationToken cancellationToken = cts.Token; + + ValueTask closeTask = clientStream.CloseAsync(cancellationToken); + + await Task.Delay(500); + Assert.False(closeTask.IsCompleted); + + cts.Cancel(); + OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => await closeTask); + Assert.Equal(cancellationToken, oce.CancellationToken); + + // Release before closing the stream, to allow the server to close its write stream. + + sem.Release(); + await clientStream.CloseAsync(); + }, + async serverStream => + { + // Wait before closing the stream, which will cause the client's CloseAsync to finish. + + await sem.WaitAsync(); + await serverStream.CloseAsync(); + }); + } + + [Fact] + public async Task CloseAsync_Cancelled_Then_Abort_Success() + { + const int AbortCode = 1234; + + await RunBidirectionalClientServer( + async clientStream => + { + // Make sure the first task throws an OCE. + + using CancellationTokenSource cts = new CancellationTokenSource(); + CancellationToken cancellationToken = cts.Token; + + ValueTask closeTask = clientStream.CloseAsync(cancellationToken); + + await Task.Delay(500); + Assert.False(closeTask.IsCompleted); + + cts.Cancel(); + OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => await closeTask); + Assert.Equal(cancellationToken, oce.CancellationToken); + + // Abort the stream, causing the other side to close. + + clientStream.Abort(AbortCode); + await clientStream.CloseAsync(); + }, + async serverStream => + { + // Wait for the client to abort its stream before closing our stream. + + var buffer = new byte[8]; + QuicStreamAbortedException ae = await Assert.ThrowsAnyAsync(async () => await serverStream.ReadAsync(buffer)); + Assert.Equal(AbortCode, ae.ErrorCode); + + await serverStream.CloseAsync(); + }); + } + + [Fact] + public async Task QuicStream_CatchPattern_Success() + { + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + await Task.Delay(500); // wait for the shutdown to reach this side. + + QuicStreamAbortedException ex = await Assert.ThrowsAsync(async () => + { + await clientStream.WriteAsync(new byte[1]); + }); + + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + + ex = await Assert.ThrowsAsync(async () => + { + await clientStream.ReadAsync(new byte[1]); + }); + + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }, + async serverStream => + { + using var cts = new CancellationTokenSource(); + CancellationToken token = cts.Token; + + try + { + ValueTask readTask = serverStream.ReadAsync(new byte[1], token); + + Assert.False(readTask.IsCompleted); + + cts.Cancel(); + await readTask; + + Assert.False(true, "This point should never be reached."); + } + catch (OperationCanceledException ex) when (ex.CancellationToken == token) + { + serverStream.Abort(ExpectedErrorCode); + } + catch + { + Assert.False(true, "This point should never be reached."); + } + + // Because the stream is aborted, this should not wait for the other side to close its stream. + await serverStream.DisposeAsync(); + + // Only allow the other side to close its stream after the dispose compleats. + sem.Release(); + }); + } } public sealed class QuicStreamTests_MockProvider : QuicStreamTests { } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 027d0adb258adf..3d34e48611d5a0 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -63,6 +63,46 @@ internal QuicListener CreateQuicListener(IPEndPoint endpoint) return new QuicListener(ImplementationProvider, endpoint, GetSslServerAuthenticationOptions()); } + internal Task RunUnidirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunClientServerStream(clientFunction, serverFunction, iterations, millisecondsTimeout, bidi: false); + + internal Task RunBidirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunClientServerStream(clientFunction, serverFunction, iterations, millisecondsTimeout, bidi: true); + + private async Task RunClientServerStream(Func clientFunction, Func serverFunction, int iterations, int millisecondsTimeout, bool bidi) + { + const long ClientThrewAbortCode = 1234567890; + const long ServerThrewAbortCode = 2345678901; + + await RunClientServer( + async clientConnection => + { + await using QuicStream clientStream = bidi ? clientConnection.OpenBidirectionalStream() : clientConnection.OpenUnidirectionalStream(); + try + { + await clientFunction(clientStream); + } + catch + { + clientStream.Abort(ClientThrewAbortCode); + throw; + } + }, + async serverConnection => + { + await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + try + { + await serverFunction(serverStream); + } + catch + { + serverStream.Abort(ServerThrewAbortCode); + throw; + } + }, iterations, millisecondsTimeout); + } + internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) { using QuicListener listener = CreateQuicListener(); From 45607a21c5f66c11cd1dfe8cf7df43e0c0482dfe Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Tue, 25 May 2021 01:11:25 -0700 Subject: [PATCH 2/8] Add Immediate to abort flags. --- .../System.Net.Quic/ref/System.Net.Quic.cs | 3 +- .../Implementations/MsQuic/MsQuicStream.cs | 55 +++++++++++++------ .../src/System/Net/Quic/QuicAbortDirection.cs | 18 +++++- .../src/System/Net/Quic/QuicStream.cs | 10 +++- .../tests/FunctionalTests/MsQuicTests.cs | 4 -- .../tests/FunctionalTests/QuicStreamTests.cs | 6 +- .../tests/FunctionalTests/QuicTestBase.cs | 21 ++++++- 7 files changed, 86 insertions(+), 31 deletions(-) diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 96299c6fd3ee6e..275a6e9582050c 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -12,6 +12,7 @@ public enum QuicAbortDirection Read = 1, Write = 2, Both = 3, + Immediate = 7 } public partial class QuicClientConnectionOptions : System.Net.Quic.QuicOptions { @@ -92,7 +93,7 @@ internal QuicStream() { } public override long Length { get { throw null; } } public override long Position { get { throw null; } set { } } public long StreamId { get { throw null; } } - public void Abort(long errorCode, System.Net.Quic.QuicAbortDirection abortDirection = System.Net.Quic.QuicAbortDirection.Both) { } + public void Abort(long errorCode, System.Net.Quic.QuicAbortDirection abortDirection = System.Net.Quic.QuicAbortDirection.Immediate) { } public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public System.Threading.Tasks.ValueTask CloseAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 722cd24896c964..b82f307506bd10 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -426,6 +426,11 @@ internal override void Abort(long errorCode, QuicAbortDirection abortDirection = } } + if ((abortDirection & QuicAbortDirection.Immediate) == QuicAbortDirection.Immediate) + { + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE; + } + StartShutdownOrAbort(flags, errorCode); if (completeWrites) @@ -443,6 +448,8 @@ internal override void Abort(long errorCode, QuicAbortDirection abortDirection = /// For abortive flags, the error code sent to peer. Otherwise, ignored. private void StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { + Debug.Assert(!_disposed); + uint status = MsQuicApi.Api.StreamShutdownDelegate(_state.Handle, flags, errorCode); QuicExceptionHelpers.ThrowIfFailed(status, "StreamShutdown failed."); } @@ -478,37 +485,42 @@ internal override Task FlushAsync(CancellationToken cancellationToken = default) } public override ValueTask DisposeAsync(CancellationToken cancellationToken) => - DisposeAsync(cancellationToken, async: true, immediate: _state.SendState == SendState.Aborted); - - public override void Dispose() => - Dispose(immediate: _state.SendState == SendState.Aborted); + DisposeAsync(cancellationToken, async: true); - ~MsQuicStream() => - Dispose(immediate: true); - - private void Dispose(bool immediate) + public override void Dispose() { - ValueTask t = DisposeAsync(cancellationToken: default, async: false, immediate); + ValueTask t = DisposeAsync(cancellationToken: default, async: false); Debug.Assert(t.IsCompleted); t.GetAwaiter().GetResult(); } - /// - /// - /// When true, causes immediate disposal without waiting for peer ACKs. - /// - private async ValueTask DisposeAsync(CancellationToken cancellationToken, bool async, bool immediate) + ~MsQuicStream() + { + DisposeAsyncThrowaway(this); + + // This is weird due to needing to keep _state alive for MsQuic's callback. + // See DisposeAsync implementation for details. + + static async void DisposeAsyncThrowaway(MsQuicStream stream) + { + await stream.DisposeAsync(cancellationToken: default, async: true).ConfigureAwait(false); + } + } + + private async ValueTask DisposeAsync(CancellationToken cancellationToken, bool async) { if (_disposed) { return; } - QUIC_STREAM_SHUTDOWN_FLAGS flags = immediate - ? (QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL | QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE) - : QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL; + // MsQuic will ignore this call if it was already shutdown elsewhere. + // PERF TODO: update write loop to make it so we don't need to call this. it queues an event to the MsQuic thread pool. + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); - StartShutdownOrAbort(flags, errorCode: 0); + // MsQuic will continue sending us events, so we need to wait for shutdown + // completion (the final event) before freeing _stateHandle's GCHandle. + // If Abort() wasn't called with "immediate", this will wait for peer to shut down their write side. if (async) { @@ -594,6 +606,13 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) { ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive; + if (receiveEvent.BufferCount == 0) + { + // This is a 0-length receive that happens once reads are finished (via abort or otherwise). + // State changes for this are handled elsewhere. + return MsQuicStatusCodes.Success; + } + int readLength; lock (state) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs index f43051991ea1d7..e441e947bd6f9b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs @@ -6,8 +6,24 @@ namespace System.Net.Quic [Flags] public enum QuicAbortDirection { + /// + /// Aborts the read direction of the stream. + /// Read = 1, + + /// + /// Aborts the write direction of the stream. + /// Write = 2, - Both = Read | Write + + /// + /// Aborts both the read and write direction of the stream. + /// + Both = Read | Write, + + /// + /// Aborts both the read and write direction of the stream, without waiting for the peer to shutdown their write direction. + /// + Immediate = Both | 4 } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index f364864364215e..a7ddd4d1812052 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -95,7 +95,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati /// /// The error code to abort with. /// The direction of the abort. - public void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both) => _provider.Abort(errorCode, abortDirection); + public void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Immediate) => _provider.Abort(errorCode, abortDirection); public ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, endStream, cancellationToken); @@ -107,6 +107,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, endStream, cancellationToken); + /// protected override void Dispose(bool disposing) { if (disposing) @@ -115,13 +116,18 @@ protected override void Dispose(bool disposing) } } + /// public override ValueTask DisposeAsync() => CloseAsync(); /// - /// Gracefully shuts down and closes the , leaving it in a disposed state. + /// Shuts down and closes the , leaving it in a disposed state. /// /// If triggered, an will be thrown and the stream will be left undisposed. /// A representing the asynchronous closure of the . + /// + /// When the stream has been been aborted with , this will complete independent of the peer. + /// Otherwise, this will wait for the peer to complete their write side (gracefully or abortive) and drain any bytes received in the mean time. + /// public ValueTask CloseAsync(CancellationToken cancellationToken = default) => _provider.DisposeAsync(cancellationToken); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index d3c08bf96b691a..2da1932c150e85 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -323,8 +323,6 @@ await RunClientServer( await stream.WriteAsync(data[pos..(pos + writeSize)]); } await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); }, clientFunction: async connection => { @@ -340,8 +338,6 @@ await RunClientServer( int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); AssertArrayEqual(data, buffer); - - await stream.ShutdownCompleted(); } ); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index a97b5104c1b4ee..0c1cca9ca78504 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -560,13 +560,13 @@ await RunBidirectionalClientServer( using CancellationTokenSource cts = new CancellationTokenSource(); CancellationToken cancellationToken = cts.Token; - ValueTask closeTask = clientStream.CloseAsync(cancellationToken); + Task oceTask = Assert.ThrowsAnyAsync(async () => await clientStream.CloseAsync(cancellationToken)); await Task.Delay(500); - Assert.False(closeTask.IsCompleted); + Assert.False(oceTask.IsCompleted); cts.Cancel(); - OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => await closeTask); + OperationCanceledException oce = await oceTask; Assert.Equal(cancellationToken, oce.CancellationToken); // Abort the stream, causing the other side to close. diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 3d34e48611d5a0..a57b1cf0c075a2 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -84,7 +84,16 @@ await RunClientServer( } catch { - clientStream.Abort(ClientThrewAbortCode); + try + { + // abort the stream to give the peer a chance to tear down. + clientStream.Abort(ClientThrewAbortCode); + } + catch(ObjectDisposedException) + { + // do nothing. + } + throw; } }, @@ -97,7 +106,15 @@ await RunClientServer( } catch { - serverStream.Abort(ServerThrewAbortCode); + try + { + // abort the stream to give the peer a chance to tear down. + serverStream.Abort(ServerThrewAbortCode); + } + catch (ObjectDisposedException) + { + // do nothing. + } throw; } }, iterations, millisecondsTimeout); From b65200d9bc20e05d6affb67012b448af841dc18c Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Mon, 21 Jun 2021 02:57:00 -0700 Subject: [PATCH 3/8] WIP --- .../Implementations/MsQuic/MsQuicStream.cs | 6 +- .../src/System/Net/Quic/QuicAbortDirection.cs | 2 +- .../tests/FunctionalTests/QuicStreamTests.cs | 83 ++++++++++++++----- 3 files changed, 65 insertions(+), 26 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index b82f307506bd10..891bcd7648c4b8 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -412,7 +412,7 @@ internal override void Abort(long errorCode, QuicAbortDirection abortDirection = { if (abortDirection.HasFlag(QuicAbortDirection.Write)) { - completeWrites = _state.SendState == SendState.None; + completeWrites = _state.SendState is SendState.None or SendState.Pending; _state.SendState = SendState.Aborted; flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND; } @@ -435,7 +435,7 @@ internal override void Abort(long errorCode, QuicAbortDirection abortDirection = if (completeWrites) { - _state.SendResettableCompletionSource.Complete(0); + _state.SendResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); } if (completeReads) @@ -1044,7 +1044,7 @@ private enum ReadState EndOfReadStream, /// - /// User has aborted the stream, either via a cancellation token on ReadAsync(), or via AbortRead(). + /// User has aborted the stream, either via a cancellation token on ReadAsync(), or via Abort(read). /// Aborted } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs index e441e947bd6f9b..788f6db7d01ff1 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs @@ -22,7 +22,7 @@ public enum QuicAbortDirection Both = Read | Write, /// - /// Aborts both the read and write direction of the stream, without waiting for the peer to shutdown their write direction. + /// Aborts both the read and write direction of the stream, without waiting for the peer to acknowledge the shutdown. /// Immediate = Both | 4 } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 0c1cca9ca78504..c14847c4377990 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -3,6 +3,7 @@ using System.Buffers; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; using System.Threading; @@ -536,14 +537,12 @@ await RunBidirectionalClientServer( // Release before closing the stream, to allow the server to close its write stream. sem.Release(); - await clientStream.CloseAsync(); }, async serverStream => { - // Wait before closing the stream, which will cause the client's CloseAsync to finish. + // Wait before closing the stream, which would otherwise cause the client's CloseAsync to finish. await sem.WaitAsync(); - await serverStream.CloseAsync(); }); } @@ -552,10 +551,13 @@ public async Task CloseAsync_Cancelled_Then_Abort_Success() { const int AbortCode = 1234; + using var sem = new SemaphoreSlim(0); + await RunBidirectionalClientServer( async clientStream => { - // Make sure the first task throws an OCE. + // We use the fact that a graceful CloseAsync() won't complete until the + // other side also does a graceful CloseAsync() to force an OperationCanceledException. using CancellationTokenSource cts = new CancellationTokenSource(); CancellationToken cancellationToken = cts.Token; @@ -569,23 +571,39 @@ await RunBidirectionalClientServer( OperationCanceledException oce = await oceTask; Assert.Equal(cancellationToken, oce.CancellationToken); - // Abort the stream, causing the other side to close. + // Abort the stream, causing CloseAsync to complete not synchronously but "immediately". - clientStream.Abort(AbortCode); + clientStream.Abort(AbortCode, QuicAbortDirection.Immediate); await clientStream.CloseAsync(); + + sem.Release(); }, async serverStream => { - // Wait for the client to abort its stream before closing our stream. + // Wait for the client to gracefully close. - var buffer = new byte[8]; - QuicStreamAbortedException ae = await Assert.ThrowsAnyAsync(async () => await serverStream.ReadAsync(buffer)); - Assert.Equal(AbortCode, ae.ErrorCode); + await sem.WaitAsync(); - await serverStream.CloseAsync(); + int readLen = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(0, readLen); + + // Wait for the client to send STOP_SENDING. + + await Task.Delay(500); + + QuicStreamAbortedException ex = await Assert.ThrowsAnyAsync(async () => await serverStream.WriteAsync(new byte[1])); + Assert.Equal(AbortCode, ex.ErrorCode); }); } + // The server portion of this method tests the full version of the "catch and close pattern", which is + // required to prevent a DoS attack. The pattern is: + // 1. Wrap all your ops in a try/catch, and have the catch call Abort() then Close() with a "shutdown timeout" cancellation token. + // - This Close() will wait for the peer to ACK the shutdown. + // 2. Wrap that try/catch in another try/catch(OperationCanceledException) which calls Abort(Immediate). + // - This causes the next Close()/Dispose() to not wait for ACK. + // + // TODO: we should revisit this because it's a very easy to screw up pattern. [Fact] public async Task QuicStream_CatchPattern_Success() { @@ -616,11 +634,13 @@ await RunBidirectionalClientServer( async serverStream => { using var cts = new CancellationTokenSource(); - CancellationToken token = cts.Token; try { - ValueTask readTask = serverStream.ReadAsync(new byte[1], token); + // We just need to throw an exception here + // Cancel reads, causing an OperationCanceledException + + ValueTask readTask = serverStream.ReadAsync(new byte[1], cts.Token); Assert.False(readTask.IsCompleted); @@ -629,19 +649,38 @@ await RunBidirectionalClientServer( Assert.False(true, "This point should never be reached."); } - catch (OperationCanceledException ex) when (ex.CancellationToken == token) + catch (Exception ex) { - serverStream.Abort(ExpectedErrorCode); - } - catch - { - Assert.False(true, "This point should never be reached."); + Assert.True(ex is OperationCanceledException oce && oce.CancellationToken == cts.Token); + + // Abort here. The CloseAsync that follows will still wait for an ACK of the shutdown, + // so a cancellation token with a shutdown timeout is passed in. + + serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Both); + + using var shutdownCts = new CancellationTokenSource(500); + try + { + await serverStream.CloseAsync(shutdownCts.Token); + } + catch(Exception ex2) + { + // TODO: this catch block will basically never be executed right now -- we need a way to + // block the MsQuic from ACKing the abort. + + Assert.True(ex2 is OperationCanceledException oce2 && oce2.CancellationToken == shutdownCts.Token); + + // Abort again. The exit code is not important, because we gave it above already. + // This time, Immediate is used which will cause CloseAsync() to not wait for a shutdown ACK. + serverStream.Abort(0, QuicAbortDirection.Immediate); + } } - // Because the stream is aborted, this should not wait for the other side to close its stream. - await serverStream.DisposeAsync(); + // Either the CloseAsync above worked, in which case this is a no-op, + // or the stream has been re-aborted with Immediate, in which case this will complete "immediately" but not synchronously. + await serverStream.CloseAsync(); - // Only allow the other side to close its stream after the dispose compleats. + // Only allow the other side to close its stream after the dispose completes. sem.Release(); }); } From 52b2814b202fcc1aa00ed1377d1d835c51cff027 Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Mon, 21 Jun 2021 21:34:20 -0700 Subject: [PATCH 4/8] WIP --- .../tests/FunctionalTests/QuicStreamTests.cs | 139 ++++++------------ 1 file changed, 43 insertions(+), 96 deletions(-) diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index c14847c4377990..9b6bdb2f1a4eff 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -522,17 +522,14 @@ await RunBidirectionalClientServer( { // Make sure the first task throws an OCE. - using CancellationTokenSource cts = new CancellationTokenSource(); - CancellationToken cancellationToken = cts.Token; + using var cts = new CancellationTokenSource(500); - ValueTask closeTask = clientStream.CloseAsync(cancellationToken); - - await Task.Delay(500); - Assert.False(closeTask.IsCompleted); + OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => + { + await clientStream.CloseAsync(cts.Token); + }); - cts.Cancel(); - OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => await closeTask); - Assert.Equal(cancellationToken, oce.CancellationToken); + Assert.Equal(cts.Token, oce.CancellationToken); // Release before closing the stream, to allow the server to close its write stream. @@ -546,67 +543,25 @@ await RunBidirectionalClientServer( }); } - [Fact] - public async Task CloseAsync_Cancelled_Then_Abort_Success() - { - const int AbortCode = 1234; - - using var sem = new SemaphoreSlim(0); - - await RunBidirectionalClientServer( - async clientStream => - { - // We use the fact that a graceful CloseAsync() won't complete until the - // other side also does a graceful CloseAsync() to force an OperationCanceledException. - - using CancellationTokenSource cts = new CancellationTokenSource(); - CancellationToken cancellationToken = cts.Token; - - Task oceTask = Assert.ThrowsAnyAsync(async () => await clientStream.CloseAsync(cancellationToken)); - - await Task.Delay(500); - Assert.False(oceTask.IsCompleted); - - cts.Cancel(); - OperationCanceledException oce = await oceTask; - Assert.Equal(cancellationToken, oce.CancellationToken); - - // Abort the stream, causing CloseAsync to complete not synchronously but "immediately". - - clientStream.Abort(AbortCode, QuicAbortDirection.Immediate); - await clientStream.CloseAsync(); - - sem.Release(); - }, - async serverStream => - { - // Wait for the client to gracefully close. - - await sem.WaitAsync(); - - int readLen = await serverStream.ReadAsync(new byte[1]); - Assert.Equal(0, readLen); - - // Wait for the client to send STOP_SENDING. - - await Task.Delay(500); - - QuicStreamAbortedException ex = await Assert.ThrowsAnyAsync(async () => await serverStream.WriteAsync(new byte[1])); - Assert.Equal(AbortCode, ex.ErrorCode); - }); - } - - // The server portion of this method tests the full version of the "catch and close pattern", which is - // required to prevent a DoS attack. The pattern is: - // 1. Wrap all your ops in a try/catch, and have the catch call Abort() then Close() with a "shutdown timeout" cancellation token. - // - This Close() will wait for the peer to ACK the shutdown. - // 2. Wrap that try/catch in another try/catch(OperationCanceledException) which calls Abort(Immediate). - // - This causes the next Close()/Dispose() to not wait for ACK. + // This tests the pattern needed to safely control shutdown of a QuicStream. + // 1. Normal stream usage happens inside try. + // 2. Call Abort(Both) in the catch. + // 3. Call Close() with a cancellation token in the finally. + // 4. If that Close() fails, call Abort(Immediate). // - // TODO: we should revisit this because it's a very easy to screw up pattern. - [Fact] - public async Task QuicStream_CatchPattern_Success() + // This is important to avoid a DoS if the peer doesn't shutdown their sends but otherwise leaves the connection open. + // TODO: we should rework the API to make this a lot more foolproof. + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task QuicStream_ClosePattern_Success(bool abortive) { + while (!Debugger.IsAttached) + { + Console.WriteLine($"Attach to process {Process.GetCurrentProcess().Id}."); + await Task.Delay(100); + } + const int ExpectedErrorCode = 0xfffffff; using SemaphoreSlim sem = new SemaphoreSlim(0); @@ -614,8 +569,11 @@ public async Task QuicStream_CatchPattern_Success() await RunBidirectionalClientServer( async clientStream => { + // Don't shutdown client side until server side has 100% completed. await sem.WaitAsync(); - await Task.Delay(500); // wait for the shutdown to reach this side. + + // Wait for server's aborts to reach us. + await Task.Delay(500); QuicStreamAbortedException ex = await Assert.ThrowsAsync(async () => { @@ -633,46 +591,35 @@ await RunBidirectionalClientServer( }, async serverStream => { - using var cts = new CancellationTokenSource(); - try { - // We just need to throw an exception here - // Cancel reads, causing an OperationCanceledException - - ValueTask readTask = serverStream.ReadAsync(new byte[1], cts.Token); - - Assert.False(readTask.IsCompleted); - - cts.Cancel(); - await readTask; + // All the usual stream usage happens inside a try block. + // Just a dummy throw here to demonstrate the pattern... - Assert.False(true, "This point should never be reached."); + if (abortive) + { + throw new Exception(); + } } - catch (Exception ex) + catch { - Assert.True(ex is OperationCanceledException oce && oce.CancellationToken == cts.Token); - - // Abort here. The CloseAsync that follows will still wait for an ACK of the shutdown, - // so a cancellation token with a shutdown timeout is passed in. - + // Abort here. The CloseAsync that follows will still wait for an ACK of the shutdown. serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Both); + } + finally + { + // Call CloseAsync() with a cancellation token to allow it to time out when peer doesn't shutdown. using var shutdownCts = new CancellationTokenSource(500); try { await serverStream.CloseAsync(shutdownCts.Token); } - catch(Exception ex2) + catch { - // TODO: this catch block will basically never be executed right now -- we need a way to - // block the MsQuic from ACKing the abort. - - Assert.True(ex2 is OperationCanceledException oce2 && oce2.CancellationToken == shutdownCts.Token); - - // Abort again. The exit code is not important, because we gave it above already. + // Abort (possibly again, which will ignore error code and not queue any new I/O). // This time, Immediate is used which will cause CloseAsync() to not wait for a shutdown ACK. - serverStream.Abort(0, QuicAbortDirection.Immediate); + serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Immediate); } } @@ -682,7 +629,7 @@ await RunBidirectionalClientServer( // Only allow the other side to close its stream after the dispose completes. sem.Release(); - }); + }, millisecondsTimeout: 1_000_000_000); } } From 97ea324de05d22e13660ef1fb77c29ec35900b82 Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Tue, 22 Jun 2021 01:35:32 -0700 Subject: [PATCH 5/8] WIP --- .../Implementations/MsQuic/MsQuicStream.cs | 80 +++++------- .../tests/FunctionalTests/QuicStreamTests.cs | 122 +++++++----------- .../tests/FunctionalTests/QuicTestBase.cs | 15 +-- 3 files changed, 86 insertions(+), 131 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 6c33fa9ebd7b14..88c5cae43306ec 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -133,6 +133,9 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F QuicExceptionHelpers.ThrowIfFailed(status, "Failed to open stream to peer."); + // TODO: StreamStart is blocking on another thread here. + // We should refactor this to use the ASYNC flag. + status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.FAIL_BLOCKED); QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream."); } @@ -143,12 +146,15 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F throw; } + // TODO: our callback starts getting called as soon as we call StreamStart. + // Should this stuff be moved before that call? + if (!connectionState.TryAddStream(this)) { _state.Handle?.Dispose(); _stateHandle.Free(); throw new ObjectDisposedException(nameof(QuicConnection)); - } + } _state.ConnectionState = connectionState; @@ -325,8 +331,8 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio if (readState != ReadState.PendingRead && cancellationToken.IsCancellationRequested) { - readState = ReadState.Aborted; - _state.ReadState = ReadState.Aborted; + readState = ReadState.StreamAborted; + _state.ReadState = ReadState.StreamAborted; canceledSynchronously = true; } else if (readState == ReadState.None) @@ -348,7 +354,7 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio { completePendingRead = state.ReadState == ReadState.PendingRead; state.RootedReceiveStream = null; - state.ReadState = ReadState.Aborted; + state.ReadState = ReadState.StreamAborted; } if (completePendingRead) @@ -390,14 +396,17 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio case ReadState.PendingRead: ex = new InvalidOperationException("Only one read is supported at a time."); break; - case ReadState.Aborted: - default: - Debug.Assert(readState == ReadState.Aborted, $"{nameof(ReadState)} of '{readState}' is unaccounted for in {nameof(ReadAsync)}."); - + case ReadState.StreamAborted: ex = canceledSynchronously ? new OperationCanceledException(cancellationToken) : // aborted by token being canceled before the async op started. abortError == -1 ? new QuicOperationAbortedException() : // aborted by user via some other operation. new QuicStreamAbortedException(abortError); // aborted by peer. + + break; + case ReadState.ConnectionAborted: + default: + Debug.Assert(readState == ReadState.ConnectionAborted, $"{nameof(ReadState)} of '{readState}' is unaccounted for in {nameof(ReadAsync)}."); + ex = GetConnectionAbortedException(_state); break; } @@ -460,7 +469,7 @@ internal override void Abort(long errorCode, QuicAbortDirection abortDirection = { completeReads = _state.ReadState == ReadState.PendingRead; _state.RootedReceiveStream = null; - _state.ReadState = ReadState.Aborted; + _state.ReadState = ReadState.StreamAborted; flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; } } @@ -542,13 +551,13 @@ public override void Dispose() t.GetAwaiter().GetResult(); } + // TODO: there's a bug here where the safe handle is no longer valid. + // This shouldn't happen because the safe handle is effectively pinned + // until after disposal completes. ~MsQuicStream() { DisposeAsyncThrowaway(this); - // This is weird due to needing to keep _state alive for MsQuic's callback. - // See DisposeAsync implementation for details. - static async void DisposeAsyncThrowaway(MsQuicStream stream) { await stream.DisposeAsync(cancellationToken: default, async: true).ConfigureAwait(false); @@ -600,7 +609,6 @@ private async ValueTask DisposeAsync(CancellationToken cancellationToken, bool a { NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] disposed"); } - } private void EnableReceive() @@ -692,7 +700,6 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) ArrayPool.Shared.Return(oldReceiveBuffers); } } - } for (uint i = 0; i < receiveEvent.BufferCount; ++i) { @@ -712,9 +719,8 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); break; - case ReadState.Aborted: default: - Debug.Assert(state.ReadState == ReadState.Aborted, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); + Debug.Assert(state.ReadState is ReadState.StreamAborted or ReadState.ConnectionAborted, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); // There was a race between a user aborting the read stream and the callback being ran. // This will eat any received data. @@ -791,7 +797,7 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) { shouldComplete = true; } - state.ReadState = ReadState.Aborted; + state.ReadState = ReadState.StreamAborted; state.ReadErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; } @@ -1101,6 +1107,7 @@ private void ThrowIfDisposed() private static uint HandleEventConnectionClose(State state) { long errorCode = state.ConnectionState.AbortErrorCode; + if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] handling Connection#{state.ConnectionState.GetHashCode()} close" + @@ -1109,34 +1116,19 @@ private static uint HandleEventConnectionClose(State state) bool shouldCompleteRead = false; bool shouldCompleteSend = false; - bool shouldCompleteShutdownWrite = false; bool shouldCompleteShutdown = false; lock (state) { - if (state.ReadState == ReadState.None) - { - shouldCompleteRead = true; - } - state.ReadState = ReadState.ConnectionClosed; + shouldCompleteRead = state.ReadState == ReadState.PendingRead; + shouldCompleteSend = state.SendState is SendState.None or SendState.Pending; - if (state.SendState == SendState.None || state.SendState == SendState.Pending) + if (state.ReadState is not ReadState.EndOfReadStream or ReadState.StreamAborted) { - shouldCompleteSend = true; + state.ReadState = ReadState.ConnectionAborted; } - state.SendState = SendState.ConnectionClosed; - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - shouldCompleteShutdownWrite = true; - } - state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed; - - if (state.ShutdownState == ShutdownState.None) - { - shouldCompleteShutdown = true; - } - state.ShutdownState = ShutdownState.ConnectionClosed; + state.SendState = SendState.ConnectionClosed; } if (shouldCompleteRead) @@ -1151,12 +1143,6 @@ private static uint HandleEventConnectionClose(State state) ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); } - if (shouldCompleteShutdownWrite) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - if (shouldCompleteShutdown) { state.ShutdownCompletionSource.SetException( @@ -1192,14 +1178,14 @@ private enum ReadState EndOfReadStream, /// - /// User has aborted the stream, either via a cancellation token on ReadAsync(), or via Abort(read). + /// The stream has been aborted, either by user or by peer. /// - Aborted, + StreamAborted, /// - /// Connection was closed, either by user or by the peer. + /// The connection has been aborted, either by user or by peer. /// - ConnectionClosed + ConnectionAborted } private enum SendState diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index f9d5b4078a6244..e1693565d32d99 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -488,24 +488,18 @@ public async Task StreamAbortedWithoutWriting_ReadThrows() { long expectedErrorCode = 1234; - await RunClientServer( - clientFunction: async connection => + await RunUnidirectionalClientServer( + clientStream => { - await using QuicStream stream = connection.OpenUnidirectionalStream(); - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); + clientStream.Abort(expectedErrorCode, QuicAbortDirection.Write); + return Task.CompletedTask; }, - serverFunction: async connection => + async serverStream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[1]; - QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(stream, buffer)); + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(serverStream, buffer)); Assert.Equal(expectedErrorCode, ex.ErrorCode); - - await stream.ShutdownCompleted(); } ); } @@ -515,39 +509,31 @@ public async Task WritePreCanceled_Throws() { long expectedErrorCode = 1234; - await RunClientServer( - clientFunction: async connection => + await RunUnidirectionalClientServer( + async clientStream => { - await using QuicStream stream = connection.OpenUnidirectionalStream(); - CancellationTokenSource cts = new CancellationTokenSource(); cts.Cancel(); - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], cts.Token).AsTask()); + await Assert.ThrowsAsync(() => clientStream.WriteAsync(new byte[1], cts.Token).AsTask()); // next write would also throw - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1]).AsTask()); + await Assert.ThrowsAsync(() => clientStream.WriteAsync(new byte[1]).AsTask()); // manual write abort is still required - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); + clientStream.Abort(expectedErrorCode, QuicAbortDirection.Write); }, - serverFunction: async connection => + async serverStream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[1024 * 1024]; // TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530 //QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(stream, buffer)); try { - await ReadAll(stream, buffer); + await ReadAll(serverStream, buffer); } catch (QuicStreamAbortedException) { } - - await stream.ShutdownCompleted(); } ); } @@ -557,11 +543,9 @@ public async Task WriteCanceled_NextWriteThrows() { long expectedErrorCode = 1234; - await RunClientServer( - clientFunction: async connection => + await RunUnidirectionalClientServer( + async clientStream => { - await using QuicStream stream = connection.OpenUnidirectionalStream(); - CancellationTokenSource cts = new CancellationTokenSource(500); async Task WriteUntilCanceled() @@ -569,7 +553,7 @@ async Task WriteUntilCanceled() var buffer = new byte[64 * 1024]; while (true) { - await stream.WriteAsync(buffer, cancellationToken: cts.Token); + await clientStream.WriteAsync(buffer, cancellationToken: cts.Token); } } @@ -577,23 +561,19 @@ async Task WriteUntilCanceled() await Assert.ThrowsAsync(() => WriteUntilCanceled().WaitAsync(TimeSpan.FromSeconds(3))); // next write would also throw - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1]).AsTask()); + await Assert.ThrowsAsync(() => clientStream.WriteAsync(new byte[1]).AsTask()); // manual write abort is still required - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); + clientStream.Abort(expectedErrorCode, QuicAbortDirection.Write); }, - serverFunction: async connection => + async serverStream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - async Task ReadUntilAborted() { var buffer = new byte[1024]; while (true) { - int res = await stream.ReadAsync(buffer); + int res = await serverStream.ReadAsync(buffer); if (res == 0) { break; @@ -608,8 +588,6 @@ async Task ReadUntilAborted() await ReadUntilAborted().WaitAsync(TimeSpan.FromSeconds(3)); } catch (QuicStreamAbortedException) { } - - await stream.ShutdownCompleted(); } ); } @@ -658,40 +636,12 @@ await RunBidirectionalClientServer( [InlineData(true)] public async Task QuicStream_ClosePattern_Success(bool abortive) { - while (!Debugger.IsAttached) - { - Console.WriteLine($"Attach to process {Process.GetCurrentProcess().Id}."); - await Task.Delay(100); - } - const int ExpectedErrorCode = 0xfffffff; using SemaphoreSlim sem = new SemaphoreSlim(0); await RunBidirectionalClientServer( async clientStream => - { - // Don't shutdown client side until server side has 100% completed. - await sem.WaitAsync(); - - // Wait for server's aborts to reach us. - await Task.Delay(500); - - QuicStreamAbortedException ex = await Assert.ThrowsAsync(async () => - { - await clientStream.WriteAsync(new byte[1]); - }); - - Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - - ex = await Assert.ThrowsAsync(async () => - { - await clientStream.ReadAsync(new byte[1]); - }); - - Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - }, - async serverStream => { try { @@ -706,7 +656,7 @@ await RunBidirectionalClientServer( catch { // Abort here. The CloseAsync that follows will still wait for an ACK of the shutdown. - serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Both); + clientStream.Abort(ExpectedErrorCode, QuicAbortDirection.Both); } finally { @@ -715,23 +665,45 @@ await RunBidirectionalClientServer( using var shutdownCts = new CancellationTokenSource(500); try { - await serverStream.CloseAsync(shutdownCts.Token); + await clientStream.CloseAsync(shutdownCts.Token); } catch { // Abort (possibly again, which will ignore error code and not queue any new I/O). // This time, Immediate is used which will cause CloseAsync() to not wait for a shutdown ACK. - serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Immediate); + clientStream.Abort(ExpectedErrorCode, QuicAbortDirection.Immediate); } } // Either the CloseAsync above worked, in which case this is a no-op, // or the stream has been re-aborted with Immediate, in which case this will complete "immediately" but not synchronously. - await serverStream.CloseAsync(); + await clientStream.CloseAsync(); // Only allow the other side to close its stream after the dispose completes. sem.Release(); - }, millisecondsTimeout: 1_000_000_000); + }, + async serverStream => + { + // Don't shutdown client side until server side has 100% completed. + await sem.WaitAsync(); + + // Wait for server's abort to reach us. + await Task.Delay(500); + + QuicStreamAbortedException ex = await Assert.ThrowsAsync(async () => + { + await serverStream.WriteAsync(new byte[1]); + }); + + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + + ex = await Assert.ThrowsAsync(async () => + { + await serverStream.ReadAsync(new byte[1]); + }); + + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index ca8dccbee2fc8e..9f2e367e56fc55 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -138,22 +138,19 @@ internal async Task RunClientServer(Func clientFunction, F { using QuicListener listener = CreateQuicListener(); - var serverFinished = new ManualResetEventSlim(); - var clientFinished = new ManualResetEventSlim(); + using var serverFinished = new SemaphoreSlim(0); + using var clientFinished = new SemaphoreSlim(0); for (int i = 0; i < iterations; ++i) { - serverFinished.Reset(); - clientFinished.Reset(); - await new[] { Task.Run(async () => { using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); await serverFunction(serverConnection); - serverFinished.Set(); - clientFinished.Wait(); + serverFinished.Release(); + await clientFinished.WaitAsync(); await serverConnection.CloseAsync(0); }), Task.Run(async () => @@ -161,8 +158,8 @@ await new[] using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); await clientConnection.ConnectAsync(); await clientFunction(clientConnection); - clientFinished.Set(); - serverFinished.Wait(); + clientFinished.Release(); + await serverFinished.WaitAsync(); await clientConnection.CloseAsync(0); }) }.WhenAllOrAnyFailed(millisecondsTimeout); From d100b6fd1121c4e713f949c0a9e16e33b18158df Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Tue, 22 Jun 2021 01:45:51 -0700 Subject: [PATCH 6/8] WIP --- .../System/Net/Http/Http3LoopbackStream.cs | 17 +++++++---------- .../tests/FunctionalTests/MsQuicTests.cs | 4 ---- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index 294dbb89ce5b5d..6c6a06c85291c9 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -15,7 +15,7 @@ namespace System.Net.Test.Common { - internal sealed class Http3LoopbackStream : IDisposable + internal sealed class Http3LoopbackStream : IDisposable, IAsyncDisposable { private const int MaximumVarIntBytes = 8; private const long VarIntMax = (1L << 62) - 1; @@ -43,6 +43,10 @@ public void Dispose() { _stream.Dispose(); } + + public ValueTask DisposeAsync() => + _stream.DisposeAsync(); + public async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { HttpRequestData request = await ReadRequestDataAsync().ConfigureAwait(false); @@ -116,12 +120,6 @@ public async Task SendFrameAsync(long frameType, ReadOnlyMemory framePaylo await _stream.WriteAsync(framePayload).ConfigureAwait(false); } - public async Task ShutdownSendAsync() - { - await _stream.CompleteWritesAsync().ConfigureAwait(false); - await _stream.ShutdownWriteCompleted().ConfigureAwait(false); - } - static int EncodeHttpInteger(long longToEncode, Span buffer) { Debug.Assert(longToEncode >= 0); @@ -226,9 +224,8 @@ public async Task SendResponseBodyAsync(byte[] content, bool isFinal = true) if (isFinal) { - await ShutdownSendAsync().ConfigureAwait(false); - await _stream.ShutdownCompleted().ConfigureAwait(false); - Dispose(); + _stream.CompleteWrites(); + await DisposeAsync(); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 5c9ec7304d30e5..d19e339d58792e 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -230,8 +230,6 @@ await RunClientServer( break; } } - - stream.CompleteWrites(); }, async serverConnection => { @@ -247,8 +245,6 @@ await RunClientServer( int expectedTotalBytes = writes.SelectMany(x => x).Sum(); Assert.Equal(expectedTotalBytes, totalBytes); - - stream.CompleteWrites(); }); } From b22e1d886751d0e91b4354803014f09065e858b3 Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Thu, 24 Jun 2021 22:00:12 -0700 Subject: [PATCH 7/8] WIP --- .../System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 88c5cae43306ec..8ece072ec156ed 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -552,8 +552,8 @@ public override void Dispose() } // TODO: there's a bug here where the safe handle is no longer valid. - // This shouldn't happen because the safe handle is effectively pinned - // until after disposal completes. + // This shouldn't happen because the safe handle *should be* rooted + // until after our disposal completes. ~MsQuicStream() { DisposeAsyncThrowaway(this); From d9e42a78971af90afde38758b90395604a1c98ea Mon Sep 17 00:00:00 2001 From: Cory Nelson Date: Thu, 24 Jun 2021 22:01:28 -0700 Subject: [PATCH 8/8] WIP --- .../System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 4604bac5c26ad0..d93d200f586b35 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -91,7 +91,10 @@ public async ValueTask DisposeAsync() if (!_disposed) { _disposed = true; + + // TODO: use CloseAsync() with a cancellation token to prevent a DoS await _stream.DisposeAsync().ConfigureAwait(false); + DisposeSyncHelper(); } }