Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML;

namespace Microsoft.ML.Runtime;

Expand Down Expand Up @@ -115,6 +116,11 @@ internal interface IHostEnvironmentInternal : IHostEnvironment
T GetOptionOrDefault<T>(string name);

bool RemoveOption(string name);

/// <summary>
/// Global random source underpinning this environment.
/// </summary>
IRandomSource RandomSource { get; }
}

/// <summary>
Expand All @@ -129,6 +135,11 @@ public interface IHost : IHostEnvironment
/// generators are NOT thread safe.
/// </summary>
Random Rand { get; }

/// <summary>
/// The random source backing <see cref="Rand"/>.
/// </summary>
IRandomSource RandomSource { get; }
}

/// <summary>
Expand Down
18 changes: 11 additions & 7 deletions src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.ML;

namespace Microsoft.ML.Runtime;

Expand Down Expand Up @@ -359,14 +360,16 @@ protected override void Dispose(bool disposing)
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
/// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
/// <param name="randomSource">Optional random source backing this environment.</param>
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
/// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
public ConsoleEnvironment(int? seed = null, bool verbose = false,
IRandomSource randomSource = null,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: base(seed, verbose, nameof(ConsoleEnvironment))
: base(seed, verbose, randomSource, nameof(ConsoleEnvironment))
{
Contracts.CheckValueOrNull(outWriter);
Contracts.CheckValueOrNull(errWriter);
Expand All @@ -391,13 +394,14 @@ private void PrintMessage(IMessageSource src, ChannelMessage msg)
Root._consoleWriter.PrintMessage(src, msg);
}

protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose)
{
Contracts.AssertValue(rand);
Contracts.AssertValue(randomSource);
Contracts.AssertValueOrNull(parentFullName);
Contracts.AssertNonEmpty(shortName);
Contracts.Assert(source == this || source is Host);
return new Host(source, shortName, parentFullName, rand, verbose);
return new Host(source, shortName, parentFullName, rand, randomSource, verbose);
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
Expand Down Expand Up @@ -462,8 +466,8 @@ public void Dispose()

private sealed class Host : HostBase
{
public Host(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, shortName, parentFullName, rand, verbose)
public Host(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose)
: base(source, shortName, parentFullName, rand, randomSource, verbose)
{
IsCanceled = source.IsCanceled;
}
Expand All @@ -484,9 +488,9 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
}

protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose)
{
return new Host(source, shortName, parentFullName, rand, verbose);
return new Host(source, shortName, parentFullName, rand, randomSource, verbose);
}
}
}
35 changes: 27 additions & 8 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Internal.Utilities;

namespace Microsoft.ML.Runtime;

Expand Down Expand Up @@ -121,8 +123,8 @@ public abstract class HostBase : HostEnvironmentBase<TEnv>, IHost

public Random Rand => _rand;

public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, rand, verbose, shortName, parentFullName)
public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string parentFullName, Random rand, IRandomSource randomSource, bool verbose)
: base(source, rand, randomSource ?? new RandomSourceAdapter(rand), verbose, shortName, parentFullName)
{
Depth = source.Depth + 1;
}
Expand All @@ -140,7 +142,8 @@ public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string paren
{
_children.RemoveAll(r => r.TryGetTarget(out IHost _) == false);
Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
IRandomSource randomSource = new RandomSourceAdapter(rand);
host = RegisterCore(this, name, Master?.FullName, rand, randomSource, verbose ?? Verbose);
if (!IsCanceled)
_children.Add(new WeakReference<IHost>(host));
}
Expand Down Expand Up @@ -338,6 +341,8 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
protected Dictionary<string, object> Options { get; } = [];
#pragma warning restore MSML_NoInstanceInitializers

public IRandomSource RandomSource => _randomSource;

protected readonly TEnv Root;
// This is non-null iff this environment was a fork of another. Disposing a fork
// doesn't free temp files. That is handled when the master is disposed.
Expand All @@ -348,6 +353,7 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)

// The random number generator for this host.
private readonly Random _rand;
private readonly IRandomSource _randomSource;

public int? Seed { get; }

