From 62875ba49723e96c9b1f516a226a81c3bdb196b9 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Sat, 11 Oct 2025 19:44:32 +0700 Subject: [PATCH 1/9] Add SIMD Mersenne Twister RNG + pluggable RNG source; wire through Host/MLContext - Introduce internal IRandomSource and IRandomBulkSource - Add adapters/shims: RandomSourceAdapter, RandomFromRandomSource, RandomShim - Implement SIMD-backed MersenneTwisterRandomSource (MT19937) and core MersenneTwister - Wire IRandomSource through HostEnvironmentBase, ConsoleEnvironment, LocalEnvironment, MLContext - Add tests for determinism and mixed-call consumption --- .../Data/IHostEnvironment.cs | 11 + .../Environment/ConsoleEnvironment.cs | 18 +- .../Environment/HostEnvironmentBase.cs | 35 +- .../Microsoft.ML.Core.csproj | 2 +- .../Utilities/FuncInstanceMethodInfo1`2.cs | 22 +- .../Utilities/FuncInstanceMethodInfo1`3.cs | 22 +- .../Utilities/FuncInstanceMethodInfo1`4.cs | 22 +- .../Utilities/FuncInstanceMethodInfo2`4.cs | 22 +- .../Utilities/FuncInstanceMethodInfo3`3.cs | 22 +- .../Utilities/FuncInstanceMethodInfo3`4.cs | 22 +- .../Utilities/IRandomBulkSource.cs | 21 ++ .../Utilities/IRandomSource.cs | 28 ++ .../Utilities/MersenneTwister.cs | 334 ++++++++++++++++++ .../Utilities/MersenneTwisterRandomSource.cs | 156 ++++++++ src/Microsoft.ML.Core/Utilities/Random.cs | 4 + .../Utilities/RandomFromRandomSource.cs | 45 +++ src/Microsoft.ML.Core/Utilities/RandomShim.cs | 83 +++++ .../Utilities/RandomSourceAdapter.cs | 72 ++++ .../Utilities/ResourceManagerUtils.cs | 62 +++- src/Microsoft.ML.Data/MLContext.cs | 12 +- .../Utilities/LocalEnvironment.cs | 19 +- .../Microsoft.ML.AutoML.Tests/AutoFitTests.cs | 3 +- .../Helpers/AdditionalMetadataReferences.cs | 5 +- .../Helpers/CSharpCodeFixVerifier`2.cs | 13 +- .../Helpers/CompatibleXUnitVerifier.cs | 224 ++++++++++++ .../UnitTests/MersenneTwisterTests.cs | 141 ++++++++ .../ContractsCheckAnalyzer.cs | 5 +- 27 files changed, 1335 insertions(+), 90 deletions(-) create mode 100644 src/Microsoft.ML.Core/Utilities/IRandomBulkSource.cs create mode 100644 src/Microsoft.ML.Core/Utilities/IRandomSource.cs create mode 100644 src/Microsoft.ML.Core/Utilities/MersenneTwister.cs create mode 100644 src/Microsoft.ML.Core/Utilities/MersenneTwisterRandomSource.cs create mode 100644 src/Microsoft.ML.Core/Utilities/RandomFromRandomSource.cs create mode 100644 src/Microsoft.ML.Core/Utilities/RandomShim.cs create mode 100644 src/Microsoft.ML.Core/Utilities/RandomSourceAdapter.cs create mode 100644 test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CompatibleXUnitVerifier.cs create mode 100644 test/Microsoft.ML.Core.Tests/UnitTests/MersenneTwisterTests.cs diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index f9849c9a6c..5a2eef5952 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using Microsoft.ML.Data; +using Microsoft.ML; namespace Microsoft.ML.Runtime; @@ -115,6 +116,11 @@ internal interface IHostEnvironmentInternal : IHostEnvironment T GetOptionOrDefault(string name); bool RemoveOption(string name); + + /// + /// Global random source underpinning this environment. + /// + IRandomSource RandomSource { get; } } /// @@ -129,6 +135,11 @@ public interface IHost : IHostEnvironment /// generators are NOT thread safe. /// Random Rand { get; } + + /// + /// The random source backing . + /// + IRandomSource RandomSource { get; } } /// diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index da07122321..79d4e9e197 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using System.Threading; +using Microsoft.ML; namespace Microsoft.ML.Runtime; @@ -359,14 +360,16 @@ protected override void Dispose(bool disposing) /// /// Random seed. Set to null for a non-deterministic environment. /// Set to true for fully verbose logging. + /// Optional random source backing this environment. /// Allowed message sensitivity. /// Text writer to print normal messages to. /// Text writer to print error messages to. /// Optional TextWriter to write messages if the host is a test environment. public ConsoleEnvironment(int? seed = null, bool verbose = false, + IRandomSource randomSource = null, MessageSensitivity sensitivity = MessageSensitivity.All, TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null) - : base(seed, verbose, nameof(ConsoleEnvironment)) + : base(seed, verbose, randomSource, nameof(ConsoleEnvironment)) { Contracts.CheckValueOrNull(outWriter); Contracts.CheckValueOrNull(errWriter); @@ -391,13 +394,14 @@ private void PrintMessage(IMessageSource src, ChannelMessage msg) Root._consoleWriter.PrintMessage(src, msg); } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) { Contracts.AssertValue(rand); + Contracts.AssertValue(randomSource); Contracts.AssertValueOrNull(parentFullName); Contracts.AssertNonEmpty(shortName); Contracts.Assert(source == this || source is Host); - return new Host(source, shortName, parentFullName, rand, verbose); + return new Host(source, shortName, parentFullName, rand, randomSource, verbose); } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) @@ -462,8 +466,8 @@ public void Dispose() private sealed class Host : HostBase { - public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) - : base(source, shortName, parentFullName, rand, verbose) + public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) + : base(source, shortName, parentFullName, rand, randomSource, verbose) { IsCanceled = source.IsCanceled; } @@ -484,9 +488,9 @@ protected override IPipe CreatePipe(ChannelProviderBase pare return new Pipe(parent, name, GetDispatchDelegate()); } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) { - return new Host(source, shortName, parentFullName, rand, verbose); + return new Host(source, shortName, parentFullName, rand, randomSource, verbose); } } } diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 7ab2620169..de8506e91f 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -6,6 +6,8 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; +using Microsoft.ML; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Runtime; @@ -121,8 +123,8 @@ public abstract class HostBase : HostEnvironmentBase, IHost public Random Rand => _rand; - public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) - : base(source, rand, verbose, shortName, parentFullName) + public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) + : base(source, rand, randomSource ?? new RandomSourceAdapter(rand), verbose, shortName, parentFullName) { Depth = source.Depth + 1; } @@ -140,7 +142,8 @@ public HostBase(HostEnvironmentBase source, string shortName, string paren { _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false); Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); - host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + IRandomSource randomSource = new RandomSourceAdapter(rand); + host = RegisterCore(this, name, Master?.FullName, rand, randomSource, verbose ?? Verbose); if (!IsCanceled) _children.Add(new WeakReference(host)); } @@ -338,6 +341,8 @@ public void RemoveListener(Action listenerFunc) protected Dictionary Options { get; } = []; #pragma warning restore MSML_NoInstanceInitializers + public IRandomSource RandomSource => _randomSource; + protected readonly TEnv Root; // This is non-null iff this environment was a fork of another. Disposing a fork // doesn't free temp files. That is handled when the master is disposed. @@ -348,6 +353,7 @@ public void RemoveListener(Action listenerFunc) // The random number generator for this host. private readonly Random _rand; + private readonly IRandomSource _randomSource; public int? Seed { get; } @@ -369,11 +375,22 @@ public void RemoveListener(Action listenerFunc) /// The main constructor. /// protected HostEnvironmentBase(int? seed, bool verbose, + IRandomSource randomSource = null, string shortName = null, string parentFullName = null) : base(shortName, parentFullName, verbose) { Seed = seed; - _rand = RandomUtils.Create(Seed); + if (randomSource is null) + { + var baseRandom = RandomUtils.Create(Seed); + _rand = baseRandom; + _randomSource = new RandomSourceAdapter(baseRandom); + } + else + { + _randomSource = randomSource; + _rand = randomSource as Random ?? new RandomFromRandomSource(randomSource); + } ListenerDict = new ConcurrentDictionary(); ProgressTracker = new ProgressReporting.ProgressTracker(this); _cancelLock = new object(); @@ -385,13 +402,14 @@ protected HostEnvironmentBase(int? seed, bool verbose, /// /// This constructor is for forking. /// - protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose, + protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, IRandomSource randomSource, bool verbose, string shortName = null, string parentFullName = null) : base(shortName, parentFullName, verbose) { Contracts.CheckValue(source, nameof(source)); Contracts.CheckValueOrNull(rand); - _rand = rand ?? RandomUtils.Create(); + _randomSource = randomSource ?? (rand != null ? new RandomSourceAdapter(rand) : new RandomSourceAdapter(RandomUtils.Create())); + _rand = rand ?? (_randomSource as Random ?? new RandomFromRandomSource(_randomSource)); _cancelLock = new object(); // This fork shares some stuff with the master. @@ -419,7 +437,8 @@ public IHost Register(string name, int? seed = null, bool? verbose = null) { _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false); Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); - host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + IRandomSource randomSource = new RandomSourceAdapter(rand); + host = RegisterCore(this, name, Master?.FullName, rand, randomSource, verbose ?? Verbose); // Need to manually copy over the parameters //((IHostEnvironmentInternal)host).Seed = this.Seed; @@ -433,7 +452,7 @@ public IHost Register(string name, int? seed = null, bool? verbose = null) } protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName, - string parentFullName, Random rand, bool verbose); + string parentFullName, Random rand, IRandomSource randomSource, bool verbose); public IProgressChannel StartProgressChannel(string name) { diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj index 36429997c5..7f5f8b4a02 100644 --- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj +++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + netstandard2.0;net8.0 true CORECLR Microsoft.ML diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs index 162130dd91..c0747778f2 100644 --- a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs @@ -70,16 +70,26 @@ public static FuncInstanceMethodInfo1 Create(Expression Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); - Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0]; + Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form"); + if (delegateTypeExpression.Value is not Type delegateType) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(delegateType == typeof(Func), nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); // Check the MethodInfo - Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); - - var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + if (methodCallExpression.Object is not ConstantExpression methodInfoExpression) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + if (methodInfoExpression.Value is not MethodInfo methodInfo) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs index 9117a1b488..ce058142ce 100644 --- a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs @@ -71,16 +71,26 @@ public static FuncInstanceMethodInfo1 Create(Expression Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); - Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0]; + Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form"); + if (delegateTypeExpression.Value is not Type delegateType) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(delegateType == typeof(Func), nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); // Check the MethodInfo - Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); - - var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + if (methodCallExpression.Object is not ConstantExpression methodInfoExpression) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + if (methodInfoExpression.Value is not MethodInfo methodInfo) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs index cb0ae6451b..e9044acc3f 100644 --- a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs @@ -72,16 +72,26 @@ public static FuncInstanceMethodInfo1 Create(Expressio // Verify that we are creating a delegate of type Func Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); - Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0]; + Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form"); + if (delegateTypeExpression.Value is not Type delegateType) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(delegateType == typeof(Func), nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); // Check the MethodInfo - Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); - - var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + if (methodCallExpression.Object is not ConstantExpression methodInfoExpression) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + if (methodInfoExpression.Value is not MethodInfo methodInfo) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs index 1a4f9c72ff..c76bd3ff67 100644 --- a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs @@ -72,16 +72,26 @@ public static FuncInstanceMethodInfo2 Create(Expressio // Verify that we are creating a delegate of type Func Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); - Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0]; + Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form"); + if (delegateTypeExpression.Value is not Type delegateType) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(delegateType == typeof(Func), nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); // Check the MethodInfo - Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); - - var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + if (methodCallExpression.Object is not ConstantExpression methodInfoExpression) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + if (methodInfoExpression.Value is not MethodInfo methodInfo) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs index 91c1c0d747..acceb608fc 100644 --- a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs @@ -71,16 +71,26 @@ public static FuncInstanceMethodInfo3 Create(Expression Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); - Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0]; + Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form"); + if (delegateTypeExpression.Value is not Type delegateType) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(delegateType == typeof(Func), nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); // Check the MethodInfo - Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); - - var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + if (methodCallExpression.Object is not ConstantExpression methodInfoExpression) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + if (methodInfoExpression.Value is not MethodInfo methodInfo) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs index 51b8dbbe34..a539727a96 100644 --- a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs @@ -72,16 +72,26 @@ public static FuncInstanceMethodInfo3 Create(Expressio // Verify that we are creating a delegate of type Func Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); - Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0]; + Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form"); + if (delegateTypeExpression.Value is not Type delegateType) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(delegateType == typeof(Func), nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); // Check the MethodInfo - Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); - Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); - - var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + if (methodCallExpression.Object is not ConstantExpression methodInfoExpression) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + if (methodInfoExpression.Value is not MethodInfo methodInfo) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); diff --git a/src/Microsoft.ML.Core/Utilities/IRandomBulkSource.cs b/src/Microsoft.ML.Core/Utilities/IRandomBulkSource.cs new file mode 100644 index 0000000000..a7f061cc7e --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/IRandomBulkSource.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Optional extension for RNG engines that can produce bulk sequences efficiently. + /// + internal interface IRandomBulkSource + { + /// Fills with independent U[0,1) doubles. + void NextDoubles(Span destination); + + /// Fills with independent uint values covering the full 32-bit range. + void NextUInt32(Span destination); + } +} + diff --git a/src/Microsoft.ML.Core/Utilities/IRandomSource.cs b/src/Microsoft.ML.Core/Utilities/IRandomSource.cs new file mode 100644 index 0000000000..022b5a22ef --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/IRandomSource.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML +{ + /// + /// Abstraction for RNG engines that expose the standard surface. + /// + public interface IRandomSource + { + int Next(); + int Next(int maxValue); + int Next(int minValue, int maxValue); + + long NextInt64(); + long NextInt64(long maxValue); + long NextInt64(long minValue, long maxValue); + + double NextDouble(); + float NextSingle(); + + void NextBytes(Span buffer); + } +} + diff --git a/src/Microsoft.ML.Core/Utilities/MersenneTwister.cs b/src/Microsoft.ML.Core/Utilities/MersenneTwister.cs new file mode 100644 index 0000000000..f1473ec1d9 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/MersenneTwister.cs @@ -0,0 +1,334 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.CompilerServices; +#if NET8_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; +#endif + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Highly optimized SIMD-enabled Mersenne Twister implementation. + /// + internal sealed class MersenneTwister + { + private const int N = 624; + private const int M = 397; + private const uint MatrixA = 0x9908B0DFU; + private const uint UpperMask = 0x80000000U; + private const uint LowerMask = 0x7FFFFFFFU; + + private readonly uint[] _mt = new uint[N]; + private int _mti = N + 1; + + private readonly uint[] _buf = new uint[N]; + private uint _carry; + private bool _hasCarry; + + public MersenneTwister(uint seed) + { + InitGenrand(seed); + } + + private void InitGenrand(uint s) + { + unchecked + { + _mt[0] = s; + for (_mti = 1; _mti < N; _mti++) + { + var x = _mt[_mti - 1]; + _mt[_mti] = 1812433253U * (x ^ (x >> 30)) + (uint)_mti; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void Twist() + { + var kk = 0; + uint y; + + for (; kk < N - M; kk++) + { + y = (_mt[kk] & UpperMask) | (_mt[kk + 1] & LowerMask); + _mt[kk] = _mt[kk + M] ^ (y >> 1) ^ ((y & 1U) != 0 ? MatrixA : 0U); + } + + for (; kk < N - 1; kk++) + { + y = (_mt[kk] & UpperMask) | (_mt[kk + 1] & LowerMask); + _mt[kk] = _mt[kk - (N - M)] ^ (y >> 1) ^ ((y & 1U) != 0 ? MatrixA : 0U); + } + + y = (_mt[N - 1] & UpperMask) | (_mt[0] & LowerMask); + _mt[N - 1] = _mt[M - 1] ^ (y >> 1) ^ ((y & 1U) != 0 ? MatrixA : 0U); + _mti = 0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void TemperScalar(ReadOnlySpan src, Span dst) + { + for (var i = 0; i < src.Length; i++) + { + var y = src[i]; + y ^= (y >> 11); + y ^= (y << 7) & 0x9D2C5680U; + y ^= (y << 15) & 0xEFC60000U; + y ^= (y >> 18); + dst[i] = y; + } + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void TemperAvx2(ReadOnlySpan src, Span dst) + { + var len = src.Length; + var i = 0; + var c7 = Vector256.Create(0x9D2C5680u); + var c15 = Vector256.Create(0xEFC60000u); + + fixed (uint* pSrc = src) + fixed (uint* pDst = dst) + { + for (; i + 8 <= len; i += 8) + { + var y = Avx.LoadVector256(pSrc + i); + y = Avx2.Xor(y, Avx2.ShiftRightLogical(y, 11)); + + var t = Avx2.And(Avx2.ShiftLeftLogical(y, 7), c7); + y = Avx2.Xor(y, t); + + t = Avx2.And(Avx2.ShiftLeftLogical(y, 15), c15); + y = Avx2.Xor(y, t); + + y = Avx2.Xor(y, Avx2.ShiftRightLogical(y, 18)); + Avx.Store(pDst + i, y); + } + } + + for (; i < len; i++) + { + var y = src[i]; + y ^= (y >> 11); + y ^= (y << 7) & 0x9D2C5680u; + y ^= (y << 15) & 0xEFC60000u; + y ^= (y >> 18); + dst[i] = y; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void TemperAdvSimd(ReadOnlySpan src, Span dst) + { + if (!AdvSimd.IsSupported) + throw new PlatformNotSupportedException("AdvSimd is not supported on this CPU."); + + var len = src.Length; + var i = 0; + var c7 = Vector128.Create(0x9D2C5680u); + var c15 = Vector128.Create(0xEFC60000u); + + fixed (uint* pSrc = src) + fixed (uint* pDst = dst) + { + for (; i + 4 <= len; i += 4) + { + var y = Unsafe.ReadUnaligned>((byte*)(pSrc + i)); + y = AdvSimd.Xor(y, AdvSimd.ShiftRightLogical(y, 11)); + + var t = AdvSimd.And(AdvSimd.ShiftLeftLogical(y, 7), c7); + y = AdvSimd.Xor(y, t); + + t = AdvSimd.And(AdvSimd.ShiftLeftLogical(y, 15), c15); + y = AdvSimd.Xor(y, t); + + y = AdvSimd.Xor(y, AdvSimd.ShiftRightLogical(y, 18)); + Unsafe.WriteUnaligned((byte*)(pDst + i), y); + } + } + + for (; i < len; i++) + { + var y = src[i]; + y ^= (y >> 11); + y ^= (y << 7) & 0x9D2C5680u; + y ^= (y << 15) & 0xEFC60000u; + y ^= (y >> 18); + dst[i] = y; + } + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void Temper(ReadOnlySpan src, Span dst) + { +#if NET8_0_OR_GREATER + if (src.Length >= 8 && Avx2.IsSupported) + { + TemperAvx2(src, dst); + return; + } + + if (src.Length >= 4 && AdvSimd.IsSupported) + { + TemperAdvSimd(src, dst); + return; + } +#endif + TemperScalar(src, dst); + } + + private const double DoubleDivisor = 1.0 / 9007199254740992.0; // 1 / 2^53 + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static double DoubleFromMant53(ulong mant53) + { + return mant53 * DoubleDivisor; + } + + public double NextDouble() + { + Span buffer = stackalloc double[1]; + NextDoubles(buffer); + return buffer[0]; + } + + public unsafe void NextDoubles(Span destination) + { + var n = destination.Length; + var filled = 0; + + fixed (uint* pMt = _mt) + fixed (uint* pBuf = _buf) + fixed (double* pDst = destination) + { + while (filled < n) + { + if (!_hasCarry) + { + if (_mti >= N) + Twist(); + + Temper(new ReadOnlySpan(pMt + _mti, 1), new Span(pBuf, 1)); + _carry = pBuf[0]; + _mti += 1; + _hasCarry = true; + } + + if (_mti >= N) + Twist(); + + var pairsRemaining = n - filled; + var availInts = N - _mti; + + if (availInts == 0) + { + Twist(); + continue; + } + + var maxPairsFromAvail = 1 + ((availInts - 1) >> 1); + var makePairs = Math.Min(pairsRemaining, maxPairsFromAvail); + var wantInts = (makePairs << 1) - 1; + + Temper(new ReadOnlySpan(pMt + _mti, wantInts), new Span(pBuf, wantInts)); + _mti += wantInts; + + var j = 0; + var a = (ulong)(_carry >> 5); + var b = (ulong)(pBuf[j++] >> 6); + pDst[filled++] = DoubleFromMant53((a << 26) | b); + _hasCarry = false; + + var remainingPairs = makePairs - 1; + for (var p = 0; p < remainingPairs; p++) + { + a = (ulong)(pBuf[j++] >> 5); + b = (ulong)(pBuf[j++] >> 6); + pDst[filled++] = DoubleFromMant53((a << 26) | b); + } + + if (filled < n) + { + var intsLeftBeforeTwist = N - _mti; + if (intsLeftBeforeTwist == 1) + { + Temper(new ReadOnlySpan(pMt + _mti, 1), new Span(pBuf, 1)); + _carry = pBuf[0]; + _mti += 1; + _hasCarry = true; + } + } + } + } + } + + public unsafe void NextTemperedUInt32(Span destination) + { + var n = destination.Length; + var filled = 0; + + if (_hasCarry && n != 0) + { + destination[filled++] = _carry; + _hasCarry = false; + } + + if (filled >= n) + return; + + fixed (uint* pMt = _mt) + fixed (uint* pBuf = _buf) + fixed (uint* pDst = destination) + { + while (filled < n) + { + if (_mti >= N) + Twist(); + + var avail = N - _mti; + + if (avail == 0) + { + Twist(); + continue; + } + + var toProduce = Math.Min(avail, n - filled); + + Temper(new ReadOnlySpan(pMt + _mti, toProduce), new Span(pBuf, toProduce)); + _mti += toProduce; + + new ReadOnlySpan(pBuf, toProduce).CopyTo(new Span(pDst + filled, toProduce)); + filled += toProduce; + } + } + } + + public uint NextTemperedUInt32() + { + if (_hasCarry) + { + _hasCarry = false; + return _carry; + } + + if (_mti >= N) + Twist(); + + var y = _mt[_mti++]; + y ^= (y >> 11); + y ^= (y << 7) & 0x9D2C5680u; + y ^= (y << 15) & 0xEFC60000u; + y ^= (y >> 18); + return y; + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/MersenneTwisterRandomSource.cs b/src/Microsoft.ML.Core/Utilities/MersenneTwisterRandomSource.cs new file mode 100644 index 0000000000..e5ed5f571d --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/MersenneTwisterRandomSource.cs @@ -0,0 +1,156 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML; + +namespace Microsoft.ML.Internal.Utilities +{ + internal sealed class MersenneTwisterRandomSource : IRandomSource, IRandomBulkSource + { + private readonly MersenneTwister _twister; + + public MersenneTwisterRandomSource(int seed) + { + _twister = new MersenneTwister((uint)seed); + } + + public int Next() + { + return (int)(_twister.NextTemperedUInt32() >> 1); + } + + public int Next(int maxValue) + { + if (maxValue <= 0) + throw new ArgumentOutOfRangeException(nameof(maxValue)); + + uint limit = (uint)maxValue; + uint threshold = uint.MaxValue - (uint.MaxValue % limit); + + uint sample; + do + { + sample = _twister.NextTemperedUInt32(); + } + while (sample >= threshold); + + return (int)(sample % limit); + } + + public int Next(int minValue, int maxValue) + { + if (minValue >= maxValue) + throw new ArgumentOutOfRangeException(nameof(minValue)); + + uint range = unchecked((uint)(maxValue - minValue)); + int offset = NextInt32InRange(range); + return minValue + offset; + } + + public long NextInt64() + { + return unchecked((long)NextUInt64()); + } + + public long NextInt64(long maxValue) + { + if (maxValue <= 0) + throw new ArgumentOutOfRangeException(nameof(maxValue)); + + return (long)NextUInt64InRange((ulong)maxValue); + } + + public long NextInt64(long minValue, long maxValue) + { + if (minValue >= maxValue) + throw new ArgumentOutOfRangeException(nameof(minValue)); + + ulong range = (ulong)(maxValue - minValue); + long offset = (long)NextUInt64InRange(range); + return minValue + offset; + } + + public double NextDouble() + { + return _twister.NextDouble(); + } + + public float NextSingle() + { + uint word = _twister.NextTemperedUInt32(); + return (word >> 8) * (1.0f / (1u << 24)); + } + + public void NextBytes(Span buffer) + { + int offset = 0; + + Span word = stackalloc uint[1]; + while (offset < buffer.Length) + { + _twister.NextTemperedUInt32(word); + uint value = word[0]; + + int bytesToCopy = Math.Min(4, buffer.Length - offset); + for (int i = 0; i < bytesToCopy; i++) + { + buffer[offset++] = (byte)value; + value >>= 8; + } + } + } + + public void NextDoubles(Span destination) + { + _twister.NextDoubles(destination); + } + + public void NextUInt32(Span destination) + { + _twister.NextTemperedUInt32(destination); + } + + private ulong NextUInt64() + { + Span words = stackalloc uint[2]; + _twister.NextTemperedUInt32(words); + return ((ulong)words[0] << 32) | words[1]; + } + + private int NextInt32InRange(uint range) + { + if (range == 0) + throw new ArgumentOutOfRangeException(nameof(range)); + + uint threshold = uint.MaxValue - (uint.MaxValue % range); + + uint sample; + do + { + sample = _twister.NextTemperedUInt32(); + } + while (sample >= threshold); + + return (int)(sample % range); + } + + private ulong NextUInt64InRange(ulong range) + { + if (range == 0) + throw new ArgumentOutOfRangeException(nameof(range)); + + ulong threshold = ulong.MaxValue - (ulong.MaxValue % range); + + ulong sample; + do + { + sample = NextUInt64(); + } + while (sample >= threshold); + + return sample % range; + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/Random.cs b/src/Microsoft.ML.Core/Utilities/Random.cs index 137ce5d45d..3db60bb9d4 100644 --- a/src/Microsoft.ML.Core/Utilities/Random.cs +++ b/src/Microsoft.ML.Core/Utilities/Random.cs @@ -162,7 +162,11 @@ private static uint GetSeed(Random rng) } } +#if NET8_0_OR_GREATER + public override float NextSingle() +#else public float NextSingle() +#endif { NextState(); return GetSingle(); diff --git a/src/Microsoft.ML.Core/Utilities/RandomFromRandomSource.cs b/src/Microsoft.ML.Core/Utilities/RandomFromRandomSource.cs new file mode 100644 index 0000000000..fff763612d --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/RandomFromRandomSource.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML; + +namespace Microsoft.ML.Internal.Utilities +{ + internal sealed class RandomFromRandomSource : Random + { + private readonly IRandomSource _source; + + public RandomFromRandomSource(IRandomSource source) + { + _source = source ?? throw new ArgumentNullException(nameof(source)); + } + + public override int Next() => _source.Next(); + + public override int Next(int maxValue) => _source.Next(maxValue); + + public override int Next(int minValue, int maxValue) => _source.Next(minValue, maxValue); + + public override double NextDouble() => _source.NextDouble(); + + public override void NextBytes(byte[] buffer) + { + if (buffer is null) + throw new ArgumentNullException(nameof(buffer)); + + _source.NextBytes(buffer); + } + + protected override double Sample() => _source.NextDouble(); + +#if NET6_0_OR_GREATER + public override void NextBytes(Span buffer) => _source.NextBytes(buffer); + public override float NextSingle() => _source.NextSingle(); + public override long NextInt64() => _source.NextInt64(); + public override long NextInt64(long maxValue) => _source.NextInt64(maxValue); + public override long NextInt64(long minValue, long maxValue) => _source.NextInt64(minValue, maxValue); +#endif + } +} diff --git a/src/Microsoft.ML.Core/Utilities/RandomShim.cs b/src/Microsoft.ML.Core/Utilities/RandomShim.cs new file mode 100644 index 0000000000..e306c19225 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/RandomShim.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Back-ports newer APIs for TFMs where they are unavailable. + /// + internal static class RandomShim + { +#if NET6_0_OR_GREATER + public static long NextInt64(Random random) => random.NextInt64(); + + public static long NextInt64(Random random, long maxValue) => random.NextInt64(maxValue); + + public static long NextInt64(Random random, long minValue, long maxValue) => random.NextInt64(minValue, maxValue); + + public static void NextBytes(Random random, Span buffer) => random.NextBytes(buffer); +#else + public static long NextInt64(Random random) => NextInt64(random, long.MinValue, long.MaxValue); + + public static long NextInt64(Random random, long maxValue) + { + if (maxValue <= 0) + throw new ArgumentOutOfRangeException(nameof(maxValue)); + + return NextInt64(random, 0, maxValue); + } + + public static long NextInt64(Random random, long minValue, long maxValue) + { + if (minValue >= maxValue) + throw new ArgumentOutOfRangeException(nameof(minValue)); + + ulong range = (ulong)(maxValue - minValue); + + while (true) + { + ulong sample = NextUInt64(random); + ulong threshold = ulong.MaxValue - (ulong.MaxValue % range); + if (sample < threshold) + return (long)(sample % range) + minValue; + } + } + + public static void NextBytes(Random random, Span buffer) + { + if (buffer.IsEmpty) + return; + + byte[] rented = ArrayPool.Shared.Rent(buffer.Length); + try + { + random.NextBytes(rented); + rented.AsSpan(0, buffer.Length).CopyTo(buffer); + } + finally + { + ArrayPool.Shared.Return(rented); + } + } + + private static ulong NextUInt64(Random random) + { + Span bytes = stackalloc byte[8]; + NextBytes(random, bytes); + + ulong value = 0; + for (int i = 0; i < bytes.Length; i++) + { + value |= (ulong)bytes[i] << (i * 8); + } + + return value; + } +#endif + } +} + diff --git a/src/Microsoft.ML.Core/Utilities/RandomSourceAdapter.cs b/src/Microsoft.ML.Core/Utilities/RandomSourceAdapter.cs new file mode 100644 index 0000000000..98d931267f --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/RandomSourceAdapter.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML; + +namespace Microsoft.ML.Internal.Utilities +{ + internal sealed class RandomSourceAdapter : IRandomSource + { + private readonly Random _random; + + public RandomSourceAdapter(Random random) + { + _random = random ?? throw new ArgumentNullException(nameof(random)); + } + + public int Next() => _random.Next(); + + public int Next(int maxValue) => _random.Next(maxValue); + + public int Next(int minValue, int maxValue) => _random.Next(minValue, maxValue); + + public long NextInt64() + { +#if NET6_0_OR_GREATER + return _random.NextInt64(); +#else + return RandomShim.NextInt64(_random); +#endif + } + + public long NextInt64(long maxValue) + { +#if NET6_0_OR_GREATER + return _random.NextInt64(maxValue); +#else + return RandomShim.NextInt64(_random, maxValue); +#endif + } + + public long NextInt64(long minValue, long maxValue) + { +#if NET6_0_OR_GREATER + return _random.NextInt64(minValue, maxValue); +#else + return RandomShim.NextInt64(_random, minValue, maxValue); +#endif + } + + public double NextDouble() => _random.NextDouble(); + + public float NextSingle() + { +#if NET6_0_OR_GREATER + return _random.NextSingle(); +#else + return RandomUtils.NextSingle(_random); +#endif + } + + public void NextBytes(Span buffer) + { +#if NET6_0_OR_GREATER + _random.NextBytes(buffer); +#else + RandomShim.NextBytes(_random, buffer); +#endif + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs index f50ecc9174..c0779237d4 100644 --- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs @@ -308,29 +308,57 @@ private async Task DownloadResource(IHostEnvironment env, IChannel ch /// The provided url to check public bool IsRedirectToDefaultPage(string url) { + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + return false; + + if (uri.IsFile) + return false; + + using var handler = new HttpClientHandler { AllowAutoRedirect = false }; + using var httpClient = new HttpClient(handler); + + static bool IsRedirectToDefault(HttpResponseMessage response) + { + if (response.StatusCode == HttpStatusCode.Redirect && response.Headers.Location is Uri location) + { + return string.Equals(location.AbsoluteUri, "https://www.microsoft.com/?ref=aka", StringComparison.OrdinalIgnoreCase); + } + + return false; + } + + using var headRequest = new HttpRequestMessage(HttpMethod.Head, uri); + + static HttpResponseMessage Send(HttpClient client, HttpRequestMessage request) + { + return client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, CancellationToken.None).GetAwaiter().GetResult(); + } + try { - var request = WebRequest.Create(url); - // FileWebRequests cannot be redirected to default aka.ms webpage - if (request.GetType() == typeof(FileWebRequest)) - return false; - HttpWebRequest httpWebRequest = (HttpWebRequest)request; - httpWebRequest.AllowAutoRedirect = false; - HttpWebResponse httpWebResponse = (HttpWebResponse)httpWebRequest.GetResponse(); + using var headResponse = Send(httpClient, headRequest); + if (headResponse.StatusCode == HttpStatusCode.MethodNotAllowed || headResponse.StatusCode == HttpStatusCode.NotImplemented) + { + using var getRequest = new HttpRequestMessage(HttpMethod.Get, uri); + using var getResponse = Send(httpClient, getRequest); + return IsRedirectToDefault(getResponse); + } + + return IsRedirectToDefault(headResponse); } - catch (WebException e) + catch (HttpRequestException) { - HttpWebResponse webResponse = (HttpWebResponse)e.Response; - // Redirects to default url - if (webResponse.StatusCode == HttpStatusCode.Redirect && webResponse.Headers["Location"] == "https://www.microsoft.com/?ref=aka") - return true; - // Redirects to another url - else if (webResponse.StatusCode == HttpStatusCode.MovedPermanently) - return false; - else + try + { + using var getRequest = new HttpRequestMessage(HttpMethod.Get, uri); + using var getResponse = Send(httpClient, getRequest); + return IsRedirectToDefault(getResponse); + } + catch (HttpRequestException) + { return false; + } } - return false; } public static ResourceDownloadResults GetErrorMessage(out string errorMessage, params ResourceDownloadResults[] result) diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index 4212b8a21f..dc098017ed 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Reflection; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; namespace Microsoft.ML @@ -22,6 +23,7 @@ public sealed class MLContext : IHostEnvironmentInternal { // REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation. private readonly LocalEnvironment _env; + private readonly IRandomSource _randomSource; /// /// Gets the trainers and tasks specific to binary classification problems. @@ -114,6 +116,7 @@ public int? GpuDeviceId set { _env.GpuDeviceId = value; } } + internal IRandomSource RandomSource => _randomSource; /// /// Create the ML context. /// @@ -143,8 +146,14 @@ public int? GpuDeviceId /// So, the predictions from a loaded model don't depend on the seed value. /// public MLContext(int? seed = null) + : this(seed, rng: null) { - _env = new LocalEnvironment(seed); + } + + public MLContext(int? seed, IRandomSource rng = null) + { + _env = rng is null ? new LocalEnvironment(seed) : new LocalEnvironment(seed, rng); + _randomSource = _env.RandomSource; _env.AddListener(ProcessMessage); BinaryClassification = new BinaryClassificationCatalog(_env); @@ -176,6 +185,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) IPipe IChannelProvider.StartPipe(string name) => _env.StartPipe(name); IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); int? IHostEnvironmentInternal.Seed => _env.Seed; + IRandomSource IHostEnvironmentInternal.RandomSource => _randomSource; [BestFriend] internal void CancelExecution() => ((ICancelable)_env).CancelExecution(); diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs index 9e8adf0ef9..bc9f3f649c 100644 --- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs +++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data @@ -46,8 +47,9 @@ protected override void Dispose(bool disposing) /// Create an ML.NET for local execution. /// /// Random seed. Set to null for a non-deterministic environment. - public LocalEnvironment(int? seed = null) - : base(seed, verbose: false) + /// Optional random source backing this environment. + public LocalEnvironment(int? seed = null, IRandomSource randomSource = null) + : base(seed, verbose: false, randomSource) { } @@ -63,13 +65,14 @@ public void AddListener(Action listener) public void RemoveListener(Action listener) => RemoveListener(listener); - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) { Contracts.AssertValue(rand); + Contracts.AssertValue(randomSource); Contracts.AssertValueOrNull(parentFullName); Contracts.AssertNonEmpty(shortName); Contracts.Assert(source == this || source is Host); - return new Host(source, shortName, parentFullName, rand, verbose); + return new Host(source, shortName, parentFullName, rand, randomSource, verbose); } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) @@ -90,8 +93,8 @@ protected override IPipe CreatePipe(ChannelProviderBase pare private sealed class Host : HostBase { - public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) - : base(source, shortName, parentFullName, rand, verbose) + public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) + : base(source, shortName, parentFullName, rand, randomSource, verbose) { IsCanceled = source.IsCanceled; } @@ -112,9 +115,9 @@ protected override IPipe CreatePipe(ChannelProviderBase pare return new Pipe(parent, name, GetDispatchDelegate()); } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose) { - return new Host(source, shortName, parentFullName, rand, verbose); + return new Host(source, shortName, parentFullName, rand, randomSource, verbose); } } } diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index 04df47c3bf..5211a8ae85 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -527,8 +527,9 @@ public void AutoFitRecommendationTest() // STEP 2: Run AutoML experiment try { + // Use a slightly larger time budget to reduce flakiness on slower hosts ExperimentResult experimentResult = mlContext.Auto() - .CreateRecommendationExperiment(5) + .CreateRecommendationExperiment(new RecommendationExperimentSettings { MaxExperimentTimeInSeconds = 10 }) .Execute(trainDataView, testDataView, new ColumnInformation() { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/AdditionalMetadataReferences.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/AdditionalMetadataReferences.cs index 8bcc0cc325..71b642371d 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/AdditionalMetadataReferences.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/AdditionalMetadataReferences.cs @@ -14,7 +14,10 @@ namespace Microsoft.ML.CodeAnalyzer.Tests.Helpers { internal static class AdditionalMetadataReferences { -#if NETCOREAPP +#if NET8_0_OR_GREATER + internal static readonly ReferenceAssemblies DefaultReferenceAssemblies = ReferenceAssemblies.Net.Net80 + .AddPackages(ImmutableArray.Create(new PackageIdentity("System.Memory", "4.5.1"))); +#elif NETCOREAPP internal static readonly ReferenceAssemblies DefaultReferenceAssemblies = ReferenceAssemblies.Default .AddPackages(ImmutableArray.Create(new PackageIdentity("System.Memory", "4.5.1"))); #else diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CSharpCodeFixVerifier`2.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CSharpCodeFixVerifier`2.cs index 351fb05d77..1c149ccf4a 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CSharpCodeFixVerifier`2.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CSharpCodeFixVerifier`2.cs @@ -9,7 +9,6 @@ using Microsoft.CodeAnalysis.CSharp.Testing; using Microsoft.CodeAnalysis.Diagnostics; using Microsoft.CodeAnalysis.Testing; -using Microsoft.CodeAnalysis.Testing.Verifiers; namespace Microsoft.ML.CodeAnalyzer.Tests.Helpers { @@ -18,15 +17,13 @@ internal static class CSharpCodeFixVerifier where TCodeFix : CodeFixProvider, new() { public static DiagnosticResult Diagnostic() -#pragma warning disable CS0618 // Type or member is obsolete - => CSharpCodeFixVerifier.Diagnostic(); + => CSharpCodeFixVerifier.Diagnostic(); public static DiagnosticResult Diagnostic(string diagnosticId) - => CSharpCodeFixVerifier.Diagnostic(diagnosticId); + => CSharpCodeFixVerifier.Diagnostic(diagnosticId); public static DiagnosticResult Diagnostic(DiagnosticDescriptor descriptor) - => CSharpCodeFixVerifier.Diagnostic(descriptor); -#pragma warning restore CS0618 // Type or member is obsolete + => CSharpCodeFixVerifier.Diagnostic(descriptor); public static async Task VerifyAnalyzerAsync(string source, params DiagnosticResult[] expected) { @@ -57,9 +54,7 @@ public static async Task VerifyCodeFixAsync(string source, DiagnosticResult[] ex await test.RunAsync(); } -#pragma warning disable CS0618 // Type or member is obsolete - internal class Test : CSharpCodeFixTest -#pragma warning restore CS0618 // Type or member is obsolete + internal class Test : CSharpCodeFixTest { public Test() { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CompatibleXUnitVerifier.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CompatibleXUnitVerifier.cs new file mode 100644 index 0000000000..e9a5c4a1ec --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CompatibleXUnitVerifier.cs @@ -0,0 +1,224 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Testing; +using Xunit; +using Xunit.Sdk; + +namespace Microsoft.ML.CodeAnalyzer.Tests.Helpers +{ + /// + /// A drop-in replacement for xUnit.net's default verifier that composes nicely with xUnit.net 2.8+. + /// + internal class CompatibleXUnitVerifier : IVerifier + { + public CompatibleXUnitVerifier() + : this(ImmutableStack.Empty) + { + } + + private CompatibleXUnitVerifier(ImmutableStack context) + { + Context = context ?? throw new ArgumentNullException(nameof(context)); + } + + private ImmutableStack Context { get; } + + public void Empty(string collectionName, IEnumerable collection) + { + using var enumerator = collection.GetEnumerator(); + if (enumerator.MoveNext()) + { + throw new XunitException(ComposeMessage($"'{collectionName}' is not empty")); + } + } + + public void Equal(T expected, T actual, string? message = null) + { + if (message is null && Context.IsEmpty) + { + Assert.Equal(expected, actual); + return; + } + + if (EqualityComparer.Default.Equals(expected, actual)) + { + return; + } + + throw new XunitException(ComposeMessageWithDetails(message, BuildEqualityMessage(expected, actual))); + } + +#if NETCOREAPP + public void True([DoesNotReturnIf(false)] bool assert, string? message = null) +#else + public void True(bool assert, string? message = null) +#endif + { + if (message is null && Context.IsEmpty) + { + Assert.True(assert); + } + else + { + Assert.True(assert, ComposeMessage(message)); + } + } + +#if NETCOREAPP + public void False([DoesNotReturnIf(true)] bool assert, string? message = null) +#else + public void False(bool assert, string? message = null) +#endif + { + if (message is null && Context.IsEmpty) + { + Assert.False(assert); + } + else + { + Assert.False(assert, ComposeMessage(message)); + } + } + +#if NETCOREAPP + [DoesNotReturn] +#endif +#if !NETCOREAPP +#pragma warning disable CS8770 // Attribute unavailable on this target. +#endif + public void Fail(string? message = null) + => throw new XunitException(ComposeMessage(message)); +#if !NETCOREAPP +#pragma warning restore CS8770 +#endif + + public void LanguageIsSupported(string language) + { + if (language != LanguageNames.CSharp && language != LanguageNames.VisualBasic) + { + throw new XunitException(ComposeMessage($"Unsupported Language: '{language}'")); + } + } + + public void NotEmpty(string collectionName, IEnumerable collection) + { + using var enumerator = collection.GetEnumerator(); + if (!enumerator.MoveNext()) + { + throw new XunitException(ComposeMessage($"'{collectionName}' is empty")); + } + } + + public void SequenceEqual(IEnumerable expected, IEnumerable actual, IEqualityComparer? equalityComparer = null, string? message = null) + { + var comparer = new SequenceEqualEnumerableEqualityComparer(equalityComparer); + if (comparer.Equals(expected, actual)) + { + return; + } + + throw new XunitException(ComposeMessageWithDetails(message, BuildEqualityMessage(expected?.ToArray(), actual?.ToArray()))); + } + + public IVerifier PushContext(string context) + => new CompatibleXUnitVerifier(Context.Push(context)); + + private string ComposeMessage(string? message) + { + foreach (var frame in Context) + { + message = "Context: " + frame + Environment.NewLine + message; + } + + return message ?? string.Empty; + } + + private string ComposeMessageWithDetails(string? message, string details) + { + var baseMessage = ComposeMessage(message); + if (string.IsNullOrEmpty(baseMessage)) + { + return details; + } + + return baseMessage + Environment.NewLine + details; + } + + private static string BuildEqualityMessage(TExpected expected, TActual actual) + { + try + { + Assert.Equal((object?)expected, (object?)actual); + } + catch (XunitException ex) + { + return ex.Message; + } + + return "Values are not equal."; + } + + private sealed class SequenceEqualEnumerableEqualityComparer : IEqualityComparer?> + { + private readonly IEqualityComparer _itemComparer; + + public SequenceEqualEnumerableEqualityComparer(IEqualityComparer? itemComparer) + { + _itemComparer = itemComparer ?? EqualityComparer.Default; + } + + public bool Equals(IEnumerable? x, IEnumerable? y) + { + if (ReferenceEquals(x, y)) + { + return true; + } + + if (x is null || y is null) + { + return false; + } + + using var enumeratorX = x.GetEnumerator(); + using var enumeratorY = y.GetEnumerator(); + + while (true) + { + var hasX = enumeratorX.MoveNext(); + var hasY = enumeratorY.MoveNext(); + + if (!hasX || !hasY) + { + return hasX == hasY; + } + + if (!_itemComparer.Equals(enumeratorX.Current, enumeratorY.Current)) + { + return false; + } + } + } + + public int GetHashCode(IEnumerable? obj) + { + if (obj is null) + { + return 0; + } + + return obj.Select(item => _itemComparer.GetHashCode(item!)) + .Aggregate(0, (agg, next) => ((agg << 5) + agg) ^ next); + } + } + } +} diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/MersenneTwisterTests.cs b/test/Microsoft.ML.Core.Tests/UnitTests/MersenneTwisterTests.cs new file mode 100644 index 0000000000..8ed4cd9364 --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/MersenneTwisterTests.cs @@ -0,0 +1,141 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.RunTests; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Core.Tests.UnitTests +{ + public sealed class MersenneTwisterTests : BaseTestBaseline + { + public MersenneTwisterTests(ITestOutputHelper output) + : base(output) + { + } + + [Fact] + [TestCategory("Utilities")] + public void MixedApiCallsConsumeTemperedSequenceWithoutGaps() + { + const uint seed = 5489u; + + var baseline = new uint[32]; + var baselineTwister = new MersenneTwister(seed); + baselineTwister.NextTemperedUInt32(baseline); + + var twister = new MersenneTwister(seed); + var index = 0; + + Assert.Equal(baseline[index++], twister.NextTemperedUInt32()); + + var firstDouble = twister.NextDouble(); + Assert.Equal(ToDoubleFromTempered(baseline[index], baseline[index + 1]), firstDouble); + index += 2; + + var buffer = new uint[5]; + twister.NextTemperedUInt32(buffer); + Assert.Equal(Slice(baseline, index, buffer.Length), buffer); + index += buffer.Length; + + Assert.Equal(baseline[index++], twister.NextTemperedUInt32()); + + var secondDouble = twister.NextDouble(); + Assert.Equal(ToDoubleFromTempered(baseline[index], baseline[index + 1]), secondDouble); + index += 2; + + var secondBuffer = new uint[3]; + twister.NextTemperedUInt32(secondBuffer); + Assert.Equal(Slice(baseline, index, secondBuffer.Length), secondBuffer); + index += secondBuffer.Length; + + var thirdDouble = twister.NextDouble(); + Assert.Equal(ToDoubleFromTempered(baseline[index], baseline[index + 1]), thirdDouble); + index += 2; + + var thirdBuffer = new uint[4]; + twister.NextTemperedUInt32(thirdBuffer); + Assert.Equal(Slice(baseline, index, thirdBuffer.Length), thirdBuffer); + index += thirdBuffer.Length; + + Assert.Equal(baseline[index++], twister.NextTemperedUInt32()); + + Assert.True(index <= baseline.Length); + } + + [Fact] + [TestCategory("Utilities")] + public void ProducesExpectedSequencesForDeterministicSeed() + { + const uint seed = 5489u; + + var expectedDoubles = new[] + { + 0.8147236863931789, + 0.9057919370756192, + 0.12698681629350606, + 0.9133758561390194, + 0.6323592462254095, + 0.09754040499940952, + 0.2784982188670484, + 0.5468815192049838, + 0.9575068354342976, + 0.9648885351992765, + }; + + var expectedTempered = new uint[] + { + 3499211612, + 581869302, + 3890346734, + 3586334585, + 545404204, + 4161255391, + 3922919429, + 949333985, + 2715962298, + 1323567403, + }; + + var doubleTwister = new MersenneTwister(seed); + var actualDoubles = new double[expectedDoubles.Length]; + for (var i = 0; i < actualDoubles.Length; i++) + { + actualDoubles[i] = doubleTwister.NextDouble(); + } + + var uintTwister = new MersenneTwister(seed); + var actualTempered = new uint[expectedTempered.Length]; + for (var i = 0; i < actualTempered.Length; i++) + { + actualTempered[i] = uintTwister.NextTemperedUInt32(); + } + + for (var i = 0; i < expectedDoubles.Length; i++) + { + Assert.Equal(expectedDoubles[i], actualDoubles[i], precision: 15); + } + + Assert.Equal(expectedTempered, actualTempered); + } + + private static uint[] Slice(uint[] source, int start, int length) + { + var result = new uint[length]; + Array.Copy(source, start, result, 0, length); + return result; + } + + private static double ToDoubleFromTempered(uint first, uint second) + { + var a = (ulong)(first >> 5); + var b = (ulong)(second >> 6); + var mantissa = (a << 26) | b; + const double inverseTwo53 = 1.0 / 9007199254740992.0; + return mantissa * inverseTwo53; + } + } +} diff --git a/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs index ef2bcdce78..0e8e7cb2fa 100644 --- a/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs @@ -195,7 +195,10 @@ private static void Analyze(SyntaxNodeAnalysisContext context) var symbolInfo = context.SemanticModel.GetSymbolInfo(invocation); if (!(symbolInfo.Symbol is IMethodSymbol methodSymbol)) return; - var containingSymbolName = methodSymbol.ContainingSymbol.ToString(); + var containingType = methodSymbol.ContainingType; + if (containingType == null) + return; + var containingSymbolName = containingType.ToDisplayString(); // The "internal" version is one used by some projects that want to benefit from Contracts, // but for some reason cannot reference MLCore. // Contract functions defined Microsoft.ML.Internal.CpuMath.Core are introduced for breaking the dependencies From fa103f41a2e2940b988a3653c8e1459fd13a0394 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Sun, 12 Oct 2025 13:03:00 +0700 Subject: [PATCH 2/9] Add APICompat suppression for IHost.RandomSource (CP0006) to allow 5.0.0 pack vs 4.0.0 baseline --- src/Microsoft.ML/CompatibilitySuppressions.xml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 src/Microsoft.ML/CompatibilitySuppressions.xml diff --git a/src/Microsoft.ML/CompatibilitySuppressions.xml b/src/Microsoft.ML/CompatibilitySuppressions.xml new file mode 100644 index 0000000000..d0522301d5 --- /dev/null +++ b/src/Microsoft.ML/CompatibilitySuppressions.xml @@ -0,0 +1,11 @@ + + + + + CP0006 + P:Microsoft.ML.Runtime.IHost.RandomSource + lib/netstandard2.0/Microsoft.ML.Core.dll + lib/netstandard2.0/Microsoft.ML.Core.dll + true + + \ No newline at end of file From 60c6bdab7e04d2f3a2cd2a3a160190081ed296af Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Sun, 12 Oct 2025 14:04:33 +0700 Subject: [PATCH 3/9] Tests: cover MersenneTwisterRandomSource, RandomSourceAdapter, RandomFromRandomSource, and ResourceManagerUtils logic; fix Windows native cmake script; add APICompat suppression --- .../UnitTests/RandomSourceTests.cs | 269 ++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 test/Microsoft.ML.Core.Tests/UnitTests/RandomSourceTests.cs diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/RandomSourceTests.cs b/test/Microsoft.ML.Core.Tests/UnitTests/RandomSourceTests.cs new file mode 100644 index 0000000000..55c0c0295d --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/RandomSourceTests.cs @@ -0,0 +1,269 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.RunTests; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Core.Tests.UnitTests +{ + public sealed class RandomSourceTests : BaseTestBaseline + { + public RandomSourceTests(ITestOutputHelper output) : base(output) { } + [Fact] + [TestCategory("Utilities")] + public void MersenneTwisterRandomSource_ReproducibleAndRanges() + { + var s1 = new MersenneTwisterRandomSource(12345); + var s2 = new MersenneTwisterRandomSource(12345); + + // Next() in [0, int.MaxValue) + for (int i = 0; i < 10; i++) + { + var a = s1.Next(); + var b = s2.Next(); + Assert.Equal(a, b); + Assert.InRange(a, 0, int.MaxValue - 1); + } + + // Next(max) + Assert.Throws(() => s1.Next(0)); + for (int i = 1; i <= 5; i++) + { + var a = s1.Next(i); + var b = s2.Next(i); + Assert.Equal(a, b); + Assert.InRange(a, 0, i - 1); + } + + // Next(min,max) + Assert.Throws(() => s1.Next(5, 5)); + for (int i = 0; i < 5; i++) + { + var a = s1.Next(-10, 10); + var b = s2.Next(-10, 10); + Assert.Equal(a, b); + Assert.InRange(a, -10, 9); + } + + // NextDouble and NextSingle in [0,1) + for (int i = 0; i < 8; i++) + { + var da = s1.NextDouble(); + var db = s2.NextDouble(); + Assert.Equal(da, db); + Assert.InRange(da, 0.0, 1.0 - double.Epsilon); + } + + for (int i = 0; i < 8; i++) + { + var fa = s1.NextSingle(); + var fb = s2.NextSingle(); + Assert.Equal(fa, fb); + Assert.InRange(fa, 0.0f, 1.0f); + } + + // Int64 variants + Assert.Throws(() => s1.NextInt64(0)); + Assert.Throws(() => s1.NextInt64(5, 5)); + + for (int i = 0; i < 5; i++) + { + var a = s1.NextInt64(); + var b = s2.NextInt64(); + Assert.Equal(a, b); + } + + for (int i = 0; i < 5; i++) + { + var a = s1.NextInt64(1000); + var b = s2.NextInt64(1000); + Assert.Equal(a, b); + Assert.InRange(a, 0, 999); + } + + for (int i = 0; i < 5; i++) + { + var a = s1.NextInt64(-123, 456); + var b = s2.NextInt64(-123, 456); + Assert.Equal(a, b); + Assert.InRange(a, -123, 455); + } + + // NextBytes and bulk APIs + var buf1 = new byte[13]; + var buf2 = new byte[13]; + s1.NextBytes(buf1); + s2.NextBytes(buf2); + Assert.Equal(buf1, buf2); + + var doubles1 = new double[7]; + var doubles2 = new double[7]; + s1.NextDoubles(doubles1); + s2.NextDoubles(doubles2); + Assert.Equal(doubles1, doubles2); + foreach (var d in doubles1) + Assert.InRange(d, 0.0, 1.0); + + var u321 = new uint[9]; + var u322 = new uint[9]; + s1.NextUInt32(u321); + s2.NextUInt32(u322); + Assert.Equal(u321, u322); + } + + [Fact] + [TestCategory("Utilities")] + public void RandomSourceAdapter_Matches_SystemRandom() + { + const int seed = 777; + var a1 = new RandomSourceAdapter(new Random(seed)); + var a2 = new RandomSourceAdapter(new Random(seed)); + + for (int i = 0; i < 5; i++) Assert.Equal(a1.Next(), a2.Next()); + for (int i = 1; i <= 5; i++) Assert.Equal(a1.Next(i), a2.Next(i)); + for (int i = 0; i < 5; i++) Assert.Equal(a1.Next(-50, 50), a2.Next(-50, 50)); + + Assert.Equal(a1.NextDouble(), a2.NextDouble()); + + var bytesA = new byte[17]; + var bytesB = new byte[17]; + a1.NextBytes(bytesA); + a2.NextBytes(bytesB); + Assert.Equal(bytesA, bytesB); + +#if NET6_0_OR_GREATER + a1 = new RandomSourceAdapter(new Random(seed)); + a2 = new RandomSourceAdapter(new Random(seed)); + Assert.Equal(a1.NextSingle(), a2.NextSingle()); + Assert.Equal(a1.NextInt64(), a2.NextInt64()); + Assert.Equal(a1.NextInt64(1000), a2.NextInt64(1000)); + Assert.Equal(a1.NextInt64(-5, 7), a2.NextInt64(-5, 7)); + + a1 = new RandomSourceAdapter(new Random(seed)); + a2 = new RandomSourceAdapter(new Random(seed)); + var spanA = new byte[8]; + var spanB = new byte[8]; + a1.NextBytes(spanA); + a2.NextBytes(spanB); + Assert.Equal(spanA, spanB); +#endif + } + + [Fact] + [TestCategory("Utilities")] + public void RandomFromRandomSource_Matches_Source() + { + // Use MT source to ensure deterministic test across TFMs + var srcForRandom = new MersenneTwisterRandomSource(9876); + var srcForCompare = new MersenneTwisterRandomSource(9876); + var random = new RandomFromRandomSource(srcForRandom); + + for (int i = 0; i < 5; i++) + Assert.Equal(srcForCompare.Next(), random.Next()); + + for (int i = 1; i <= 5; i++) + Assert.Equal(srcForCompare.Next(i), random.Next(i)); + + for (int i = 0; i < 5; i++) + Assert.Equal(srcForCompare.Next(-10, 10), random.Next(-10, 10)); + + var bytesA = new byte[21]; + var bytesB = new byte[21]; + srcForCompare.NextBytes(bytesA); + random.NextBytes(bytesB); + Assert.Equal(bytesA, bytesB); + + Assert.Equal(srcForCompare.NextDouble(), random.NextDouble()); + +#if NET6_0_OR_GREATER + // Span-based NextBytes and newer APIs + var sA = new byte[5]; + var sB = new byte[5]; + srcForCompare.NextBytes(sA); + random.NextBytes(sB); + Assert.Equal(sA, sB); + + Assert.Equal(srcForCompare.NextSingle(), random.NextSingle()); + Assert.Equal(srcForCompare.NextInt64(), random.NextInt64()); + Assert.Equal(srcForCompare.NextInt64(1000), random.NextInt64(1000)); + Assert.Equal(srcForCompare.NextInt64(-20, 33), random.NextInt64(-20, 33)); +#endif + } + + [Fact] + [TestCategory("Utilities")] + public void ResourceManagerUtils_BuildsUrl_And_ErrorMessages() + { + // Preserve env and restore on exit + var oldBase = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + try + { + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, null); + var url = ResourceManagerUtils.GetUrl("foo/bar.txt"); + Assert.StartsWith("https://aka.ms/mlnet-resources/", url); + Assert.EndsWith("foo/bar.txt", url); + + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, "https://example.com/custom/"); + url = ResourceManagerUtils.GetUrl("a/b"); + Assert.Equal("https://example.com/custom/a/b", url); + + // GetErrorMessage formatting + var r1 = new ResourceManagerUtils.ResourceDownloadResults("file", "some error"); + var r2 = new ResourceManagerUtils.ResourceDownloadResults("file", "other error", "https://host/resource"); + + var first = ResourceManagerUtils.GetErrorMessage(out var msg1, r1, r2); + Assert.Same(r1, first); + Assert.Contains("Error downloading resource:", msg1); + + first = ResourceManagerUtils.GetErrorMessage(out var msg2, r2, r1); + Assert.Same(r2, first); + Assert.Contains("Error downloading resource from", msg2); + + // IsRedirectToDefaultPage returns false for file URIs and invalid absolute + Assert.False(ResourceManagerUtils.Instance.IsRedirectToDefaultPage("file:///C:/temp/x")); + Assert.False(ResourceManagerUtils.Instance.IsRedirectToDefaultPage("not a uri")); + } + finally + { + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, oldBase); + } + } + + [Fact] + [TestCategory("Utilities")] + public async System.Threading.Tasks.Task ResourceManagerUtils_Throws_For_NonAkaHost() + { + // Force downloads into a temp directory to avoid touching AppData + var oldPath = Environment.GetEnvironmentVariable(Utils.CustomSearchDirEnvVariable); + var oldBase = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + var tmp = Path.Combine(Path.GetTempPath(), "mlnet-test-resources", Guid.NewGuid().ToString("N")); + + try + { + Directory.CreateDirectory(tmp); + Environment.SetEnvironmentVariable(Utils.CustomSearchDirEnvVariable, tmp); + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, "https://example.com/base/"); + + using var swOut = new StringWriter(); + using var swErr = new StringWriter(); + var env = new ConsoleEnvironment(1, outWriter: swOut, errWriter: swErr); + using var ch = env.Start("test"); + + await Assert.ThrowsAsync(() => + ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "rel/path", "file.bin", "subdir", timeout: 1000)); + } + finally + { + Environment.SetEnvironmentVariable(Utils.CustomSearchDirEnvVariable, oldPath); + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, oldBase); + try { if (Directory.Exists(tmp)) Directory.Delete(tmp, recursive: true); } catch { } + } + } + } +} From f1930f7dc04268c8e6ef59f936fd7c727eefcdba Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Sun, 12 Oct 2025 15:45:20 +0700 Subject: [PATCH 4/9] chore: trigger Azure Pipelines From e0eed95a66e58635fab39196065a04ca95627ade Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Sun, 12 Oct 2025 17:49:53 +0700 Subject: [PATCH 5/9] tests(Core): remove ProjectReference to Microsoft.ML.Tests to eliminate cross-test DLL/PDB contention under coverlet --- test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index 8484d27c9a..6292a34e78 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -31,7 +31,6 @@ - From 7001ec75f7e8fd8deea8424ffab8cb80881408b0 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Mon, 13 Oct 2025 03:52:36 +0700 Subject: [PATCH 6/9] chore: retrigger Azure Pipelines From 0905638073ddb30a9fea0186b00b4804aeb24b75 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Mon, 13 Oct 2025 10:52:07 +0700 Subject: [PATCH 7/9] ResourceManagerUtils: make download mutex robust (unique name per path, handle abandoned, correct ownership release) to deflake concurrent model downloads in CI --- .../Utilities/ResourceManagerUtils.cs | 105 ++++++++++++------ 1 file changed, 68 insertions(+), 37 deletions(-) diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs index c0779237d4..213c230b81 100644 --- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs @@ -251,58 +251,89 @@ private async Task DownloadResource(IHostEnvironment env, IChannel ch if (File.Exists(path)) return null; - var mutex = new Mutex(false, "Resource" + fileName); - mutex.WaitOne(); - if (File.Exists(path)) - { - mutex.ReleaseMutex(); - return null; - } - - Guid guid = Guid.NewGuid(); - string tempPath = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(path), "temp-resource-" + guid.ToString())); + // Serialize access across processes/threads to avoid racing on the same file. + // Use a name derived from the absolute target path to avoid clashes across different directories. + string mutexName = GetSafeMutexName(path); + using var mutex = new Mutex(false, mutexName); + bool lockTaken = false; + string tempPath = null; try { - int blockSize = 4096; - - var response = await httpClient.GetAsync(uri, ct).ConfigureAwait(false); - using (var fh = env.CreateOutputFile(tempPath)) - using (var ws = fh.CreateWriteStream()) + try + { + lockTaken = mutex.WaitOne(); + } + catch (AbandonedMutexException) { - response.EnsureSuccessStatusCode(); - IEnumerable headers; - var hasHeader = response.Headers.TryGetValues("content-length", out headers); - if (uri.Host == "aka.ms" && IsRedirectToDefaultPage(uri.AbsoluteUri)) - throw new NotSupportedException($"The provided url ({uri}) redirects to the default url ({DefaultUrl})"); - if (!hasHeader || !long.TryParse(headers.First(), out var size)) - size = 10000000; + // The previous owner terminated without releasing. We now own the mutex; proceed safely. + lockTaken = true; + } - var stream = await response.EnsureSuccessStatusCode().Content.ReadAsStreamAsync().ConfigureAwait(false); + // Another process may have completed the download while we were waiting. + if (File.Exists(path)) + { + return null; + } - await stream.CopyToAsync(ws, blockSize, ct); + Guid guid = Guid.NewGuid(); + tempPath = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(path), "temp-resource-" + guid.ToString())); + try + { + int blockSize = 4096; - if (ct.IsCancellationRequested) + var response = await httpClient.GetAsync(uri, ct).ConfigureAwait(false); + using (var fh = env.CreateOutputFile(tempPath)) + using (var ws = fh.CreateWriteStream()) { - ch.Error($"{fileName}: Download timed out"); - return ch.Except("Download timed out"); + response.EnsureSuccessStatusCode(); + IEnumerable headers; + var hasHeader = response.Headers.TryGetValues("content-length", out headers); + if (uri.Host == "aka.ms" && IsRedirectToDefaultPage(uri.AbsoluteUri)) + throw new NotSupportedException($"The provided url ({uri}) redirects to the default url ({DefaultUrl})"); + if (!hasHeader || !long.TryParse(headers.First(), out var size)) + size = 10000000; + + var stream = await response.EnsureSuccessStatusCode().Content.ReadAsStreamAsync().ConfigureAwait(false); + + await stream.CopyToAsync(ws, blockSize, ct); + + if (ct.IsCancellationRequested) + { + ch.Error($"{fileName}: Download timed out"); + return ch.Except("Download timed out"); + } } + File.Move(tempPath, path); + ch.Info($"{fileName}: Download complete"); + return null; + } + catch (WebException e) + { + ch.Error($"{fileName}: Could not download. HttpClient returned the following error: {e.Message}"); + return e; } - File.Move(tempPath, path); - ch.Info($"{fileName}: Download complete"); - return null; - } - catch (WebException e) - { - ch.Error($"{fileName}: Could not download. HttpClient returned the following error: {e.Message}"); - return e; } finally { - TryDelete(ch, tempPath, warn: false); - mutex.ReleaseMutex(); + if (!string.IsNullOrEmpty(tempPath)) + TryDelete(ch, tempPath, warn: false); + if (lockTaken) + { + try { mutex.ReleaseMutex(); } catch { /* ignore if not owned */ } + } } } + private static string GetSafeMutexName(string path) + { + // Named system mutexes have platform-specific constraints; keep it simple and portable. + // Derive a unique, stable identifier from the absolute path and sanitize invalid characters. + string name = path ?? string.Empty; + name = name.Replace('\\', '_').Replace('/', '_').Replace(':', '_'); + // Prefix to avoid collisions with other applications. + return $"MLNET_Resource_{name}"; + } + /// This method checks whether or not the provided aka.ms url redirects to /// Microsoft's homepage, as the default faulty aka.ms URLs redirect to https://www.microsoft.com/?ref=aka . /// The provided url to check From b911dcad810269f8be7d6af0f123c2c8c23664c5 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Mon, 13 Oct 2025 10:58:44 +0700 Subject: [PATCH 8/9] chore: trigger Azure Pipelines From bfd8543715c17a87b08b713ebad75c085f7ed171 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Tue, 14 Oct 2025 20:15:00 +0700 Subject: [PATCH 9/9] chore: trigger Azure Pipelines