diff --git a/build/Dependencies.props b/build/Dependencies.props index 9ede98e350..4afccaa748 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -8,7 +8,7 @@ 1.5.0 4.5.1 4.3.0 - 4.8.0 + 4.7.1 diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj index e8f7feace3..43ad73f248 100644 --- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj +++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj @@ -11,10 +11,10 @@ - + diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj index f3e7d96b59..8076f82ef9 100644 --- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj +++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj @@ -10,7 +10,7 @@ - + diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 090a34049e..87e1ecdec8 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -5,8 +5,8 @@ using System; using System.Collections.Generic; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; -using System.Threading.Tasks.Dataflow; using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; @@ -486,13 +486,12 @@ private static readonly FuncInstanceMethodInfo1 _createGe private int _liveCount; private bool _doneConsuming; - private readonly BufferBlock _toProduce; - private readonly BufferBlock _toConsume; + private readonly Channel _toProduceChannel; + private readonly Channel _toConsumeChannel; private readonly Task _producerTask; private Exception _producerTaskException; private readonly int[] _colToActivesIndex; - private bool _disposed; public override DataViewSchema Schema => _input.Schema; @@ -541,46 +540,20 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input, _liveCount = 1; // Set up the producer worker. - _toConsume = new BufferBlock(); - _toProduce = new BufferBlock(); + _toConsumeChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleWriter = true }); + _toProduceChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleWriter = true }); // First request the pool - 1 + block size rows, to get us going. - PostAssert(_toProduce, _poolRows - 1 + _blockSize); + PostAssert(_toProduceChannel, _poolRows - 1 + _blockSize); // Queue up the remaining capacity. for (int i = 1; i < _bufferDepth; ++i) - PostAssert(_toProduce, _blockSize); + PostAssert(_toProduceChannel, _blockSize); _producerTask = ProduceAsync(); } - protected override void Dispose(bool disposing) + public static void PostAssert(Channel target, T item) { - if (_disposed) - return; - - if (disposing) - { - _toProduce.Complete(); - _producerTask.Wait(); - - // Complete the consumer after the producerTask has finished, since producerTask could - // have posted more items to _toConsume. - _toConsume.Complete(); - - // Drain both BufferBlocks - this prevents what appears to be memory leaks when using the VS Debugger - // because if a BufferBlock still contains items, its underlying Tasks are not getting completed. - // See https://github.com/dotnet/corefx/issues/30582 for the VS Debugger issue. - // See also https://github.com/dotnet/machinelearning/issues/4399 - _toProduce.TryReceiveAll(out _); - _toConsume.TryReceiveAll(out _); - } - - _disposed = true; - base.Dispose(disposing); - } - - public static void PostAssert(ITargetBlock target, T item) - { - bool retval = target.Post(item); + bool retval = target.Writer.TryWrite(item); Contracts.Assert(retval); } @@ -594,12 +567,13 @@ private async Task ProduceAsync() try { int circularIndex = 0; - while (await _toProduce.OutputAvailableAsync().ConfigureAwait(false)) + while (await _toProduceChannel.Reader.WaitToReadAsync().ConfigureAwait(false)) { int requested; - if (!_toProduce.TryReceive(out requested)) + if (!_toProduceChannel.Reader.TryRead(out requested)) { - // OutputAvailableAsync returned true, but TryReceive returned false - + // The producer Channel's Reader.WaitToReadAsync returned true, + // but the Reader's TryRead returned false - // so loop back around and try again. continue; } @@ -618,14 +592,14 @@ private async Task ProduceAsync() if (circularIndex == _pipeIndices.Length) circularIndex = 0; } - PostAssert(_toConsume, numRows); + PostAssert(_toConsumeChannel, numRows); if (numRows < requested) { // We've reached the end of the cursor. Send the sentinel, then exit. // This assumes that the receiver will receive things in Post order // (so that the sentinel is received, after the last Post). if (numRows > 0) - PostAssert(_toConsume, 0); + PostAssert(_toConsumeChannel, 0); return; } } @@ -634,7 +608,7 @@ private async Task ProduceAsync() { _producerTaskException = ex; // Send the sentinel in this case as well, the field will be checked. - PostAssert(_toConsume, 0); + PostAssert(_toConsumeChannel, 0); } } @@ -651,26 +625,32 @@ protected override bool MoveNextCore() { // We should let the producer know it can give us more stuff. // It is possible for int values to be sent beyond the - // end of the sentinel, but we suppose this is irrelevant. - PostAssert(_toProduce, _deadCount); + // end of the Channel, but we suppose this is irrelevant. + PostAssert(_toProduceChannel, _deadCount); _deadCount = 0; } while (_liveCount < _poolRows && !_doneConsuming) { // We are under capacity. Try to get some more. - int got = _toConsume.Receive(); - if (got == 0) + while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult()) { - // We've reached the end sentinel. There's no reason - // to attempt further communication with the producer. - // Check whether something horrible happened. - if (_producerTaskException != null) - throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception"); - _doneConsuming = true; - break; + var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got); + if (hasReadItem) + { + if (got == 0) + { + // We've reached the end of the Channel. There's no reason + // to attempt further communication with the producer. + // Check whether something horrible happened. + if (_producerTaskException != null) + throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception"); + _doneConsuming = true; + break; + } + _liveCount += got; + } } - _liveCount += got; } if (_liveCount == 0) return false; diff --git a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs index a300999ebc..6e71d13729 100644 --- a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs +++ b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs @@ -5,8 +5,8 @@ using System; using System.Collections.Generic; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; -using System.Threading.Tasks.Dataflow; using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Internal.Utilities; @@ -168,7 +168,7 @@ public sealed class Options private readonly object _lock; private readonly CancellationTokenSource _cts; - private readonly BufferBlock _paramQueue; + private readonly Channel _paramChannel; private readonly int _relaxation; private readonly ISweeper _baseSweeper; private readonly IHost _host; @@ -208,7 +208,8 @@ public DeterministicSweeperAsync(IHostEnvironment env, Options options) _lock = new object(); _results = new List(); _nullRuns = new HashSet(); - _paramQueue = new BufferBlock(); + _paramChannel = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleWriter = true }); PrepareNextBatch(null); } @@ -220,12 +221,12 @@ private void PrepareNextBatch(IEnumerable results) if (Utils.Size(paramSets) == 0) { // Mark the queue as completed. - _paramQueue.Complete(); + _paramChannel.Writer.Complete(); return; } // Assign an id to each ParameterSet and enque it. foreach (var paramSet in paramSets) - _paramQueue.Post(new ParameterSetWithId(_numGenerated++, paramSet)); + _paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet)); EnsureResultsSize(); } @@ -278,7 +279,7 @@ public async Task ProposeAsync() return null; try { - return await _paramQueue.ReceiveAsync(_cts.Token); + return await _paramChannel.Reader.ReadAsync(_cts.Token); } catch (InvalidOperationException) { diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 292e16f30f..2c6b18811d 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -22,7 +22,6 @@ #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll" -#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll"