Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@
// See the LICENSE file in the project root for more information.

using System;
namespace Microsoft.ML.EntryPoints
{
/// <summary>
/// This is a signature for classes that are 'holders' of entry points and components.
/// </summary>
[BestFriend]
internal delegate void SignatureEntryPointModule();
namespace Microsoft.ML.EntryPoints;

/// <summary>
/// This is a signature for classes that are 'holders' of entry points and components.
/// </summary>
[BestFriend]
internal delegate void SignatureEntryPointModule();

/// <summary>
/// A simplified assembly attribute for marking EntryPoint modules.
/// </summary>
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
[BestFriend]
internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
{
public EntryPointModuleAttribute(Type loaderType)
: base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
{ }
}
/// <summary>
/// A simplified assembly attribute for marking EntryPoint modules.
/// </summary>
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
[BestFriend]
internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
{
public EntryPointModuleAttribute(Type loaderType)
: base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
{ }
}
213 changes: 106 additions & 107 deletions src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,127 +9,126 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.EntryPoints
namespace Microsoft.ML.EntryPoints;

[BestFriend]
internal static class EntryPointUtils
{
[BestFriend]
internal static class EntryPointUtils
private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo
= new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>);

private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj)
{
private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo
= new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>);
T val;
if (obj is Optional<T> asOptional)
val = asOptional.Value;
else
val = (T)obj;

private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj)
{
T val;
if (obj is Optional<T> asOptional)
val = asOptional.Value;
else
val = (T)obj;

return
(range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
(range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
(range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
(range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
}
return
(range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
(range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
(range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
(range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
}

public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
{
Contracts.AssertValue(range);
Contracts.AssertValue(val);
// Avoid trying to cast double as float. If range
// was specified using floats, but value being checked
// is double, change range to be of type double
if (range.Type == typeof(float) && val is double)
range.CastToDouble();
return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
}
public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
{
Contracts.AssertValue(range);
Contracts.AssertValue(val);
// Avoid trying to cast double as float. If range
// was specified using floats, but value being checked
// is double, change range to be of type double
if (range.Type == typeof(float) && val is double)
range.CastToDouble();
return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
}

/// <summary>
/// Performs checks on an EntryPoint input class equivalent to the checks that are done
/// when parsing a JSON EntryPoint graph.
///
/// Call this method from EntryPoint methods to ensure that range and required checks are performed
/// in a consistent manner when EntryPoints are created directly from code.
/// </summary>
public static void CheckInputArgs(IExceptionContext ectx, object args)
/// <summary>
/// Performs checks on an EntryPoint input class equivalent to the checks that are done
/// when parsing a JSON EntryPoint graph.
///
/// Call this method from EntryPoint methods to ensure that range and required checks are performed
/// in a consistent manner when EntryPoints are created directly from code.
/// </summary>
public static void CheckInputArgs(IExceptionContext ectx, object args)
{
foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
as ArgumentAttribute;
if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
continue;

var fieldVal = fieldInfo.GetValue(args);
var fieldType = fieldInfo.FieldType;

// Optionals are either left in their Implicit constructed state or
// a new Explicit optional is constructed. They should never be set
// to null.
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);

if (attr.IsRequired)
{
bool equalToDefault;
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
equalToDefault = !((Optional)fieldVal).IsExplicit;
else
equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;

if (equalToDefault)
throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
}

var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
as TlcModule.RangeAttribute;
if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
}
}
var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
as ArgumentAttribute;
if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
continue;

public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(hostName);
host.CheckValue(input, nameof(input));
CheckInputArgs(host, input);
return host;
}
var fieldVal = fieldInfo.GetValue(args);
var fieldType = fieldInfo.FieldType;

/// <summary>
/// Searches for the given column name in the schema. This method applies a
/// common policy that throws an exception if the column is not found
/// and the column name was explicitly specified. If the column is not found
/// and the column name was not explicitly specified, it returns null.
/// </summary>
public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckValue(value, nameof(value));
// Optionals are either left in their Implicit constructed state or
// a new Explicit optional is constructed. They should never be set
// to null.
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);

if (value == "")
return null;
if (schema.GetColumnOrNull(value) == null)
if (attr.IsRequired)
{
if (value.IsExplicit)
throw ectx.Except("Column '{0}' not found", value);
return null;
bool equalToDefault;
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
equalToDefault = !((Optional)fieldVal).IsExplicit;
else
equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;

if (equalToDefault)
throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
}
return value;

var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
as TlcModule.RangeAttribute;
if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
}
}

/// <summary>
/// Converts EntryPoint Optional{T} types into nullable types, with the
/// implicit value being converted to the null value.
/// </summary>
public static T? AsNullable<T>(this Optional<T> opt) where T : struct
public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(hostName);
host.CheckValue(input, nameof(input));
CheckInputArgs(host, input);
return host;
}

/// <summary>
/// Searches for the given column name in the schema. This method applies a
/// common policy that throws an exception if the column is not found
/// and the column name was explicitly specified. If the column is not found
/// and the column name was not explicitly specified, it returns null.
/// </summary>
public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckValue(value, nameof(value));

if (value == "")
return null;
if (schema.GetColumnOrNull(value) == null)
{
if (opt.IsExplicit)
return opt.Value;
else
return null;
if (value.IsExplicit)
throw ectx.Except("Column '{0}' not found", value);
return null;
}
return value;
}

/// <summary>
/// Converts EntryPoint Optional{T} types into nullable types, with the
/// implicit value being converted to the null value.
/// </summary>
public static T? AsNullable<T>(this Optional<T> opt) where T : struct
{
if (opt.IsExplicit)
return opt.Value;
else
return null;
}
}
Loading