diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
index 1dcfd675c1..14bf50f20c 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
@@ -3,23 +3,22 @@
// See the LICENSE file in the project root for more information.
using System;
-namespace Microsoft.ML.EntryPoints
-{
- ///
- /// This is a signature for classes that are 'holders' of entry points and components.
- ///
- [BestFriend]
- internal delegate void SignatureEntryPointModule();
+namespace Microsoft.ML.EntryPoints;
+
+///
+/// This is a signature for classes that are 'holders' of entry points and components.
+///
+[BestFriend]
+internal delegate void SignatureEntryPointModule();
- ///
- /// A simplified assembly attribute for marking EntryPoint modules.
- ///
- [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
- [BestFriend]
- internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
- {
- public EntryPointModuleAttribute(Type loaderType)
- : base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
- { }
- }
+///
+/// A simplified assembly attribute for marking EntryPoint modules.
+///
+[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+[BestFriend]
+internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
+{
+ public EntryPointModuleAttribute(Type loaderType)
+ : base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
+ { }
}
diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
index 3cf2e81016..d4ee9c2a42 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
@@ -9,127 +9,126 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.EntryPoints
+namespace Microsoft.ML.EntryPoints;
+
+[BestFriend]
+internal static class EntryPointUtils
{
- [BestFriend]
- internal static class EntryPointUtils
+ private static readonly FuncStaticMethodInfo1 _isValueWithinRangeMethodInfo
+ = new FuncStaticMethodInfo1(IsValueWithinRange);
+
+ private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj)
{
- private static readonly FuncStaticMethodInfo1 _isValueWithinRangeMethodInfo
- = new FuncStaticMethodInfo1(IsValueWithinRange);
+ T val;
+ if (obj is Optional asOptional)
+ val = asOptional.Value;
+ else
+ val = (T)obj;
- private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj)
- {
- T val;
- if (obj is Optional asOptional)
- val = asOptional.Value;
- else
- val = (T)obj;
-
- return
- (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
- (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
- (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
- (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
- }
+ return
+ (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
+ (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
+ (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
+ (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
+ }
- public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
- {
- Contracts.AssertValue(range);
- Contracts.AssertValue(val);
- // Avoid trying to cast double as float. If range
- // was specified using floats, but value being checked
- // is double, change range to be of type double
- if (range.Type == typeof(float) && val is double)
- range.CastToDouble();
- return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
- }
+ public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
+ {
+ Contracts.AssertValue(range);
+ Contracts.AssertValue(val);
+ // Avoid trying to cast double as float. If range
+ // was specified using floats, but value being checked
+ // is double, change range to be of type double
+ if (range.Type == typeof(float) && val is double)
+ range.CastToDouble();
+ return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
+ }
- ///
- /// Performs checks on an EntryPoint input class equivalent to the checks that are done
- /// when parsing a JSON EntryPoint graph.
- ///
- /// Call this method from EntryPoint methods to ensure that range and required checks are performed
- /// in a consistent manner when EntryPoints are created directly from code.
- ///
- public static void CheckInputArgs(IExceptionContext ectx, object args)
+ ///
+ /// Performs checks on an EntryPoint input class equivalent to the checks that are done
+ /// when parsing a JSON EntryPoint graph.
+ ///
+ /// Call this method from EntryPoint methods to ensure that range and required checks are performed
+ /// in a consistent manner when EntryPoints are created directly from code.
+ ///
+ public static void CheckInputArgs(IExceptionContext ectx, object args)
+ {
+ foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
- foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
- {
- var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
- as ArgumentAttribute;
- if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
- continue;
-
- var fieldVal = fieldInfo.GetValue(args);
- var fieldType = fieldInfo.FieldType;
-
- // Optionals are either left in their Implicit constructed state or
- // a new Explicit optional is constructed. They should never be set
- // to null.
- if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
- throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);
-
- if (attr.IsRequired)
- {
- bool equalToDefault;
- if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
- equalToDefault = !((Optional)fieldVal).IsExplicit;
- else
- equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;
-
- if (equalToDefault)
- throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
- }
-
- var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
- as TlcModule.RangeAttribute;
- if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
- throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
- }
- }
+ var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
+ as ArgumentAttribute;
+ if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
+ continue;
- public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
- {
- Contracts.CheckValue(env, nameof(env));
- var host = env.Register(hostName);
- host.CheckValue(input, nameof(input));
- CheckInputArgs(host, input);
- return host;
- }
+ var fieldVal = fieldInfo.GetValue(args);
+ var fieldType = fieldInfo.FieldType;
- ///
- /// Searches for the given column name in the schema. This method applies a
- /// common policy that throws an exception if the column is not found
- /// and the column name was explicitly specified. If the column is not found
- /// and the column name was not explicitly specified, it returns null.
- ///
- public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional value)
- {
- Contracts.CheckValueOrNull(ectx);
- ectx.CheckValue(schema, nameof(schema));
- ectx.CheckValue(value, nameof(value));
+ // Optionals are either left in their Implicit constructed state or
+ // a new Explicit optional is constructed. They should never be set
+ // to null.
+ if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
+ throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);
- if (value == "")
- return null;
- if (schema.GetColumnOrNull(value) == null)
+ if (attr.IsRequired)
{
- if (value.IsExplicit)
- throw ectx.Except("Column '{0}' not found", value);
- return null;
+ bool equalToDefault;
+ if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
+ equalToDefault = !((Optional)fieldVal).IsExplicit;
+ else
+ equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;
+
+ if (equalToDefault)
+ throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
}
- return value;
+
+ var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
+ as TlcModule.RangeAttribute;
+ if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
+ throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
}
+ }
- ///
- /// Converts EntryPoint Optional{T} types into nullable types, with the
- /// implicit value being converted to the null value.
- ///
- public static T? AsNullable(this Optional opt) where T : struct
+ public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(hostName);
+ host.CheckValue(input, nameof(input));
+ CheckInputArgs(host, input);
+ return host;
+ }
+
+ ///
+ /// Searches for the given column name in the schema. This method applies a
+ /// common policy that throws an exception if the column is not found
+ /// and the column name was explicitly specified. If the column is not found
+ /// and the column name was not explicitly specified, it returns null.
+ ///
+ public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional value)
+ {
+ Contracts.CheckValueOrNull(ectx);
+ ectx.CheckValue(schema, nameof(schema));
+ ectx.CheckValue(value, nameof(value));
+
+ if (value == "")
+ return null;
+ if (schema.GetColumnOrNull(value) == null)
{
- if (opt.IsExplicit)
- return opt.Value;
- else
- return null;
+ if (value.IsExplicit)
+ throw ectx.Except("Column '{0}' not found", value);
+ return null;
}
+ return value;
+ }
+
+ ///
+ /// Converts EntryPoint Optional{T} types into nullable types, with the
+ /// implicit value being converted to the null value.
+ ///
+ public static T? AsNullable(this Optional opt) where T : struct
+ {
+ if (opt.IsExplicit)
+ return opt.Value;
+ else
+ return null;
}
}
diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
index 867df0a648..15d0fc56e6 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
@@ -9,723 +9,722 @@
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.EntryPoints
+namespace Microsoft.ML.EntryPoints;
+
+///
+/// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining
+/// the module interface.
+///
+[BestFriend]
+internal static class TlcModule
{
///
- /// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining
- /// the module interface.
+ /// An attribute used to annotate the component.
///
- [BestFriend]
- internal static class TlcModule
+ [AttributeUsage(AttributeTargets.Class)]
+ public sealed class ComponentAttribute : Attribute
{
///
- /// An attribute used to annotate the component.
+ /// The load name of the component. Must be unique within its kind.
///
- [AttributeUsage(AttributeTargets.Class)]
- public sealed class ComponentAttribute : Attribute
+ public string Name { get; set; }
+
+ ///
+ /// UI friendly name. Can contain spaces and other forbidden for Name symbols.
+ ///
+ public string FriendlyName { get; set; }
+
+ ///
+ /// Alternative names of the component. Each alias must also be unique in the component's kind.
+ ///
+ public string[] Aliases { get; set; }
+
+ ///
+ /// Comma-separated .
+ ///
+ public string Alias
{
- ///
- /// The load name of the component. Must be unique within its kind.
- ///
- public string Name { get; set; }
-
- ///
- /// UI friendly name. Can contain spaces and other forbidden for Name symbols.
- ///
- public string FriendlyName { get; set; }
-
- ///
- /// Alternative names of the component. Each alias must also be unique in the component's kind.
- ///
- public string[] Aliases { get; set; }
-
- ///
- /// Comma-separated .
- ///
- public string Alias
+ get
{
- get
- {
- if (Aliases == null)
- return null;
- return string.Join(",", Aliases);
- }
- set
+ if (Aliases == null)
+ return null;
+ return string.Join(",", Aliases);
+ }
+ set
+ {
+ if (string.IsNullOrWhiteSpace(value))
+ Aliases = null;
+ else
{
- if (string.IsNullOrWhiteSpace(value))
- Aliases = null;
- else
- {
- var parts = value.Split(',');
- Aliases = parts.Select(x => x.Trim()).ToArray();
- }
+ var parts = value.Split(',');
+ Aliases = parts.Select(x => x.Trim()).ToArray();
}
}
-
- ///
- /// Description of the component.
- ///
- public string Desc { get; set; }
-
- ///
- /// This should indicate a name of an embedded resource that contains detailed documents
- /// for the component, for example, markdown document with the .md extension. The embedded resource
- /// is assumed to be in the same assembly as the class on which this attribute is ascribed.
- ///
- public string DocName { get; set; }
}
///
- /// An attribute used to annotate the signature interface.
- /// Effectively, this is a way to associate the signature interface with a user-friendly name.
+ /// Description of the component.
///
- [AttributeUsage(AttributeTargets.Interface)]
- public sealed class ComponentKindAttribute : Attribute
- {
- public readonly string Kind;
-
- public ComponentKindAttribute(string kind)
- {
- Kind = kind;
- }
- }
+ public string Desc { get; set; }
///
- /// An attribute used to annotate the kind of entry points.
- /// Typically it is used on the input classes.
+ /// This should indicate a name of an embedded resource that contains detailed documents
+ /// for the component, for example, markdown document with the .md extension. The embedded resource
+ /// is assumed to be in the same assembly as the class on which this attribute is ascribed.
///
- [AttributeUsage(AttributeTargets.Class)]
- public sealed class EntryPointKindAttribute : Attribute
- {
- public readonly Type[] Kinds;
+ public string DocName { get; set; }
+ }
- public EntryPointKindAttribute(params Type[] kinds)
- {
- Kinds = kinds;
- }
+ ///
+ /// An attribute used to annotate the signature interface.
+ /// Effectively, this is a way to associate the signature interface with a user-friendly name.
+ ///
+ [AttributeUsage(AttributeTargets.Interface)]
+ public sealed class ComponentKindAttribute : Attribute
+ {
+ public readonly string Kind;
+
+ public ComponentKindAttribute(string kind)
+ {
+ Kind = kind;
}
+ }
- ///
- /// An attribute used to annotate the outputs of the module.
- ///
- [AttributeUsage(AttributeTargets.Field)]
- public sealed class OutputAttribute : Attribute
+ ///
+ /// An attribute used to annotate the kind of entry points.
+ /// Typically it is used on the input classes.
+ ///
+ [AttributeUsage(AttributeTargets.Class)]
+ public sealed class EntryPointKindAttribute : Attribute
+ {
+ public readonly Type[] Kinds;
+
+ public EntryPointKindAttribute(params Type[] kinds)
{
- ///
- /// Official name of the output. If it is not specified, the field name is used.
- ///
- public string Name { get; set; }
-
- ///
- /// The description of the output.
- ///
- public string Desc { get; set; }
-
- ///
- /// The rank order of the output. Because .NET reflection returns members in an unspecified order, this
- /// is the only way to ensure consistency.
- ///
- public Double SortOrder { get; set; }
+ Kinds = kinds;
}
+ }
+ ///
+ /// An attribute used to annotate the outputs of the module.
+ ///
+ [AttributeUsage(AttributeTargets.Field)]
+ public sealed class OutputAttribute : Attribute
+ {
///
- /// An attribute to indicate that a field is optional in an EntryPoint module.
- /// A node can be run without optional input fields.
+ /// Official name of the output. If it is not specified, the field name is used.
///
- [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
- public sealed class OptionalInputAttribute : Attribute { }
+ public string Name { get; set; }
///
- /// An attribute used to annotate the valid range of a numeric input.
+ /// The description of the output.
///
- [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
- public sealed class RangeAttribute : Attribute
- {
- private object _min;
- private object _max;
- private object _inf;
- private object _sup;
- private Type _type;
-
- ///
- /// The target type of this range attribute, as determined by the type of
- /// the set range bound values.
- ///
- public Type Type => _type;
-
- ///
- /// An inclusive lower bound of the value.
- ///
- public object Min
- {
- get { return _min; }
- set
- {
- CheckType(value);
- Contracts.Check(_inf == null,
- "The minimum and infimum cannot be both set in a range attribute");
- Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) != -1,
- "The minimum must be less than or equal to the maximum");
- Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1,
- "The minimum must be less than the supremum");
- _min = value;
- }
- }
-
- ///
- /// An inclusive upper bound of the value.
- ///
- public object Max
- {
- get { return _max; }
- set
- {
- CheckType(value);
- Contracts.Check(_sup == null,
- "The maximum and supremum cannot be both set in a range attribute");
- Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) != 1,
- "The maximum must be greater than or equal to the minimum");
- Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1,
- "The maximum must be greater than the infimum");
- _max = value;
- }
- }
+ public string Desc { get; set; }
- ///
- /// An exclusive lower bound of the value.
- ///
- public object Inf
- {
- get { return _inf; }
- set
- {
- CheckType(value);
- Contracts.Check(_min == null,
- "The infimum and minimum cannot be both set in a range attribute");
- Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) == 1,
- "The infimum must be less than the maximum");
- Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1,
- "The infimum must be less than the supremum");
- _inf = value;
- }
- }
+ ///
+ /// The rank order of the output. Because .NET reflection returns members in an unspecified order, this
+ /// is the only way to ensure consistency.
+ ///
+ public Double SortOrder { get; set; }
+ }
- ///
- /// An exclusive upper bound of the value.
- ///
- public object Sup
- {
- get { return _sup; }
- set
- {
- CheckType(value);
- Contracts.Check(_max == null,
- "The supremum and maximum cannot be both set in a range attribute");
- Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) == -1,
- "The supremum must be greater than the minimum");
- Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1,
- "The supremum must be greater than the infimum");
- _sup = value;
- }
- }
+ ///
+ /// An attribute to indicate that a field is optional in an EntryPoint module.
+ /// A node can be run without optional input fields.
+ ///
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
+ public sealed class OptionalInputAttribute : Attribute { }
- private void CheckType(object val)
- {
- Contracts.CheckValue(val, nameof(val));
- if (_type == null)
- {
- Contracts.Check(val is IComparable, "Type for range attribute must support IComparable");
- _type = val.GetType();
- }
- else
- Contracts.Check(_type == val.GetType(), "All Range attribute values must be of the same type");
- }
+ ///
+ /// An attribute used to annotate the valid range of a numeric input.
+ ///
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
+ public sealed class RangeAttribute : Attribute
+ {
+ private object _min;
+ private object _max;
+ private object _inf;
+ private object _sup;
+ private Type _type;
- public void CastToDouble()
- {
- _type = typeof(double);
- if (_inf != null)
- _inf = Convert.ToDouble(_inf);
- if (_min != null)
- _min = Convert.ToDouble(_min);
- if (_max != null)
- _max = Convert.ToDouble(_max);
- if (_sup != null)
- _sup = Convert.ToDouble(_sup);
- }
+ ///
+ /// The target type of this range attribute, as determined by the type of
+ /// the set range bound values.
+ ///
+ public Type Type => _type;
- public override string ToString()
+ ///
+ /// An inclusive lower bound of the value.
+ ///
+ public object Min
+ {
+ get { return _min; }
+ set
{
- string optionalTypeSpecifier = "";
- if (_type == typeof(double))
- optionalTypeSpecifier = "d";
- else if (_type == typeof(float))
- optionalTypeSpecifier = "f";
-
- var pieces = new List();
- if (_inf != null)
- pieces.Add($"Inf = {_inf}{optionalTypeSpecifier}");
- if (_min != null)
- pieces.Add($"Min = {_min}{optionalTypeSpecifier}");
- if (_max != null)
- pieces.Add($"Max = {_max}{optionalTypeSpecifier}");
- if (_sup != null)
- pieces.Add($"Sup = {_sup}{optionalTypeSpecifier}");
- return $"[TlcModule.Range({string.Join(", ", pieces)})]";
+ CheckType(value);
+ Contracts.Check(_inf == null,
+ "The minimum and infimum cannot be both set in a range attribute");
+ Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) != -1,
+ "The minimum must be less than or equal to the maximum");
+ Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1,
+ "The minimum must be less than the supremum");
+ _min = value;
}
}
///
- /// An attribute used to indicate suggested sweep ranges for parameter sweeping.
+ /// An inclusive upper bound of the value.
///
- public abstract class SweepableParamAttribute : Attribute
+ public object Max
{
- public string Name { get; set; }
- private IComparable _rawValue;
- public virtual IComparable RawValue
+ get { return _max; }
+ set
{
- get => _rawValue;
- set
- {
- if (!Frozen)
- _rawValue = value;
- }
+ CheckType(value);
+ Contracts.Check(_sup == null,
+ "The maximum and supremum cannot be both set in a range attribute");
+ Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) != 1,
+ "The maximum must be greater than or equal to the minimum");
+ Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1,
+ "The maximum must be greater than the infimum");
+ _max = value;
}
-
- // The raw value will store an index for discrete parameters,
- // but sometimes we want the text or numeric value itself,
- // not the hot index. The processed value does that for discrete
- // params. For other params, it just returns the raw value itself.
- public virtual IComparable ProcessedValue() => _rawValue;
-
- // Allows for hyperparameter value freezing, so that sweeps
- // will not alter the current value when true.
- public bool Frozen { get; set; }
-
- // Allows the sweepable param to be set directly using the
- // available ValueText attribute on IParameterValues (from
- // the ParameterSets used in the old hyperparameter sweepers).
- public abstract void SetUsingValueText(string valueText);
-
- public abstract SweepableParamAttribute Clone();
}
///
- /// An attribute used to indicate suggested sweep ranges for discrete parameter sweeping.
- /// The value is the index of the option chosen. Use Options[Value] to get the corresponding
- /// option using the index.
+ /// An exclusive lower bound of the value.
///
- [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
- public sealed class SweepableDiscreteParamAttribute : SweepableParamAttribute
+ public object Inf
{
- public object[] Options { get; }
-
- public SweepableDiscreteParamAttribute(string name, object[] values, bool isBool = false) : this(values, isBool)
+ get { return _inf; }
+ set
{
- Name = name;
- }
-
- public SweepableDiscreteParamAttribute(object[] values, bool isBool = false)
- {
- Options = isBool ? new object[] { false, true } : values;
- }
-
- public override IComparable RawValue
- {
- get => base.RawValue;
- set
- {
- var val = Convert.ToInt32(value);
- if (!Frozen && 0 <= val && val < Options.Length)
- base.RawValue = val;
- }
+ CheckType(value);
+ Contracts.Check(_min == null,
+ "The infimum and minimum cannot be both set in a range attribute");
+ Contracts.Check(_max == null || ((IComparable)_max).CompareTo(value) == 1,
+ "The infimum must be less than the maximum");
+ Contracts.Check(_sup == null || ((IComparable)_sup).CompareTo(value) == 1,
+ "The infimum must be less than the supremum");
+ _inf = value;
}
-
- public override void SetUsingValueText(string valueText)
- {
- for (int i = 0; i < Options.Length; i++)
- if (valueText == Options[i].ToString())
- RawValue = i;
- }
-
- public int IndexOf(object option)
- {
- for (int i = 0; i < Options.Length; i++)
- if (option == Options[i])
- return i;
- return -1;
- }
-
- private static string TranslateOption(object o)
- {
- switch (o)
- {
- case float _:
- case double _:
- return $"{o}f";
- case long _:
- case int _:
- case byte _:
- case short _:
- return o.ToString();
- case bool _:
- return o.ToString().ToLower();
- case Enum _:
- var type = o.GetType();
- var defaultName = $"Enums.{type.Name}.{o.ToString()}";
- var name = type.FullName?.Replace("+", ".");
- if (name == null)
- return defaultName;
- var index1 = name.LastIndexOf(".", StringComparison.Ordinal);
- var index2 = name.Substring(0, index1).LastIndexOf(".", StringComparison.Ordinal) + 1;
- if (index2 >= 0)
- return $"{name.Substring(index2)}.{o.ToString()}";
- return defaultName;
- default:
- return $"\"{o}\"";
- }
- }
-
- public override SweepableParamAttribute Clone() =>
- new SweepableDiscreteParamAttribute(Name, Options) { RawValue = RawValue, Frozen = Frozen };
-
- public override string ToString()
- {
- var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
- return $"[TlcModule.{GetType().Name}({name}new object[]{{{string.Join(", ", Options.Select(TranslateOption))}}})]";
- }
-
- public override IComparable ProcessedValue() => (IComparable)Options[(int)RawValue];
}
///
- /// An attribute used to indicate suggested sweep ranges for float parameter sweeping.
+ /// An exclusive upper bound of the value.
///
- [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
- public sealed class SweepableFloatParamAttribute : SweepableParamAttribute
+ public object Sup
{
- public float Min { get; }
- public float Max { get; }
- public float? StepSize { get; }
- public int? NumSteps { get; }
- public bool IsLogScale { get; }
-
- public SweepableFloatParamAttribute(string name, float min, float max, float stepSize = -1, int numSteps = -1,
- bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale)
+ get { return _sup; }
+ set
{
- Name = name;
- }
-
- public SweepableFloatParamAttribute(float min, float max, float stepSize = -1, int numSteps = -1, bool isLogScale = false)
- {
- Min = min;
- Max = max;
- if (!stepSize.Equals(-1))
- StepSize = stepSize;
- if (numSteps != -1)
- NumSteps = numSteps;
- IsLogScale = isLogScale;
+ CheckType(value);
+ Contracts.Check(_max == null,
+ "The supremum and maximum cannot be both set in a range attribute");
+ Contracts.Check(_min == null || ((IComparable)_min).CompareTo(value) == -1,
+ "The supremum must be greater than the minimum");
+ Contracts.Check(_inf == null || ((IComparable)_inf).CompareTo(value) == -1,
+ "The supremum must be greater than the infimum");
+ _sup = value;
}
+ }
- public override void SetUsingValueText(string valueText)
+ private void CheckType(object val)
+ {
+ Contracts.CheckValue(val, nameof(val));
+ if (_type == null)
{
- RawValue = float.Parse(valueText);
+ Contracts.Check(val is IComparable, "Type for range attribute must support IComparable");
+ _type = val.GetType();
}
+ else
+ Contracts.Check(_type == val.GetType(), "All Range attribute values must be of the same type");
+ }
- public override SweepableParamAttribute Clone() =>
- new SweepableFloatParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen };
+ public void CastToDouble()
+ {
+ _type = typeof(double);
+ if (_inf != null)
+ _inf = Convert.ToDouble(_inf);
+ if (_min != null)
+ _min = Convert.ToDouble(_min);
+ if (_max != null)
+ _max = Convert.ToDouble(_max);
+ if (_sup != null)
+ _sup = Convert.ToDouble(_sup);
+ }
- public override string ToString()
- {
- var optional = new StringBuilder();
- if (StepSize != null)
- optional.Append($", stepSize:{StepSize}");
- if (NumSteps != null)
- optional.Append($", numSteps:{NumSteps}");
- if (IsLogScale)
- optional.Append($", isLogScale:true");
- var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
- return $"[TlcModule.{GetType().Name}({name}{Min}f, {Max}f{optional})]";
- }
+ public override string ToString()
+ {
+ string optionalTypeSpecifier = "";
+ if (_type == typeof(double))
+ optionalTypeSpecifier = "d";
+ else if (_type == typeof(float))
+ optionalTypeSpecifier = "f";
+
+ var pieces = new List();
+ if (_inf != null)
+ pieces.Add($"Inf = {_inf}{optionalTypeSpecifier}");
+ if (_min != null)
+ pieces.Add($"Min = {_min}{optionalTypeSpecifier}");
+ if (_max != null)
+ pieces.Add($"Max = {_max}{optionalTypeSpecifier}");
+ if (_sup != null)
+ pieces.Add($"Sup = {_sup}{optionalTypeSpecifier}");
+ return $"[TlcModule.Range({string.Join(", ", pieces)})]";
}
+ }
- ///
- /// An attribute used to indicate suggested sweep ranges for long parameter sweeping.
- ///
- [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
- public sealed class SweepableLongParamAttribute : SweepableParamAttribute
+ ///
+ /// An attribute used to indicate suggested sweep ranges for parameter sweeping.
+ ///
+ public abstract class SweepableParamAttribute : Attribute
+ {
+ public string Name { get; set; }
+ private IComparable _rawValue;
+ public virtual IComparable RawValue
{
- public long Min { get; }
- public long Max { get; }
- public float? StepSize { get; }
- public int? NumSteps { get; }
- public bool IsLogScale { get; }
-
- public SweepableLongParamAttribute(string name, long min, long max, float stepSize = -1, int numSteps = -1,
- bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale)
+ get => _rawValue;
+ set
{
- Name = name;
+ if (!Frozen)
+ _rawValue = value;
}
+ }
- public SweepableLongParamAttribute(long min, long max, float stepSize = -1, int numSteps = -1, bool isLogScale = false)
- {
- Min = min;
- Max = max;
- if (!stepSize.Equals(-1))
- StepSize = stepSize;
- if (numSteps != -1)
- NumSteps = numSteps;
- IsLogScale = isLogScale;
- }
+ // The raw value will store an index for discrete parameters,
+ // but sometimes we want the text or numeric value itself,
+ // not the hot index. The processed value does that for discrete
+ // params. For other params, it just returns the raw value itself.
+ public virtual IComparable ProcessedValue() => _rawValue;
- public override void SetUsingValueText(string valueText)
- {
- RawValue = long.Parse(valueText);
- }
+ // Allows for hyperparameter value freezing, so that sweeps
+ // will not alter the current value when true.
+ public bool Frozen { get; set; }
+
+ // Allows the sweepable param to be set directly using the
+ // available ValueText attribute on IParameterValues (from
+ // the ParameterSets used in the old hyperparameter sweepers).
+ public abstract void SetUsingValueText(string valueText);
- public override SweepableParamAttribute Clone() =>
- new SweepableLongParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen };
+ public abstract SweepableParamAttribute Clone();
+ }
+
+ ///
+ /// An attribute used to indicate suggested sweep ranges for discrete parameter sweeping.
+ /// The value is the index of the option chosen. Use Options[Value] to get the corresponding
+ /// option using the index.
+ ///
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
+ public sealed class SweepableDiscreteParamAttribute : SweepableParamAttribute
+ {
+ public object[] Options { get; }
- public override string ToString()
+ public SweepableDiscreteParamAttribute(string name, object[] values, bool isBool = false) : this(values, isBool)
+ {
+ Name = name;
+ }
+
+ public SweepableDiscreteParamAttribute(object[] values, bool isBool = false)
+ {
+ Options = isBool ? new object[] { false, true } : values;
+ }
+
+ public override IComparable RawValue
+ {
+ get => base.RawValue;
+ set
{
- var optional = new StringBuilder();
- if (StepSize != null)
- optional.Append($", stepSize:{StepSize}");
- if (NumSteps != null)
- optional.Append($", numSteps:{NumSteps}");
- if (IsLogScale)
- optional.Append($", isLogScale:true");
- var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
- return $"[TlcModule.{GetType().Name}({name}{Min}, {Max}{optional})]";
+ var val = Convert.ToInt32(value);
+ if (!Frozen && 0 <= val && val < Options.Length)
+ base.RawValue = val;
}
}
- ///
- /// An attribute to mark an entry point of a module.
- ///
- [AttributeUsage(AttributeTargets.Method)]
- public sealed class EntryPointAttribute : Attribute
+ public override void SetUsingValueText(string valueText)
{
- ///
- /// The entry point name.
- ///
- public string Name { get; set; }
-
- ///
- /// The entry point description.
- ///
- public string Desc { get; set; }
-
- ///
- /// UI friendly name. Can contain spaces and other forbidden for Name symbols.
- ///
- public string UserName { get; set; }
-
- ///
- /// Short name of the Entry Point
- ///
- public string ShortName { get; set; }
+ for (int i = 0; i < Options.Length; i++)
+ if (valueText == Options[i].ToString())
+ RawValue = i;
}
- ///
- /// The list of data types that are supported as inputs or outputs.
- ///
- public enum DataKind
+ public int IndexOf(object option)
{
- ///
- /// Not used.
- ///
- Unknown = 0,
- ///
- /// Integer, including long.
- ///
- Int,
- ///
- /// Unsigned integer, including ulong.
- ///
- UInt,
- ///
- /// Floating point, including double.
- ///
- Float,
- ///
- /// A char.
- ///
- Char,
- ///
- /// A string.
- ///
- String,
- ///
- /// A boolean value.
- ///
- Bool,
- ///
- /// A dataset, represented by an .
- ///
- DataView,
- ///
- /// A file handle, represented by an .
- ///
- FileHandle,
- ///
- /// A transform model, represented by an .
- ///
- TransformModel,
- ///
- /// A predictor model, represented by an .
- ///
- PredictorModel,
- ///
- /// An enum: one value of a specified list.
- ///
- Enum,
- ///
- /// An array (0 or more values of the same type, accessible by index).
- ///
- Array,
- ///
- /// A dictionary (0 or more values of the same type, identified by a unique string key).
- /// The underlying C# representation is
- ///
- Dictionary,
- ///
- /// A component of a specified kind. The component is identified by the "load name" (unique per kind) and,
- /// optionally, a set of parameters, unique to each component. Example: "BinaryClassifierEvaluator{threshold=0.5}".
- /// The C# representation is .
- ///
- Component
+ for (int i = 0; i < Options.Length; i++)
+ if (option == Options[i])
+ return i;
+ return -1;
}
- public static DataKind GetDataType(Type type)
+ private static string TranslateOption(object o)
{
- Contracts.AssertValue(type);
-
- // If this is a Optional-wrapped type, unwrap it and examine
- // the inner type.
- if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>)))
- type = type.GetGenericArguments()[0];
-
- if (type == typeof(char))
- return DataKind.Char;
- if (type == typeof(string))
- return DataKind.String;
- if (type == typeof(bool))
- return DataKind.Bool;
- if (type == typeof(int) || type == typeof(long))
- return DataKind.Int;
- if (type == typeof(uint) || type == typeof(ulong))
- return DataKind.UInt;
- if (type == typeof(Single) || type == typeof(Double))
- return DataKind.Float;
- if (typeof(IDataView).IsAssignableFrom(type))
- return DataKind.DataView;
- if (typeof(TransformModel).IsAssignableFrom(type))
- return DataKind.TransformModel;
- if (typeof(PredictorModel).IsAssignableFrom(type))
- return DataKind.PredictorModel;
- if (typeof(IFileHandle).IsAssignableFrom(type))
- return DataKind.FileHandle;
- if (type.IsEnum)
- return DataKind.Enum;
- if (type.IsArray)
- return DataKind.Array;
- if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary<,>)
- && type.GetGenericArguments()[0] == typeof(string))
+ switch (o)
{
- return DataKind.Dictionary;
+ case float _:
+ case double _:
+ return $"{o}f";
+ case long _:
+ case int _:
+ case byte _:
+ case short _:
+ return o.ToString();
+ case bool _:
+ return o.ToString().ToLower();
+ case Enum _:
+ var type = o.GetType();
+ var defaultName = $"Enums.{type.Name}.{o.ToString()}";
+ var name = type.FullName?.Replace("+", ".");
+ if (name == null)
+ return defaultName;
+ var index1 = name.LastIndexOf(".", StringComparison.Ordinal);
+ var index2 = name.Substring(0, index1).LastIndexOf(".", StringComparison.Ordinal) + 1;
+ if (index2 >= 0)
+ return $"{name.Substring(index2)}.{o.ToString()}";
+ return defaultName;
+ default:
+ return $"\"{o}\"";
}
- if (typeof(IComponentFactory).IsAssignableFrom(type))
- return DataKind.Component;
-
- return DataKind.Unknown;
}
- public static bool IsNumericKind(DataKind kind)
+ public override SweepableParamAttribute Clone() =>
+ new SweepableDiscreteParamAttribute(Name, Options) { RawValue = RawValue, Frozen = Frozen };
+
+ public override string ToString()
{
- return kind == DataKind.Int || kind == DataKind.UInt || kind == DataKind.Float;
+ var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
+ return $"[TlcModule.{GetType().Name}({name}new object[]{{{string.Join(", ", Options.Select(TranslateOption))}}})]";
}
+
+ public override IComparable ProcessedValue() => (IComparable)Options[(int)RawValue];
}
///
- /// The untyped base class for 'maybe'.
+ /// An attribute used to indicate suggested sweep ranges for float parameter sweeping.
///
- [BestFriend]
- internal abstract class Optional
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
+ public sealed class SweepableFloatParamAttribute : SweepableParamAttribute
{
- ///
- /// Whether the value was set 'explicitly', or 'implicitly'.
- ///
- public readonly bool IsExplicit;
+ public float Min { get; }
+ public float Max { get; }
+ public float? StepSize { get; }
+ public int? NumSteps { get; }
+ public bool IsLogScale { get; }
+
+ public SweepableFloatParamAttribute(string name, float min, float max, float stepSize = -1, int numSteps = -1,
+ bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale)
+ {
+ Name = name;
+ }
- public abstract object GetValue();
+ public SweepableFloatParamAttribute(float min, float max, float stepSize = -1, int numSteps = -1, bool isLogScale = false)
+ {
+ Min = min;
+ Max = max;
+ if (!stepSize.Equals(-1))
+ StepSize = stepSize;
+ if (numSteps != -1)
+ NumSteps = numSteps;
+ IsLogScale = isLogScale;
+ }
+
+ public override void SetUsingValueText(string valueText)
+ {
+ RawValue = float.Parse(valueText);
+ }
- private protected Optional(bool isExplicit)
+ public override SweepableParamAttribute Clone() =>
+ new SweepableFloatParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen };
+
+ public override string ToString()
{
- IsExplicit = isExplicit;
+ var optional = new StringBuilder();
+ if (StepSize != null)
+ optional.Append($", stepSize:{StepSize}");
+ if (NumSteps != null)
+ optional.Append($", numSteps:{NumSteps}");
+ if (IsLogScale)
+ optional.Append($", isLogScale:true");
+ var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
+ return $"[TlcModule.{GetType().Name}({name}{Min}f, {Max}f{optional})]";
}
}
///
- /// This is a 'maybe' class that is able to differentiate the cases when the value is set 'explicitly', or 'implicitly'.
- /// The idea is that if the default value is specified by the user, in some cases it needs to be treated differently
- /// than if it's auto-filled.
- ///
- /// An example is the weight column: the default behavior is to use 'Weight' column if it's present. But if the user explicitly sets
- /// the weight column to be 'Weight', we need to actually enforce the presence of the column.
+ /// An attribute used to indicate suggested sweep ranges for long parameter sweeping.
///
- /// The type of the value
- [BestFriend]
- internal sealed class Optional : Optional
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
+ public sealed class SweepableLongParamAttribute : SweepableParamAttribute
{
- public readonly T Value;
+ public long Min { get; }
+ public long Max { get; }
+ public float? StepSize { get; }
+ public int? NumSteps { get; }
+ public bool IsLogScale { get; }
+
+ public SweepableLongParamAttribute(string name, long min, long max, float stepSize = -1, int numSteps = -1,
+ bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale)
+ {
+ Name = name;
+ }
- private Optional(bool isExplicit, T value)
- : base(isExplicit)
+ public SweepableLongParamAttribute(long min, long max, float stepSize = -1, int numSteps = -1, bool isLogScale = false)
{
- Value = value;
+ Min = min;
+ Max = max;
+ if (!stepSize.Equals(-1))
+ StepSize = stepSize;
+ if (numSteps != -1)
+ NumSteps = numSteps;
+ IsLogScale = isLogScale;
}
- ///
- /// Create the 'implicit' value.
- ///
- public static Optional Implicit(T value)
+ public override void SetUsingValueText(string valueText)
{
- return new Optional(false, value);
+ RawValue = long.Parse(valueText);
}
- public static Optional Explicit(T value)
+ public override SweepableParamAttribute Clone() =>
+ new SweepableLongParamAttribute(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen };
+
+ public override string ToString()
{
- return new Optional(true, value);
+ var optional = new StringBuilder();
+ if (StepSize != null)
+ optional.Append($", stepSize:{StepSize}");
+ if (NumSteps != null)
+ optional.Append($", numSteps:{NumSteps}");
+ if (IsLogScale)
+ optional.Append($", isLogScale:true");
+ var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
+ return $"[TlcModule.{GetType().Name}({name}{Min}, {Max}{optional})]";
}
+ }
+ ///
+ /// An attribute to mark an entry point of a module.
+ ///
+ [AttributeUsage(AttributeTargets.Method)]
+ public sealed class EntryPointAttribute : Attribute
+ {
///
- /// The implicit conversion into .
+ /// The entry point name.
///
- public static implicit operator T(Optional optional)
- {
- return optional.Value;
- }
+ public string Name { get; set; }
///
- /// The implicit conversion from .
- /// This will assume that the parameter is set 'explicitly'.
+ /// The entry point description.
///
- public static implicit operator Optional(T value)
- {
- return new Optional(true, value);
- }
+ public string Desc { get; set; }
- public override object GetValue()
- {
- return Value;
- }
+ ///
+ /// UI friendly name. Can contain spaces and other forbidden for Name symbols.
+ ///
+ public string UserName { get; set; }
- public override string ToString()
+ ///
+ /// Short name of the Entry Point
+ ///
+ public string ShortName { get; set; }
+ }
+
+ ///
+ /// The list of data types that are supported as inputs or outputs.
+ ///
+ public enum DataKind
+ {
+ ///
+ /// Not used.
+ ///
+ Unknown = 0,
+ ///
+ /// Integer, including long.
+ ///
+ Int,
+ ///
+ /// Unsigned integer, including ulong.
+ ///
+ UInt,
+ ///
+ /// Floating point, including double.
+ ///
+ Float,
+ ///
+ /// A char.
+ ///
+ Char,
+ ///
+ /// A string.
+ ///
+ String,
+ ///
+ /// A boolean value.
+ ///
+ Bool,
+ ///
+ /// A dataset, represented by an .
+ ///
+ DataView,
+ ///
+ /// A file handle, represented by an .
+ ///
+ FileHandle,
+ ///
+ /// A transform model, represented by an .
+ ///
+ TransformModel,
+ ///
+ /// A predictor model, represented by an .
+ ///
+ PredictorModel,
+ ///
+ /// An enum: one value of a specified list.
+ ///
+ Enum,
+ ///
+ /// An array (0 or more values of the same type, accessible by index).
+ ///
+ Array,
+ ///
+ /// A dictionary (0 or more values of the same type, identified by a unique string key).
+ /// The underlying C# representation is
+ ///
+ Dictionary,
+ ///
+ /// A component of a specified kind. The component is identified by the "load name" (unique per kind) and,
+ /// optionally, a set of parameters, unique to each component. Example: "BinaryClassifierEvaluator{threshold=0.5}".
+ /// The C# representation is .
+ ///
+ Component
+ }
+
+ public static DataKind GetDataType(Type type)
+ {
+ Contracts.AssertValue(type);
+
+ // If this is a Optional-wrapped type, unwrap it and examine
+ // the inner type.
+ if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>)))
+ type = type.GetGenericArguments()[0];
+
+ if (type == typeof(char))
+ return DataKind.Char;
+ if (type == typeof(string))
+ return DataKind.String;
+ if (type == typeof(bool))
+ return DataKind.Bool;
+ if (type == typeof(int) || type == typeof(long))
+ return DataKind.Int;
+ if (type == typeof(uint) || type == typeof(ulong))
+ return DataKind.UInt;
+ if (type == typeof(Single) || type == typeof(Double))
+ return DataKind.Float;
+ if (typeof(IDataView).IsAssignableFrom(type))
+ return DataKind.DataView;
+ if (typeof(TransformModel).IsAssignableFrom(type))
+ return DataKind.TransformModel;
+ if (typeof(PredictorModel).IsAssignableFrom(type))
+ return DataKind.PredictorModel;
+ if (typeof(IFileHandle).IsAssignableFrom(type))
+ return DataKind.FileHandle;
+ if (type.IsEnum)
+ return DataKind.Enum;
+ if (type.IsArray)
+ return DataKind.Array;
+ if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary<,>)
+ && type.GetGenericArguments()[0] == typeof(string))
{
- if (Value == null)
- return "";
- return Value.ToString();
+ return DataKind.Dictionary;
}
+ if (typeof(IComponentFactory).IsAssignableFrom(type))
+ return DataKind.Component;
+
+ return DataKind.Unknown;
+ }
+
+ public static bool IsNumericKind(DataKind kind)
+ {
+ return kind == DataKind.Int || kind == DataKind.UInt || kind == DataKind.Float;
+ }
+}
+
+///
+/// The untyped base class for 'maybe'.
+///
+[BestFriend]
+internal abstract class Optional
+{
+ ///
+ /// Whether the value was set 'explicitly', or 'implicitly'.
+ ///
+ public readonly bool IsExplicit;
+
+ public abstract object GetValue();
+
+ private protected Optional(bool isExplicit)
+ {
+ IsExplicit = isExplicit;
+ }
+}
+
+///
+/// This is a 'maybe' class that is able to differentiate the cases when the value is set 'explicitly', or 'implicitly'.
+/// The idea is that if the default value is specified by the user, in some cases it needs to be treated differently
+/// than if it's auto-filled.
+///
+/// An example is the weight column: the default behavior is to use 'Weight' column if it's present. But if the user explicitly sets
+/// the weight column to be 'Weight', we need to actually enforce the presence of the column.
+///
+/// The type of the value
+[BestFriend]
+internal sealed class Optional : Optional
+{
+ public readonly T Value;
+
+ private Optional(bool isExplicit, T value)
+ : base(isExplicit)
+ {
+ Value = value;
+ }
+
+ ///
+ /// Create the 'implicit' value.
+ ///
+ public static Optional Implicit(T value)
+ {
+ return new Optional(false, value);
+ }
+
+ public static Optional Explicit(T value)
+ {
+ return new Optional(true, value);
+ }
+
+ ///
+ /// The implicit conversion into .
+ ///
+ public static implicit operator T(Optional optional)
+ {
+ return optional.Value;
+ }
+
+ ///
+ /// The implicit conversion from .
+ /// This will assume that the parameter is set 'explicitly'.
+ ///
+ public static implicit operator Optional(T value)
+ {
+ return new Optional(true, value);
+ }
+
+ public override object GetValue()
+ {
+ return Value;
+ }
+
+ public override string ToString()
+ {
+ if (Value == null)
+ return "";
+ return Value.ToString();
}
}
diff --git a/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs
index aef1b8a298..9eeaddadfe 100644
--- a/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs
@@ -6,64 +6,63 @@
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.EntryPoints
+namespace Microsoft.ML.EntryPoints;
+
+///
+/// Base type for standard predictor model port type.
+///
+[BestFriend]
+internal abstract class PredictorModel
{
- ///
- /// Base type for standard predictor model port type.
- ///
[BestFriend]
- internal abstract class PredictorModel
+ private protected PredictorModel()
{
- [BestFriend]
- private protected PredictorModel()
- {
- }
+ }
- ///
- /// Save the model to the given stream.
- ///
- [BestFriend]
- internal abstract void Save(IHostEnvironment env, Stream stream);
+ ///
+ /// Save the model to the given stream.
+ ///
+ [BestFriend]
+ internal abstract void Save(IHostEnvironment env, Stream stream);
- ///
- /// Extract only the transform portion of the predictor model.
- ///
- [BestFriend]
- internal abstract TransformModel TransformModel { get; }
+ ///
+ /// Extract only the transform portion of the predictor model.
+ ///
+ [BestFriend]
+ internal abstract TransformModel TransformModel { get; }
- ///
- /// Extract the predictor object out of the predictor model.
- ///
- [BestFriend]
- internal abstract IPredictor Predictor { get; }
+ ///
+ /// Extract the predictor object out of the predictor model.
+ ///
+ [BestFriend]
+ internal abstract IPredictor Predictor { get; }
- ///
- /// Apply the predictor model to the transform model and return the resulting predictor model.
- ///
- [BestFriend]
- internal abstract PredictorModel Apply(IHostEnvironment env, TransformModel transformModel);
+ ///
+ /// Apply the predictor model to the transform model and return the resulting predictor model.
+ ///
+ [BestFriend]
+ internal abstract PredictorModel Apply(IHostEnvironment env, TransformModel transformModel);
- ///
- /// For a given input data, return role mapped data and the predictor object.
- /// The scoring entry point will hopefully know how to construct a scorer out of them.
- ///
- [BestFriend]
- internal abstract void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor);
+ ///
+ /// For a given input data, return role mapped data and the predictor object.
+ /// The scoring entry point will hopefully know how to construct a scorer out of them.
+ ///
+ [BestFriend]
+ internal abstract void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor);
- ///
- /// Returns a string array containing the label names of the label column type predictor was trained on.
- /// If the training label is a key with text key value annotation, it should return this annotation. The order of the labels should be consistent
- /// with the key values. Otherwise, it returns null.
- ///
- ///
- /// The column type of the label the predictor was trained on.
- [BestFriend]
- internal abstract string[] GetLabelInfo(IHostEnvironment env, out DataViewType labelType);
+ ///
+ /// Returns a string array containing the label names of the label column type predictor was trained on.
+ /// If the training label is a key with text key value annotation, it should return this annotation. The order of the labels should be consistent
+ /// with the key values. Otherwise, it returns null.
+ ///
+ ///
+ /// The column type of the label the predictor was trained on.
+ [BestFriend]
+ internal abstract string[] GetLabelInfo(IHostEnvironment env, out DataViewType labelType);
- ///
- /// Returns the that was used in training.
- ///
- [BestFriend]
- internal abstract RoleMappedSchema GetTrainingSchema(IHostEnvironment env);
- }
+ ///
+ /// Returns the that was used in training.
+ ///
+ [BestFriend]
+ internal abstract RoleMappedSchema GetTrainingSchema(IHostEnvironment env);
}
diff --git a/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs
index 8e2ddb7e66..a3d2eac26c 100644
--- a/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs
@@ -6,65 +6,64 @@
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.EntryPoints
+namespace Microsoft.ML.EntryPoints;
+
+///
+/// Interface for standard transform model port type.
+///
+[BestFriend]
+internal abstract class TransformModel
{
- ///
- /// Interface for standard transform model port type.
- ///
[BestFriend]
- internal abstract class TransformModel
+ private protected TransformModel()
{
- [BestFriend]
- private protected TransformModel()
- {
- }
+ }
- ///
- /// The input schema that this transform model was originally instantiated on.
- /// Note that the schema may have columns that aren't needed by this transform model.
- /// If an exists with this schema, then applying this transform model to it
- /// shouldn't fail because of column type issues.
- ///
- // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
- // however that doing so may cause issues for composing transform models. For example,
- // if transform model A needs column X and model B needs Y, that is NOT produced by A,
- // then trimming A's input schema would cause composition to fail.
- [BestFriend]
- internal abstract DataViewSchema InputSchema { get; }
+ ///
+ /// The input schema that this transform model was originally instantiated on.
+ /// Note that the schema may have columns that aren't needed by this transform model.
+ /// If an exists with this schema, then applying this transform model to it
+ /// shouldn't fail because of column type issues.
+ ///
+ // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
+ // however that doing so may cause issues for composing transform models. For example,
+ // if transform model A needs column X and model B needs Y, that is NOT produced by A,
+ // then trimming A's input schema would cause composition to fail.
+ [BestFriend]
+ internal abstract DataViewSchema InputSchema { get; }
- ///
- /// The output schema that this transform model was originally instantiated on. The schema resulting
- /// from may differ from this, similarly to how
- /// may differ from the schema of dataviews we apply this transform model to.
- ///
- [BestFriend]
- internal abstract DataViewSchema OutputSchema { get; }
+ ///
+ /// The output schema that this transform model was originally instantiated on. The schema resulting
+ /// from may differ from this, similarly to how
+ /// may differ from the schema of dataviews we apply this transform model to.
+ ///
+ [BestFriend]
+ internal abstract DataViewSchema OutputSchema { get; }
- ///
- /// Apply the transform(s) in the model to the given input data.
- ///
- [BestFriend]
- internal abstract IDataView Apply(IHostEnvironment env, IDataView input);
+ ///
+ /// Apply the transform(s) in the model to the given input data.
+ ///
+ [BestFriend]
+ internal abstract IDataView Apply(IHostEnvironment env, IDataView input);
- ///
- /// Apply the transform(s) in the model to the given transform model.
- ///
- [BestFriend]
- internal abstract TransformModel Apply(IHostEnvironment env, TransformModel input);
+ ///
+ /// Apply the transform(s) in the model to the given transform model.
+ ///
+ [BestFriend]
+ internal abstract TransformModel Apply(IHostEnvironment env, TransformModel input);
- ///
- /// Save the model to the given stream.
- ///
- [BestFriend]
- internal abstract void Save(IHostEnvironment env, Stream stream);
+ ///
+ /// Save the model to the given stream.
+ ///
+ [BestFriend]
+ internal abstract void Save(IHostEnvironment env, Stream stream);
- ///
- /// Returns the transform model as an that can output a row
- /// given a row with the same schema as .
- ///
- /// The transform model as an . If not all transforms
- /// in the pipeline are then it returns .
- [BestFriend]
- internal abstract IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx);
- }
+ ///
+ /// Returns the transform model as an that can output a row
+ /// given a row with the same schema as .
+ ///
+ /// The transform model as an . If not all transforms
+ /// in the pipeline are then it returns .
+ [BestFriend]
+ internal abstract IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx);
}