Expand All @@ -369,11 +375,22 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
/// The main constructor.
/// </summary>
protected HostEnvironmentBase(int? seed, bool verbose,
IRandomSource randomSource = null,
string shortName = null, string parentFullName = null)
: base(shortName, parentFullName, verbose)
{
Seed = seed;
_rand = RandomUtils.Create(Seed);
if (randomSource is null)
{
var baseRandom = RandomUtils.Create(Seed);
_rand = baseRandom;
_randomSource = new RandomSourceAdapter(baseRandom);
}
else
{
_randomSource = randomSource;
_rand = randomSource as Random ?? new RandomFromRandomSource(randomSource);
}
ListenerDict = new ConcurrentDictionary<Type, Dispatcher>();
ProgressTracker = new ProgressReporting.ProgressTracker(this);
_cancelLock = new object();
Expand All @@ -385,13 +402,14 @@ protected HostEnvironmentBase(int? seed, bool verbose,
/// <summary>
/// This constructor is for forking.
/// </summary>
protected HostEnvironmentBase(HostEnvironmentBase<TEnv> source, Random rand, bool verbose,
protected HostEnvironmentBase(HostEnvironmentBase<TEnv> source, Random rand, IRandomSource randomSource, bool verbose,
string shortName = null, string parentFullName = null)
: base(shortName, parentFullName, verbose)
{
Contracts.CheckValue(source, nameof(source));
Contracts.CheckValueOrNull(rand);
_rand = rand ?? RandomUtils.Create();
_randomSource = randomSource ?? (rand != null ? new RandomSourceAdapter(rand) : new RandomSourceAdapter(RandomUtils.Create()));
_rand = rand ?? (_randomSource as Random ?? new RandomFromRandomSource(_randomSource));
_cancelLock = new object();

// This fork shares some stuff with the master.
Expand Down Expand Up @@ -419,7 +437,8 @@ public IHost Register(string name, int? seed = null, bool? verbose = null)
{
_children.RemoveAll(r => r.TryGetTarget(out IHost _) == false);
Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
IRandomSource randomSource = new RandomSourceAdapter(rand);
host = RegisterCore(this, name, Master?.FullName, rand, randomSource, verbose ?? Verbose);

// Need to manually copy over the parameters
//((IHostEnvironmentInternal)host).Seed = this.Seed;
Expand All @@ -433,7 +452,7 @@ public IHost Register(string name, int? seed = null, bool? verbose = null)
}

protected abstract IHost RegisterCore(HostEnvironmentBase<TEnv> source, string shortName,
string parentFullName, Random rand, bool verbose);
string parentFullName, Random rand, IRandomSource randomSource, bool verbose);

public IProgressChannel StartProgressChannel(string name)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ο»Ώ<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFrameworks>netstandard2.0;net8.0</TargetFrameworks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>CORECLR</DefineConstants>
<RootNamespace>Microsoft.ML</RootNamespace>
Expand Down
22 changes: 16 additions & 6 deletions src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,26 @@ public static FuncInstanceMethodInfo1<TTarget, TResult> Create(Expression<Func<T
// Verify that we are creating a delegate of type Func<TRet>
Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form");
Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func<TResult>), nameof(expression), "Unexpected expression form");
var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0];
Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form");
if (delegateTypeExpression.Value is not Type delegateType)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(delegateType == typeof(Func<TResult>), nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form");

// Check the MethodInfo
Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form");

var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value;
if (methodCallExpression.Object is not ConstantExpression methodInfoExpression)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form");
if (methodInfoExpression.Value is not MethodInfo methodInfo)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form");

Expand Down
22 changes: 16 additions & 6 deletions src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,26 @@ public static FuncInstanceMethodInfo1<TTarget, T, TResult> Create(Expression<Fun
// Verify that we are creating a delegate of type Func<T, TResult>
Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form");
Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func<T, TResult>), nameof(expression), "Unexpected expression form");
var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0];
Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form");
if (delegateTypeExpression.Value is not Type delegateType)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(delegateType == typeof(Func<T, TResult>), nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form");

// Check the MethodInfo
Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form");

var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value;
if (methodCallExpression.Object is not ConstantExpression methodInfoExpression)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form");
if (methodInfoExpression.Value is not MethodInfo methodInfo)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form");

Expand Down
22 changes: 16 additions & 6 deletions src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,26 @@ public static FuncInstanceMethodInfo1<TTarget, T1, T2, TResult> Create(Expressio
// Verify that we are creating a delegate of type Func<T, TResult>
Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form");
Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func<T1, T2, TResult>), nameof(expression), "Unexpected expression form");
var delegateTypeExpression = (ConstantExpression)methodCallExpression.Arguments[0];
Contracts.CheckParam(delegateTypeExpression.Type == typeof(Type), nameof(expression), "Unexpected expression form");
if (delegateTypeExpression.Value is not Type delegateType)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(delegateType == typeof(Func<T1, T2, TResult>), nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form");

// Check the MethodInfo
Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form");

var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value;
if (methodCallExpression.Object is not ConstantExpression methodInfoExpression)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(methodInfoExpression.Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form");
if (methodInfoExpression.Value is not MethodInfo methodInfo)
{
throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form");
}
Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form");
Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form");

Expand Down
Loading