diff --git a/build/Dependencies.props b/build/Dependencies.props
index 9ede98e350..4afccaa748 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -8,7 +8,7 @@
1.5.0
4.5.1
4.3.0
- 4.8.0
+ 4.7.1
diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
index e8f7feace3..43ad73f248 100644
--- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
+++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
@@ -11,10 +11,10 @@
-
+
diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
index f3e7d96b59..8076f82ef9 100644
--- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
+++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
@@ -10,7 +10,7 @@
-
+
diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
index 090a34049e..87e1ecdec8 100644
--- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
@@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
+using System.Threading.Channels;
using System.Threading.Tasks;
-using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
@@ -486,13 +486,12 @@ private static readonly FuncInstanceMethodInfo1 _createGe
private int _liveCount;
private bool _doneConsuming;
- private readonly BufferBlock _toProduce;
- private readonly BufferBlock _toConsume;
+ private readonly Channel _toProduceChannel;
+ private readonly Channel _toConsumeChannel;
private readonly Task _producerTask;
private Exception _producerTaskException;
private readonly int[] _colToActivesIndex;
- private bool _disposed;
public override DataViewSchema Schema => _input.Schema;
@@ -541,46 +540,20 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input,
_liveCount = 1;
// Set up the producer worker.
- _toConsume = new BufferBlock();
- _toProduce = new BufferBlock();
+ _toConsumeChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleWriter = true });
+ _toProduceChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleWriter = true });
// First request the pool - 1 + block size rows, to get us going.
- PostAssert(_toProduce, _poolRows - 1 + _blockSize);
+ PostAssert(_toProduceChannel, _poolRows - 1 + _blockSize);
// Queue up the remaining capacity.
for (int i = 1; i < _bufferDepth; ++i)
- PostAssert(_toProduce, _blockSize);
+ PostAssert(_toProduceChannel, _blockSize);
_producerTask = ProduceAsync();
}
- protected override void Dispose(bool disposing)
+ public static void PostAssert(Channel target, T item)
{
- if (_disposed)
- return;
-
- if (disposing)
- {
- _toProduce.Complete();
- _producerTask.Wait();
-
- // Complete the consumer after the producerTask has finished, since producerTask could
- // have posted more items to _toConsume.
- _toConsume.Complete();
-
- // Drain both BufferBlocks - this prevents what appears to be memory leaks when using the VS Debugger
- // because if a BufferBlock still contains items, its underlying Tasks are not getting completed.
- // See https://github.com/dotnet/corefx/issues/30582 for the VS Debugger issue.
- // See also https://github.com/dotnet/machinelearning/issues/4399
- _toProduce.TryReceiveAll(out _);
- _toConsume.TryReceiveAll(out _);
- }
-
- _disposed = true;
- base.Dispose(disposing);
- }
-
- public static void PostAssert(ITargetBlock target, T item)
- {
- bool retval = target.Post(item);
+ bool retval = target.Writer.TryWrite(item);
Contracts.Assert(retval);
}
@@ -594,12 +567,13 @@ private async Task ProduceAsync()
try
{
int circularIndex = 0;
- while (await _toProduce.OutputAvailableAsync().ConfigureAwait(false))
+ while (await _toProduceChannel.Reader.WaitToReadAsync().ConfigureAwait(false))
{
int requested;
- if (!_toProduce.TryReceive(out requested))
+ if (!_toProduceChannel.Reader.TryRead(out requested))
{
- // OutputAvailableAsync returned true, but TryReceive returned false -
+ // The producer Channel's Reader.WaitToReadAsync returned true,
+ // but the Reader's TryRead returned false -
// so loop back around and try again.
continue;
}
@@ -618,14 +592,14 @@ private async Task ProduceAsync()
if (circularIndex == _pipeIndices.Length)
circularIndex = 0;
}
- PostAssert(_toConsume, numRows);
+ PostAssert(_toConsumeChannel, numRows);
if (numRows < requested)
{
// We've reached the end of the cursor. Send the sentinel, then exit.
// This assumes that the receiver will receive things in Post order
// (so that the sentinel is received, after the last Post).
if (numRows > 0)
- PostAssert(_toConsume, 0);
+ PostAssert(_toConsumeChannel, 0);
return;
}
}
@@ -634,7 +608,7 @@ private async Task ProduceAsync()
{
_producerTaskException = ex;
// Send the sentinel in this case as well, the field will be checked.
- PostAssert(_toConsume, 0);
+ PostAssert(_toConsumeChannel, 0);
}
}
@@ -651,26 +625,32 @@ protected override bool MoveNextCore()
{
// We should let the producer know it can give us more stuff.
// It is possible for int values to be sent beyond the
- // end of the sentinel, but we suppose this is irrelevant.
- PostAssert(_toProduce, _deadCount);
+ // end of the Channel, but we suppose this is irrelevant.
+ PostAssert(_toProduceChannel, _deadCount);
_deadCount = 0;
}
while (_liveCount < _poolRows && !_doneConsuming)
{
// We are under capacity. Try to get some more.
- int got = _toConsume.Receive();
- if (got == 0)
+ while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult())
{
- // We've reached the end sentinel. There's no reason
- // to attempt further communication with the producer.
- // Check whether something horrible happened.
- if (_producerTaskException != null)
- throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
- _doneConsuming = true;
- break;
+ var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
+ if (hasReadItem)
+ {
+ if (got == 0)
+ {
+ // We've reached the end of the Channel. There's no reason
+ // to attempt further communication with the producer.
+ // Check whether something horrible happened.
+ if (_producerTaskException != null)
+ throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
+ _doneConsuming = true;
+ break;
+ }
+ _liveCount += got;
+ }
}
- _liveCount += got;
}
if (_liveCount == 0)
return false;
diff --git a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs
index a300999ebc..6e71d13729 100644
--- a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs
@@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
+using System.Threading.Channels;
using System.Threading.Tasks;
-using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
@@ -168,7 +168,7 @@ public sealed class Options
private readonly object _lock;
private readonly CancellationTokenSource _cts;
- private readonly BufferBlock _paramQueue;
+ private readonly Channel _paramChannel;
private readonly int _relaxation;
private readonly ISweeper _baseSweeper;
private readonly IHost _host;
@@ -208,7 +208,8 @@ public DeterministicSweeperAsync(IHostEnvironment env, Options options)
_lock = new object();
_results = new List();
_nullRuns = new HashSet();
- _paramQueue = new BufferBlock();
+ _paramChannel = Channel.CreateUnbounded(
+ new UnboundedChannelOptions { SingleWriter = true });
PrepareNextBatch(null);
}
@@ -220,12 +221,12 @@ private void PrepareNextBatch(IEnumerable results)
if (Utils.Size(paramSets) == 0)
{
// Mark the queue as completed.
- _paramQueue.Complete();
+ _paramChannel.Writer.Complete();
return;
}
// Assign an id to each ParameterSet and enque it.
foreach (var paramSet in paramSets)
- _paramQueue.Post(new ParameterSetWithId(_numGenerated++, paramSet));
+ _paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
EnsureResultsSize();
}
@@ -278,7 +279,7 @@ public async Task ProposeAsync()
return null;
try
{
- return await _paramQueue.ReceiveAsync(_cts.Token);
+ return await _paramChannel.Reader.ReadAsync(_cts.Token);
}
catch (InvalidOperationException)
{
diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
index 292e16f30f..2c6b18811d 100644
--- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
+++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
@@ -22,7 +22,6 @@
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll"
-#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll"