diff --git a/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs b/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs
index 63e893569f..a93e588970 100644
--- a/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs
+++ b/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs
@@ -4,29 +4,28 @@
using System;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+[BestFriend]
+internal static class AnnotationBuilderExtensions
{
- [BestFriend]
- internal static class AnnotationBuilderExtensions
- {
- ///
- /// Add slot names annotation.
- ///
- /// The to which to add the slot names.
- /// The size of the slot names vector.
- /// The getter delegate for the slot names.
- public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter>> getter)
- => builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
+ ///
+ /// Add slot names annotation.
+ ///
+ /// The to which to add the slot names.
+ /// The size of the slot names vector.
+ /// The getter delegate for the slot names.
+ public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter>> getter)
+ => builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
- ///
- /// Add key values annotation.
- ///
- /// The value type of key values.
- /// The to which to add the key values.
- /// The size of key values vector.
- /// The value type of key values. Its raw type must match .
- /// The getter delegate for the key values.
- public static void AddKeyValues(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter> getter)
- => builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
- }
+ ///
+ /// Add key values annotation.
+ ///
+ /// The value type of key values.
+ /// The to which to add the key values.
+ /// The size of key values vector.
+ /// The value type of key values. Its raw type must match .
+ /// The getter delegate for the key values.
+ public static void AddKeyValues(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter> getter)
+ => builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
}
diff --git a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs
index f942e4755e..257eb73728 100644
--- a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs
+++ b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs
@@ -9,493 +9,492 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// Utilities for implementing and using the annotation API of .
+///
+[BestFriend]
+internal static class AnnotationUtils
{
///
- /// Utilities for implementing and using the annotation API of .
+ /// This class lists the canonical annotation kinds
///
- [BestFriend]
- internal static class AnnotationUtils
+ public static class Kinds
{
///
- /// This class lists the canonical annotation kinds
+ /// Annotation kind for names associated with slots/positions in a vector-valued column.
+ /// The associated annotation type is typically fixed-sized vector of Text.
///
- public static class Kinds
- {
- ///
- /// Annotation kind for names associated with slots/positions in a vector-valued column.
- /// The associated annotation type is typically fixed-sized vector of Text.
- ///
- public const string SlotNames = "SlotNames";
-
- ///
- /// Annotation kind for values associated with the key indices when the column type's item type
- /// is a key type. The associated annotation type is typically fixed-sized vector of a primitive
- /// type. The primitive type is frequently Text, but can be anything.
- ///
- public const string KeyValues = "KeyValues";
-
- ///
- /// Annotation kind for sets of score columns. The value is typically a with raw type U4.
- ///
- public const string ScoreColumnSetId = "ScoreColumnSetId";
-
- ///
- /// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification".
- /// The value is typically a ReadOnlyMemory<char>.
- ///
- public const string ScoreColumnKind = "ScoreColumnKind";
-
- ///
- /// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory.
- ///
- public const string ScoreValueKind = "ScoreValueKind";
-
- ///
- /// Annotation kind that indicates if a column is normalized. The value is typically a Bool.
- ///
- public const string IsNormalized = "IsNormalized";
-
- ///
- /// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool.
- /// Not to be confused with IsHidden() that determines if a column is masked.
- ///
- public const string IsUserVisible = "IsUserVisible";
-
- ///
- /// Annotation kind for the label values used in training to be used for the predicted label.
- /// The value is typically a fixed-sized vector of Text.
- ///
- public const string TrainingLabelValues = "TrainingLabelValues";
-
- ///
- /// Annotation kind that indicates the ranges within a column that are categorical features.
- /// The value is a vector type of ints with dimension of two. The first dimension
- /// represents the number of categorical features and second dimension represents the range
- /// and is of size two. The range has start and end index(both inclusive) of categorical
- /// slots within that column.
- ///
- public const string CategoricalSlotRanges = "CategoricalSlotRanges";
- }
+ public const string SlotNames = "SlotNames";
///
- /// This class holds all pre-defined string values that can be found in canonical annotations
+ /// Annotation kind for values associated with the key indices when the column type's item type
+ /// is a key type. The associated annotation type is typically fixed-sized vector of a primitive
+ /// type. The primitive type is frequently Text, but can be anything.
///
- public static class Const
- {
- public static class ScoreColumnKind
- {
- public const string BinaryClassification = "BinaryClassification";
- public const string MulticlassClassification = "MulticlassClassification";
- public const string Regression = "Regression";
- public const string Ranking = "Ranking";
- public const string Clustering = "Clustering";
- public const string MultiOutputRegression = "MultiOutputRegression";
- public const string AnomalyDetection = "AnomalyDetection";
- public const string SequenceClassification = "SequenceClassification";
- public const string QuantileRegression = "QuantileRegression";
- public const string Recommender = "Recommender";
- public const string ItemSimilarity = "ItemSimilarity";
- public const string FeatureContribution = "FeatureContribution";
- }
-
- public static class ScoreValueKind
- {
- public const string Score = "Score";
- public const string PredictedLabel = "PredictedLabel";
- public const string Probability = "Probability";
- }
- }
+ public const string KeyValues = "KeyValues";
///
- /// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below.
+ /// Annotation kind for sets of score columns. The value is typically a with raw type U4.
///
- public delegate void AnnotationGetter(int col, ref TValue dst);
+ public const string ScoreColumnSetId = "ScoreColumnSetId";
///
- /// Returns a standard exception for responding to an invalid call to GetAnnotation.
+ /// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification".
+ /// The value is typically a ReadOnlyMemory<char>.
///
- public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation");
+ public const string ScoreColumnKind = "ScoreColumnKind";
///
- /// Returns a standard exception for responding to an invalid call to GetAnnotation.
+ /// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory.
///
- public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation");
+ public const string ScoreValueKind = "ScoreValueKind";
///
- /// Helper to marshal a call to GetAnnotation{TValue} to a specific type.
+ /// Annotation kind that indicates if a column is normalized. The value is typically a Bool.
///
- public static void Marshal(this AnnotationGetter getter, int col, ref TNeed dst)
- {
- Contracts.CheckValue(getter, nameof(getter));
-
- if (typeof(TNeed) != typeof(THave))
- throw ExceptGetAnnotation();
- var get = (AnnotationGetter)(Delegate)getter;
- get(col, ref dst);
- }
+ public const string IsNormalized = "IsNormalized";
///
- /// Returns a vector type with item type text and the given size. The size must be positive.
- /// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
+ /// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool.
+ /// Not to be confused with IsHidden() that determines if a column is masked.
///
- public static VectorDataViewType GetNamesType(int size)
- {
- Contracts.CheckParam(size > 0, nameof(size), "must be known size");
- return new VectorDataViewType(TextDataViewType.Instance, size);
- }
+ public const string IsUserVisible = "IsUserVisible";
///
- /// Returns a vector type with item type int and the given size.
- /// The range count must be a positive integer.
- /// This is a standard type for annotation consisting of multiple int values that represent
- /// categorical slot ranges with in a column.
+ /// Annotation kind for the label values used in training to be used for the predicted label.
+ /// The value is typically a fixed-sized vector of Text.
///
- public static VectorDataViewType GetCategoricalType(int rangeCount)
- {
- Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
- return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
- }
-
- private static volatile KeyDataViewType _scoreColumnSetIdType;
+ public const string TrainingLabelValues = "TrainingLabelValues";
///
- /// The type of the ScoreColumnSetId annotation.
+ /// Annotation kind that indicates the ranges within a column that are categorical features.
+ /// The value is a vector type of ints with dimension of two. The first dimension
+ /// represents the number of categorical features and second dimension represents the range
+ /// and is of size two. The range has start and end index(both inclusive) of categorical
+ /// slots within that column.
///
- public static KeyDataViewType ScoreColumnSetIdType
- {
- get
- {
- return _scoreColumnSetIdType ??
- Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
- _scoreColumnSetIdType;
- }
- }
+ public const string CategoricalSlotRanges = "CategoricalSlotRanges";
+ }
- ///
- /// Returns a key-value pair useful when implementing GetAnnotationTypes(col).
- ///
- public static KeyValuePair GetSlotNamesPair(int size)
+ ///
+ /// This class holds all pre-defined string values that can be found in canonical annotations
+ ///
+ public static class Const
+ {
+ public static class ScoreColumnKind
{
- return GetNamesType(size).GetPair(Kinds.SlotNames);
+ public const string BinaryClassification = "BinaryClassification";
+ public const string MulticlassClassification = "MulticlassClassification";
+ public const string Regression = "Regression";
+ public const string Ranking = "Ranking";
+ public const string Clustering = "Clustering";
+ public const string MultiOutputRegression = "MultiOutputRegression";
+ public const string AnomalyDetection = "AnomalyDetection";
+ public const string SequenceClassification = "SequenceClassification";
+ public const string QuantileRegression = "QuantileRegression";
+ public const string Recommender = "Recommender";
+ public const string ItemSimilarity = "ItemSimilarity";
+ public const string FeatureContribution = "FeatureContribution";
}
- ///
- /// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes
- /// that the values of the key type are Text.
- ///
- public static KeyValuePair GetKeyNamesPair(int size)
+ public static class ScoreValueKind
{
- return GetNamesType(size).GetPair(Kinds.KeyValues);
+ public const string Score = "Score";
+ public const string PredictedLabel = "PredictedLabel";
+ public const string Probability = "Probability";
}
+ }
- ///
- /// Given a type and annotation kind string, returns a key-value pair. This is useful when
- /// implementing GetAnnotationTypes(col).
- ///
- public static KeyValuePair GetPair(this DataViewType type, string kind)
- {
- Contracts.CheckValue(type, nameof(type));
- return new KeyValuePair(kind, type);
- }
+ ///
+ /// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below.
+ ///
+ public delegate void AnnotationGetter(int col, ref TValue dst);
- // REVIEW: This should be in some general utility code.
+ ///
+ /// Returns a standard exception for responding to an invalid call to GetAnnotation.
+ ///
+ public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation");
- ///
- /// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes.
- ///
- public static IEnumerable Prepend(this IEnumerable tail, params T[] head)
+ ///
+ /// Returns a standard exception for responding to an invalid call to GetAnnotation.
+ ///
+ public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation");
+
+ ///
+ /// Helper to marshal a call to GetAnnotation{TValue} to a specific type.
+ ///
+ public static void Marshal(this AnnotationGetter getter, int col, ref TNeed dst)
+ {
+ Contracts.CheckValue(getter, nameof(getter));
+
+ if (typeof(TNeed) != typeof(THave))
+ throw ExceptGetAnnotation();
+ var get = (AnnotationGetter)(Delegate)getter;
+ get(col, ref dst);
+ }
+
+ ///
+ /// Returns a vector type with item type text and the given size. The size must be positive.
+ /// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
+ ///
+ public static VectorDataViewType GetNamesType(int size)
+ {
+ Contracts.CheckParam(size > 0, nameof(size), "must be known size");
+ return new VectorDataViewType(TextDataViewType.Instance, size);
+ }
+
+ ///
+ /// Returns a vector type with item type int and the given size.
+ /// The range count must be a positive integer.
+ /// This is a standard type for annotation consisting of multiple int values that represent
+ /// categorical slot ranges with in a column.
+ ///
+ public static VectorDataViewType GetCategoricalType(int rangeCount)
+ {
+ Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
+ return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
+ }
+
+ private static volatile KeyDataViewType _scoreColumnSetIdType;
+
+ ///
+ /// The type of the ScoreColumnSetId annotation.
+ ///
+ public static KeyDataViewType ScoreColumnSetIdType
+ {
+ get
{
- return head.Concat(tail);
+ return _scoreColumnSetIdType ??
+ Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
+ _scoreColumnSetIdType;
}
+ }
- ///
- /// Returns the max value for the specified annotation kind.
- /// The annotation type should be a with raw type U4.
- /// colMax will be set to the first column that has the max value for the specified annotation.
- /// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
- /// The filter function is called for each column, passing in the schema and the column index, and returns
- /// true if the column should be considered, false if the column should be skipped.
- ///
- public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func filterFunc = null)
+ ///
+ /// Returns a key-value pair useful when implementing GetAnnotationTypes(col).
+ ///
+ public static KeyValuePair GetSlotNamesPair(int size)
+ {
+ return GetNamesType(size).GetPair(Kinds.SlotNames);
+ }
+
+ ///
+ /// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes
+ /// that the values of the key type are Text.
+ ///
+ public static KeyValuePair GetKeyNamesPair(int size)
+ {
+ return GetNamesType(size).GetPair(Kinds.KeyValues);
+ }
+
+ ///
+ /// Given a type and annotation kind string, returns a key-value pair. This is useful when
+ /// implementing GetAnnotationTypes(col).
+ ///
+ public static KeyValuePair GetPair(this DataViewType type, string kind)
+ {
+ Contracts.CheckValue(type, nameof(type));
+ return new KeyValuePair(kind, type);
+ }
+
+ // REVIEW: This should be in some general utility code.
+
+ ///
+ /// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes.
+ ///
+ public static IEnumerable Prepend(this IEnumerable tail, params T[] head)
+ {
+ return head.Concat(tail);
+ }
+
+ ///
+ /// Returns the max value for the specified annotation kind.
+ /// The annotation type should be a with raw type U4.
+ /// colMax will be set to the first column that has the max value for the specified annotation.
+ /// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
+ /// The filter function is called for each column, passing in the schema and the column index, and returns
+ /// true if the column should be considered, false if the column should be skipped.
+ ///
+ public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func filterFunc = null)
+ {
+ uint max = 0;
+ colMax = -1;
+ for (int col = 0; col < schema.Count; col++)
{
- uint max = 0;
- colMax = -1;
- for (int col = 0; col < schema.Count; col++)
+ var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
+ if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
+ continue;
+ if (filterFunc != null && !filterFunc(schema, col))
+ continue;
+ uint value = 0;
+ schema[col].Annotations.GetValue(annotationKind, ref value);
+ if (max < value)
{
- var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
- if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
- continue;
- if (filterFunc != null && !filterFunc(schema, col))
- continue;
- uint value = 0;
- schema[col].Annotations.GetValue(annotationKind, ref value);
- if (max < value)
- {
- max = value;
- colMax = col;
- }
+ max = value;
+ colMax = col;
}
- return max;
}
+ return max;
+ }
- ///
- /// Returns the set of column ids which match the value of specified annotation kind.
- /// The annotation type should be a with raw type U4.
- ///
- public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
+ ///
+ /// Returns the set of column ids which match the value of specified annotation kind.
+ /// The annotation type should be a with raw type U4.
+ ///
+ public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
+ {
+ for (int col = 0; col < schema.Count; col++)
{
- for (int col = 0; col < schema.Count; col++)
+ var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
+ if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
{
- var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
- if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
- {
- uint val = 0;
- schema[col].Annotations.GetValue(annotationKind, ref val);
- if (val == value)
- yield return col;
- }
+ uint val = 0;
+ schema[col].Annotations.GetValue(annotationKind, ref val);
+ if (val == value)
+ yield return col;
}
}
+ }
- ///
- /// Returns the set of column ids which match the value of specified annotation kind.
- /// The annotation type should be of type text.
- ///
- public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, string value)
+ ///
+ /// Returns the set of column ids which match the value of specified annotation kind.
+ /// The annotation type should be of type text.
+ ///
+ public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, string value)
+ {
+ for (int col = 0; col < schema.Count; col++)
{
- for (int col = 0; col < schema.Count; col++)
+ var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
+ if (columnType is TextDataViewType)
{
- var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
- if (columnType is TextDataViewType)
- {
- ReadOnlyMemory val = default;
- schema[col].Annotations.GetValue(annotationKind, ref val);
- if (ReadOnlyMemoryUtils.EqualsStr(value, val))
- yield return col;
- }
+ ReadOnlyMemory val = default;
+ schema[col].Annotations.GetValue(annotationKind, ref val);
+ if (ReadOnlyMemoryUtils.EqualsStr(value, val))
+ yield return col;
}
}
+ }
- ///
- /// Returns true if the specified column:
- /// * has a SlotNames annotation
- /// * annotation type is VBuffer<ReadOnlyMemory<char>> of length .
- ///
- public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
- {
- if (vectorSize == 0)
- return false;
-
- var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
- return
- metaColumn != null
- && metaColumn.Value.Type is VectorDataViewType vectorType
- && vectorType.Size == vectorSize
- && vectorType.ItemType is TextDataViewType;
- }
+ ///
+ /// Returns true if the specified column:
+ /// * has a SlotNames annotation
+ /// * annotation type is VBuffer<ReadOnlyMemory<char>> of length .
+ ///
+ public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
+ {
+ if (vectorSize == 0)
+ return false;
+
+ var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
+ return
+ metaColumn != null
+ && metaColumn.Value.Type is VectorDataViewType vectorType
+ && vectorType.Size == vectorSize
+ && vectorType.ItemType is TextDataViewType;
+ }
- public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames)
- {
- Contracts.CheckValueOrNull(schema);
- Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
-
- IReadOnlyList list = schema?.GetColumns(role);
- if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
- VBufferUtils.Resize(ref slotNames, vectorSize, 0);
- else
- schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
- }
+ public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames)
+ {
+ Contracts.CheckValueOrNull(schema);
+ Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
+
+ IReadOnlyList list = schema?.GetColumns(role);
+ if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
+ VBufferUtils.Resize(ref slotNames, vectorSize, 0);
+ else
+ schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
+ }
- public static bool NeedsSlotNames(this SchemaShape.Column col)
- {
- return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
- && metaCol.Kind == SchemaShape.Column.VectorKind.Vector
- && metaCol.ItemType is TextDataViewType;
- }
+ public static bool NeedsSlotNames(this SchemaShape.Column col)
+ {
+ return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
+ && metaCol.Kind == SchemaShape.Column.VectorKind.Vector
+ && metaCol.ItemType is TextDataViewType;
+ }
- ///
- /// Returns whether a column has the annotation indicated by
- /// the schema shape.
- ///
- /// The schema shape column to query
- /// True if and only if the column has the annotation
- /// of a scalar type, which we assume, if set, should be true.
- public static bool IsNormalized(this SchemaShape.Column column)
- {
- Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
- return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol)
- && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
- && metaCol.ItemType == BooleanDataViewType.Instance;
- }
+ ///
+ /// Returns whether a column has the annotation indicated by
+ /// the schema shape.
+ ///
+ /// The schema shape column to query
+ /// True if and only if the column has the annotation
+ /// of a scalar type, which we assume, if set, should be true.
+ public static bool IsNormalized(this SchemaShape.Column column)
+ {
+ Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
+ return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol)
+ && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
+ && metaCol.ItemType == BooleanDataViewType.Instance;
+ }
- ///
- /// Returns whether a column has the annotation indicated by
- /// the schema shape.
- ///
- /// The schema shape column to query
- /// True if and only if the column is a definite sized vector type, has the
- /// annotation of definite sized vectors of text.
- public static bool HasSlotNames(this SchemaShape.Column col)
- {
- Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
- return col.Kind == SchemaShape.Column.VectorKind.Vector
- && col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol)
- && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
- && metaCol.ItemType == TextDataViewType.Instance;
- }
+ ///
+ /// Returns whether a column has the annotation indicated by
+ /// the schema shape.
+ ///
+ /// The schema shape column to query
+ /// True if and only if the column is a definite sized vector type, has the
+ /// annotation of definite sized vectors of text.
+ public static bool HasSlotNames(this SchemaShape.Column col)
+ {
+ Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
+ return col.Kind == SchemaShape.Column.VectorKind.Vector
+ && col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol)
+ && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
+ && metaCol.ItemType == TextDataViewType.Instance;
+ }
- ///
- /// Tries to get the annotation kind of the specified type for a column.
- ///
- /// The raw type of the annotation, should match the PrimitiveType type
- /// The schema
- /// The type of the annotation
- /// The annotation kind
- /// The column
- /// The value to return, if successful
- /// True if the annotation of the right type exists, false otherwise
- public static bool TryGetAnnotation(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
- {
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(type, nameof(type));
-
- var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
- if (!type.Equals(annotationType))
- return false;
- schema[col].Annotations.GetValue(kind, ref value);
- return true;
- }
+ ///
+ /// Tries to get the annotation kind of the specified type for a column.
+ ///
+ /// The raw type of the annotation, should match the PrimitiveType type
+ /// The schema
+ /// The type of the annotation
+ /// The annotation kind
+ /// The column
+ /// The value to return, if successful
+ /// True if the annotation of the right type exists, false otherwise
+ public static bool TryGetAnnotation(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ Contracts.CheckValue(type, nameof(type));
+
+ var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
+ if (!type.Equals(annotationType))
+ return false;
+ schema[col].Annotations.GetValue(kind, ref value);
+ return true;
+ }
- ///
- /// The categoricalFeatures is a vector of the indices of categorical features slots.
- /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
- /// So if its value is the range of numbers: 0,2,3,4,8,9
- /// look at it as [0,2],[3,4],[8,9].
- /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
- /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
- ///
- public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
- {
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.Check(colIndex >= 0, nameof(colIndex));
+ ///
+ /// The categoricalFeatures is a vector of the indices of categorical features slots.
+ /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
+ /// So if its value is the range of numbers: 0,2,3,4,8,9
+ /// look at it as [0,2],[3,4],[8,9].
+ /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
+ /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
+ ///
+ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ Contracts.Check(colIndex >= 0, nameof(colIndex));
- bool isValid = false;
- categoricalFeatures = null;
- if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
- return isValid;
+ bool isValid = false;
+ categoricalFeatures = null;
+ if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
+ return isValid;
- var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
- if (type?.RawType == typeof(VBuffer))
+ var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
+ if (type?.RawType == typeof(VBuffer))
+ {
+ VBuffer catIndices = default(VBuffer);
+ schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
+ VBufferUtils.Densify(ref catIndices);
+ int columnSlotsCount = vecType.Size;
+ if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
{
- VBuffer catIndices = default(VBuffer);
- schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
- VBufferUtils.Densify(ref catIndices);
- int columnSlotsCount = vecType.Size;
- if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
+ int previousEndIndex = -1;
+ isValid = true;
+ var catIndicesValues = catIndices.GetValues();
+ for (int i = 0; i < catIndicesValues.Length; i += 2)
{
- int previousEndIndex = -1;
- isValid = true;
- var catIndicesValues = catIndices.GetValues();
- for (int i = 0; i < catIndicesValues.Length; i += 2)
+ if (catIndicesValues[i] > catIndicesValues[i + 1] ||
+ catIndicesValues[i] <= previousEndIndex ||
+ catIndicesValues[i] >= columnSlotsCount ||
+ catIndicesValues[i + 1] >= columnSlotsCount)
{
- if (catIndicesValues[i] > catIndicesValues[i + 1] ||
- catIndicesValues[i] <= previousEndIndex ||
- catIndicesValues[i] >= columnSlotsCount ||
- catIndicesValues[i + 1] >= columnSlotsCount)
- {
- isValid = false;
- break;
- }
-
- previousEndIndex = catIndicesValues[i + 1];
+ isValid = false;
+ break;
}
- if (isValid)
- categoricalFeatures = catIndicesValues.ToArray();
+
+ previousEndIndex = catIndicesValues[i + 1];
}
+ if (isValid)
+ categoricalFeatures = catIndicesValues.ToArray();
}
-
- return isValid;
}
- ///
- /// Produces sequence of columns that are generated by trainer estimators.
- ///
- /// whether we should also append 'IsNormalized' (typically for probability column)
- public static IEnumerable GetTrainerOutputAnnotation(bool isNormalized = false)
- {
- var cols = new List();
- cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
- cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
- cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
- if (isNormalized)
- cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
- return cols;
- }
+ return isValid;
+ }
- ///
- /// Produces annotations for the score column generated by trainer estimators for multiclass classification.
- /// If input LabelColumn is not available it produces slotnames annotation by default.
- ///
- /// Label column.
- public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
+ ///
+ /// Produces sequence of columns that are generated by trainer estimators.
+ ///
+ /// whether we should also append 'IsNormalized' (typically for probability column)
+ public static IEnumerable GetTrainerOutputAnnotation(bool isNormalized = false)
+ {
+ var cols = new List();
+ cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
+ cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
+ cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
+ if (isNormalized)
+ cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
+ return cols;
+ }
+
+ ///
+ /// Produces annotations for the score column generated by trainer estimators for multiclass classification.
+ /// If input LabelColumn is not available it produces slotnames annotation by default.
+ ///
+ /// Label column.
+ public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
+ {
+ var cols = new List();
+ if (labelColumn != null && labelColumn.Value.IsKey)
{
- var cols = new List();
- if (labelColumn != null && labelColumn.Value.IsKey)
+ if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
+ metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
{
- if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
- metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
- {
- if (metaCol.ItemType is TextDataViewType)
- cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
- cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
- }
+ if (metaCol.ItemType is TextDataViewType)
+ cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
+ cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
}
- cols.AddRange(GetTrainerOutputAnnotation());
- return cols;
}
+ cols.AddRange(GetTrainerOutputAnnotation());
+ return cols;
+ }
+
+ private sealed class AnnotationRow : DataViewRow
+ {
+ private readonly DataViewSchema.Annotations _annotations;
- private sealed class AnnotationRow : DataViewRow
+ public AnnotationRow(DataViewSchema.Annotations annotations)
{
- private readonly DataViewSchema.Annotations _annotations;
+ Contracts.AssertValue(annotations);
+ _annotations = annotations;
+ }
- public AnnotationRow(DataViewSchema.Annotations annotations)
- {
- Contracts.AssertValue(annotations);
- _annotations = annotations;
- }
+ public override DataViewSchema Schema => _annotations.Schema;
+ public override long Position => 0;
+ public override long Batch => 0;
- public override DataViewSchema Schema => _annotations.Schema;
- public override long Position => 0;
- public override long Batch => 0;
-
- ///
- /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
- /// This throws if the column is not active in this row, or if the type
- /// differs from this column's type.
- ///
- /// is the column's content type.
- /// is the output column whose getter should be returned.
- public override ValueGetter GetGetter(DataViewSchema.Column column) => _annotations.GetGetter(column);
-
- public override ValueGetter GetIdGetter() => (ref DataViewRowId dst) => dst = default;
-
- ///
- /// Returns whether the given column is active in this row.
- ///
- public override bool IsColumnActive(DataViewSchema.Column column) => true;
- }
+ ///
+ /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
+ /// This throws if the column is not active in this row, or if the type
+ /// differs from this column's type.
+ ///
+ /// is the column's content type.
+ /// is the output column whose getter should be returned.
+ public override ValueGetter GetGetter(DataViewSchema.Column column) => _annotations.GetGetter(column);
+
+ public override ValueGetter GetIdGetter() => (ref DataViewRowId dst) => dst = default;
///
- /// Presents a as a an .
+ /// Returns whether the given column is active in this row.
///
- /// The annotations to wrap.
- /// A row that wraps an input annotations.
- [BestFriend]
- internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations)
- {
- Contracts.CheckValue(annotations, nameof(annotations));
- return new AnnotationRow(annotations);
- }
+ public override bool IsColumnActive(DataViewSchema.Column column) => true;
+ }
+
+ ///
+ /// Presents a as a an .
+ ///
+ /// The annotations to wrap.
+ /// A row that wraps an input annotations.
+ [BestFriend]
+ internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations)
+ {
+ Contracts.CheckValue(annotations, nameof(annotations));
+ return new AnnotationRow(annotations);
}
}
diff --git a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
index d1465f2c03..d595ff4d06 100644
--- a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
@@ -5,165 +5,164 @@
using System;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// Extension methods related to the ColumnType class.
+///
+[BestFriend]
+internal static class ColumnTypeExtensions
{
///
- /// Extension methods related to the ColumnType class.
+ /// Whether this type is a standard scalar type completely determined by its
+ /// (not a or , etc).
+ ///
+ public static bool IsStandardScalar(this DataViewType columnType) =>
+ (columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
+ (columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) ||
+ (columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType);
+
+ ///
+ /// Zero return means it's not a key type.
+ ///
+ public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
+
+ ///
+ /// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
+ /// Zero return means it's not a key type.
///
- [BestFriend]
- internal static class ColumnTypeExtensions
+ public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
{
- ///
- /// Whether this type is a standard scalar type completely determined by its
- /// (not a or , etc).
- ///
- public static bool IsStandardScalar(this DataViewType columnType) =>
- (columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
- (columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) ||
- (columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType);
-
- ///
- /// Zero return means it's not a key type.
- ///
- public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
-
- ///
- /// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
- /// Zero return means it's not a key type.
- ///
- public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
- {
- ulong count = columnType.GetKeyCount();
- ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
- return (int)count;
- }
+ ulong count = columnType.GetKeyCount();
+ ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
+ return (int)count;
+ }
- ///
- /// For non-vector types, this returns the column type itself (i.e., return ).
- /// For vector types, this returns the type of the items stored as values in vector.
- ///
- public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
-
- ///
- /// Zero return means either it's not a vector or the size is unknown.
- ///
- public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
-
- ///
- /// For non-vectors, this returns one. For unknown size vectors, it returns zero.
- /// For known sized vectors, it returns size.
- ///
- public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
-
- ///
- /// Whether this is a vector type with known size. Returns false for non-vector types.
- /// Equivalent to > 0.
- ///
- public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0;
-
- ///
- /// Gets the equivalent for the 's RawType.
- /// This can return default() if the RawType doesn't have a corresponding
- /// .
- ///
- public static InternalDataKind GetRawKind(this DataViewType columnType)
- {
- columnType.RawType.TryGetDataKind(out InternalDataKind result);
- return result;
- }
+ ///
+ /// For non-vector types, this returns the column type itself (i.e., return ).
+ /// For vector types, this returns the type of the items stored as values in vector.
+ ///
+ public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
- ///
- /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
- /// returns true if current and other vector types have the same size and item type.
- ///
- public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
- {
- if (other == null)
- return false;
-
- if (columnType.Equals(other))
- return true;
-
- // For vector types, we don't care about the factoring of the dimensions.
- if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
- return false;
- if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
- return false;
- return vectorType.Size == otherVectorType.Size;
- }
+ ///
+ /// Zero return means either it's not a vector or the size is unknown.
+ ///
+ public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
- public static PrimitiveDataViewType PrimitiveTypeFromType(Type type)
- {
- if (type == typeof(ReadOnlyMemory) || type == typeof(string))
- return TextDataViewType.Instance;
- if (type == typeof(bool))
- return BooleanDataViewType.Instance;
- if (type == typeof(TimeSpan))
- return TimeSpanDataViewType.Instance;
- if (type == typeof(DateTime))
- return DateTimeDataViewType.Instance;
- if (type == typeof(DateTimeOffset))
- return DateTimeOffsetDataViewType.Instance;
- if (type == typeof(DataViewRowId))
- return RowIdDataViewType.Instance;
- return NumberTypeFromType(type);
- }
+ ///
+ /// For non-vectors, this returns one. For unknown size vectors, it returns zero.
+ /// For known sized vectors, it returns size.
+ ///
+ public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
- public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind)
- {
- if (kind == InternalDataKind.TX)
- return TextDataViewType.Instance;
- if (kind == InternalDataKind.BL)
- return BooleanDataViewType.Instance;
- if (kind == InternalDataKind.TS)
- return TimeSpanDataViewType.Instance;
- if (kind == InternalDataKind.DT)
- return DateTimeDataViewType.Instance;
- if (kind == InternalDataKind.DZ)
- return DateTimeOffsetDataViewType.Instance;
- if (kind == InternalDataKind.UG)
- return RowIdDataViewType.Instance;
- return NumberTypeFromKind(kind);
- }
+ ///
+ /// Whether this is a vector type with known size. Returns false for non-vector types.
+ /// Equivalent to > 0.
+ ///
+ public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0;
- public static NumberDataViewType NumberTypeFromType(Type type)
- {
- InternalDataKind kind;
- if (type.TryGetDataKind(out kind))
- return NumberTypeFromKind(kind);
+ ///
+ /// Gets the equivalent for the 's RawType.
+ /// This can return default() if the RawType doesn't have a corresponding
+ /// .
+ ///
+ public static InternalDataKind GetRawKind(this DataViewType columnType)
+ {
+ columnType.RawType.TryGetDataKind(out InternalDataKind result);
+ return result;
+ }
- Contracts.Assert(false);
- throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
- }
+ ///
+ /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
+ /// returns true if current and other vector types have the same size and item type.
+ ///
+ public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
+ {
+ if (other == null)
+ return false;
+
+ if (columnType.Equals(other))
+ return true;
+
+ // For vector types, we don't care about the factoring of the dimensions.
+ if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
+ return false;
+ if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
+ return false;
+ return vectorType.Size == otherVectorType.Size;
+ }
- private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind)
+ public static PrimitiveDataViewType PrimitiveTypeFromType(Type type)
+ {
+ if (type == typeof(ReadOnlyMemory) || type == typeof(string))
+ return TextDataViewType.Instance;
+ if (type == typeof(bool))
+ return BooleanDataViewType.Instance;
+ if (type == typeof(TimeSpan))
+ return TimeSpanDataViewType.Instance;
+ if (type == typeof(DateTime))
+ return DateTimeDataViewType.Instance;
+ if (type == typeof(DateTimeOffset))
+ return DateTimeOffsetDataViewType.Instance;
+ if (type == typeof(DataViewRowId))
+ return RowIdDataViewType.Instance;
+ return NumberTypeFromType(type);
+ }
+
+ public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind)
+ {
+ if (kind == InternalDataKind.TX)
+ return TextDataViewType.Instance;
+ if (kind == InternalDataKind.BL)
+ return BooleanDataViewType.Instance;
+ if (kind == InternalDataKind.TS)
+ return TimeSpanDataViewType.Instance;
+ if (kind == InternalDataKind.DT)
+ return DateTimeDataViewType.Instance;
+ if (kind == InternalDataKind.DZ)
+ return DateTimeOffsetDataViewType.Instance;
+ if (kind == InternalDataKind.UG)
+ return RowIdDataViewType.Instance;
+ return NumberTypeFromKind(kind);
+ }
+
+ public static NumberDataViewType NumberTypeFromType(Type type)
+ {
+ InternalDataKind kind;
+ if (type.TryGetDataKind(out kind))
+ return NumberTypeFromKind(kind);
+
+ Contracts.Assert(false);
+ throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
+ }
+
+ private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind)
+ {
+ switch (kind)
{
- switch (kind)
- {
- case InternalDataKind.I1:
- return NumberDataViewType.SByte;
- case InternalDataKind.U1:
- return NumberDataViewType.Byte;
- case InternalDataKind.I2:
- return NumberDataViewType.Int16;
- case InternalDataKind.U2:
- return NumberDataViewType.UInt16;
- case InternalDataKind.I4:
- return NumberDataViewType.Int32;
- case InternalDataKind.U4:
- return NumberDataViewType.UInt32;
- case InternalDataKind.I8:
- return NumberDataViewType.Int64;
- case InternalDataKind.U8:
- return NumberDataViewType.UInt64;
- case InternalDataKind.R4:
- return NumberDataViewType.Single;
- case InternalDataKind.R8:
- return NumberDataViewType.Double;
- }
-
- Contracts.Assert(false);
- throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
+ case InternalDataKind.I1:
+ return NumberDataViewType.SByte;
+ case InternalDataKind.U1:
+ return NumberDataViewType.Byte;
+ case InternalDataKind.I2:
+ return NumberDataViewType.Int16;
+ case InternalDataKind.U2:
+ return NumberDataViewType.UInt16;
+ case InternalDataKind.I4:
+ return NumberDataViewType.Int32;
+ case InternalDataKind.U4:
+ return NumberDataViewType.UInt32;
+ case InternalDataKind.I8:
+ return NumberDataViewType.Int64;
+ case InternalDataKind.U8:
+ return NumberDataViewType.UInt64;
+ case InternalDataKind.R4:
+ return NumberDataViewType.Single;
+ case InternalDataKind.R8:
+ return NumberDataViewType.Double;
}
+
+ Contracts.Assert(false);
+ throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
}
}
diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs
index 3baf7b5511..567819278e 100644
--- a/src/Microsoft.ML.Core/Data/DataKind.cs
+++ b/src/Microsoft.ML.Core/Data/DataKind.cs
@@ -4,380 +4,379 @@
using System;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// Specifies a simple data type.
+///
+///
+/// or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory` | |
+/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` |
+/// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value |
+///
+/// The table below shows the missing value definition for each of the data types.
+///
+/// | Type | Missing Value | IsMissing Indicator |
+/// | -- | -- | -- |
+/// | or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` |
+/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` |
+/// | | | |
+/// | | | |
+/// | All other types | Not defined | Always `false` |
+///
+/// ]]>
+///
+///
+// Data type specifiers mainly used in creating text loader and type converter.
+public enum DataKind : byte
{
+ /// 1-byte integer, type of .
+ SByte = 1,
+ /// 1-byte unsigned integer, type of .
+ Byte = 2,
+ /// 2-byte integer, type of .
+ Int16 = 3,
+ /// 2-byte unsigned integer, type of .
+ UInt16 = 4,
+ /// 4-byte integer, type of .
+ Int32 = 5,
+ /// 4-byte unsigned integer, type of .
+ UInt32 = 6,
+ /// 8-byte integer, type of .
+ Int64 = 7,
+ /// 8-byte unsigned integer, type of .
+ UInt64 = 8,
+ /// 4-byte floating-point number, type of .
+ Single = 9,
+ /// 8-byte floating-point number, type of .
+ Double = 10,
///
- /// Specifies a simple data type.
+ /// string, type of , where T is .
+ /// Also compatible with .
///
- ///
- /// or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory` | |
- /// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` |
- /// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value |
- ///
- /// The table below shows the missing value definition for each of the data types.
- ///
- /// | Type | Missing Value | IsMissing Indicator |
- /// | -- | -- | -- |
- /// | or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` |
- /// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` |
- /// | | | |
- /// | | | |
- /// | All other types | Not defined | Always `false` |
- ///
- /// ]]>
- ///
- ///
- // Data type specifiers mainly used in creating text loader and type converter.
- public enum DataKind : byte
- {
- /// 1-byte integer, type of .
- SByte = 1,
- /// 1-byte unsigned integer, type of .
- Byte = 2,
- /// 2-byte integer, type of .
- Int16 = 3,
- /// 2-byte unsigned integer, type of .
- UInt16 = 4,
- /// 4-byte integer, type of .
- Int32 = 5,
- /// 4-byte unsigned integer, type of .
- UInt32 = 6,
- /// 8-byte integer, type of .
- Int64 = 7,
- /// 8-byte unsigned integer, type of .
- UInt64 = 8,
- /// 4-byte floating-point number, type of .
- Single = 9,
- /// 8-byte floating-point number, type of .
- Double = 10,
- ///
- /// string, type of , where T is .
- /// Also compatible with .
- ///
- String = 11,
- /// boolean variable type, type of .
- Boolean = 12,
- /// type of .
- TimeSpan = 13,
- /// type of .
- DateTime = 14,
- /// type of .
- DateTimeOffset = 15,
- }
+ String = 11,
+ /// boolean variable type, type of .
+ Boolean = 12,
+ /// type of .
+ TimeSpan = 13,
+ /// type of .
+ DateTime = 14,
+ /// type of .
+ DateTimeOffset = 15,
+}
- ///
- /// Data type specifier used in command line. is the underlying version of
- /// used for command line and entry point BC.
- ///
- [BestFriend]
- internal enum InternalDataKind : byte
- {
- // Notes:
- // * These values are serialized, so changing them breaks binary formats.
- // * We intentionally skip zero.
- // * Some code depends on sizeof(DataKind) == sizeof(byte).
+///
+/// Data type specifier used in command line. is the underlying version of
+/// used for command line and entry point BC.
+///
+[BestFriend]
+internal enum InternalDataKind : byte
+{
+ // Notes:
+ // * These values are serialized, so changing them breaks binary formats.
+ // * We intentionally skip zero.
+ // * Some code depends on sizeof(DataKind) == sizeof(byte).
- I1 = DataKind.SByte,
- U1 = DataKind.Byte,
- I2 = DataKind.Int16,
- U2 = DataKind.UInt16,
- I4 = DataKind.Int32,
- U4 = DataKind.UInt32,
- I8 = DataKind.Int64,
- U8 = DataKind.UInt64,
- R4 = DataKind.Single,
- R8 = DataKind.Double,
- Num = R4,
+ I1 = DataKind.SByte,
+ U1 = DataKind.Byte,
+ I2 = DataKind.Int16,
+ U2 = DataKind.UInt16,
+ I4 = DataKind.Int32,
+ U4 = DataKind.UInt32,
+ I8 = DataKind.Int64,
+ U8 = DataKind.UInt64,
+ R4 = DataKind.Single,
+ R8 = DataKind.Double,
+ Num = R4,
- TX = DataKind.String,
+ TX = DataKind.String,
#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independent of C# naming conventions.
- TXT = TX,
- Text = TX,
+ TXT = TX,
+ Text = TX,
- BL = DataKind.Boolean,
- Bool = BL,
+ BL = DataKind.Boolean,
+ Bool = BL,
- TS = DataKind.TimeSpan,
- TimeSpan = TS,
- DT = DataKind.DateTime,
- DateTime = DT,
- DZ = DataKind.DateTimeOffset,
- DateTimeZone = DZ,
+ TS = DataKind.TimeSpan,
+ TimeSpan = TS,
+ DT = DataKind.DateTime,
+ DateTime = DT,
+ DZ = DataKind.DateTimeOffset,
+ DateTimeZone = DZ,
- UG = 16, // Unsigned 16-byte integer.
- U16 = UG,
+ UG = 16, // Unsigned 16-byte integer.
+ U16 = UG,
#pragma warning restore MSML_GeneralName
- }
+}
+
+///
+/// Extension methods related to the DataKind enum.
+///
+[BestFriend]
+internal static class InternalDataKindExtensions
+{
+ public const InternalDataKind KindMin = InternalDataKind.I1;
+ public const InternalDataKind KindLim = InternalDataKind.U16 + 1;
+ public const int KindCount = KindLim - KindMin;
///
- /// Extension methods related to the DataKind enum.
+ /// Maps a DataKind to a value suitable for indexing into an array of size KindCount.
///
- [BestFriend]
- internal static class InternalDataKindExtensions
+ public static int ToIndex(this InternalDataKind kind)
{
- public const InternalDataKind KindMin = InternalDataKind.I1;
- public const InternalDataKind KindLim = InternalDataKind.U16 + 1;
- public const int KindCount = KindLim - KindMin;
-
- ///
- /// Maps a DataKind to a value suitable for indexing into an array of size KindCount.
- ///
- public static int ToIndex(this InternalDataKind kind)
- {
- return kind - KindMin;
- }
-
- ///
- /// Maps from an index into an array of size KindCount to the corresponding DataKind
- ///
- public static InternalDataKind FromIndex(int index)
- {
- Contracts.Check(0 <= index && index < KindCount);
- return (InternalDataKind)(index + (int)KindMin);
- }
-
- ///
- /// This function converts to .
- /// Because is a subset of , the conversion is straightforward.
- ///
- public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind;
+ return kind - KindMin;
+ }
- ///
- /// This function converts to .
- /// Because is a subset of , we should check if
- /// can be found in .
- ///
- public static DataKind ToDataKind(this InternalDataKind kind)
- {
- Contracts.Check(kind != InternalDataKind.UG);
- return (DataKind)kind;
- }
+ ///
+ /// Maps from an index into an array of size KindCount to the corresponding DataKind
+ ///
+ public static InternalDataKind FromIndex(int index)
+ {
+ Contracts.Check(0 <= index && index < KindCount);
+ return (InternalDataKind)(index + (int)KindMin);
+ }
- ///
- /// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
- /// it returns zero.
- ///
- public static ulong ToMaxInt(this InternalDataKind kind)
- {
- switch (kind)
- {
- case InternalDataKind.I1:
- return (ulong)sbyte.MaxValue;
- case InternalDataKind.U1:
- return byte.MaxValue;
- case InternalDataKind.I2:
- return (ulong)short.MaxValue;
- case InternalDataKind.U2:
- return ushort.MaxValue;
- case InternalDataKind.I4:
- return int.MaxValue;
- case InternalDataKind.U4:
- return uint.MaxValue;
- case InternalDataKind.I8:
- return long.MaxValue;
- case InternalDataKind.U8:
- return ulong.MaxValue;
- }
+ ///
+ /// This function converts to .
+ /// Because is a subset of , the conversion is straightforward.
+ ///
+ public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind;
- return 0;
- }
+ ///
+ /// This function converts to .
+ /// Because is a subset of , we should check if
+ /// can be found in .
+ ///
+ public static DataKind ToDataKind(this InternalDataKind kind)
+ {
+ Contracts.Check(kind != InternalDataKind.UG);
+ return (DataKind)kind;
+ }
- ///
- /// For integer Types, this returns the maximum legal value. For un-supported Types,
- /// it returns zero.
- ///
- public static ulong ToMaxInt(this Type type)
+ ///
+ /// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
+ /// it returns zero.
+ ///
+ public static ulong ToMaxInt(this InternalDataKind kind)
+ {
+ switch (kind)
{
- if (type == typeof(sbyte))
+ case InternalDataKind.I1:
return (ulong)sbyte.MaxValue;
- else if (type == typeof(byte))
+ case InternalDataKind.U1:
return byte.MaxValue;
- else if (type == typeof(short))
+ case InternalDataKind.I2:
return (ulong)short.MaxValue;
- else if (type == typeof(ushort))
+ case InternalDataKind.U2:
return ushort.MaxValue;
- else if (type == typeof(int))
+ case InternalDataKind.I4:
return int.MaxValue;
- else if (type == typeof(uint))
+ case InternalDataKind.U4:
return uint.MaxValue;
- else if (type == typeof(long))
+ case InternalDataKind.I8:
return long.MaxValue;
- else if (type == typeof(ulong))
+ case InternalDataKind.U8:
return ulong.MaxValue;
-
- return 0;
}
- ///
- /// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
- /// it returns one.
- ///
- public static long ToMinInt(this InternalDataKind kind)
- {
- switch (kind)
- {
- case InternalDataKind.I1:
- return sbyte.MinValue;
- case InternalDataKind.U1:
- return byte.MinValue;
- case InternalDataKind.I2:
- return short.MinValue;
- case InternalDataKind.U2:
- return ushort.MinValue;
- case InternalDataKind.I4:
- return int.MinValue;
- case InternalDataKind.U4:
- return uint.MinValue;
- case InternalDataKind.I8:
- return long.MinValue;
- case InternalDataKind.U8:
- return 0;
- }
+ return 0;
+ }
- return 1;
- }
+ ///
+ /// For integer Types, this returns the maximum legal value. For un-supported Types,
+ /// it returns zero.
+ ///
+ public static ulong ToMaxInt(this Type type)
+ {
+ if (type == typeof(sbyte))
+ return (ulong)sbyte.MaxValue;
+ else if (type == typeof(byte))
+ return byte.MaxValue;
+ else if (type == typeof(short))
+ return (ulong)short.MaxValue;
+ else if (type == typeof(ushort))
+ return ushort.MaxValue;
+ else if (type == typeof(int))
+ return int.MaxValue;
+ else if (type == typeof(uint))
+ return uint.MaxValue;
+ else if (type == typeof(long))
+ return long.MaxValue;
+ else if (type == typeof(ulong))
+ return ulong.MaxValue;
- ///
- /// Maps a DataKind to the associated .Net representation type.
- ///
- public static Type ToType(this InternalDataKind kind)
- {
- switch (kind)
- {
- case InternalDataKind.I1:
- return typeof(sbyte);
- case InternalDataKind.U1:
- return typeof(byte);
- case InternalDataKind.I2:
- return typeof(short);
- case InternalDataKind.U2:
- return typeof(ushort);
- case InternalDataKind.I4:
- return typeof(int);
- case InternalDataKind.U4:
- return typeof(uint);
- case InternalDataKind.I8:
- return typeof(long);
- case InternalDataKind.U8:
- return typeof(ulong);
- case InternalDataKind.R4:
- return typeof(Single);
- case InternalDataKind.R8:
- return typeof(Double);
- case InternalDataKind.TX:
- return typeof(ReadOnlyMemory);
- case InternalDataKind.BL:
- return typeof(bool);
- case InternalDataKind.TS:
- return typeof(TimeSpan);
- case InternalDataKind.DT:
- return typeof(DateTime);
- case InternalDataKind.DZ:
- return typeof(DateTimeOffset);
- case InternalDataKind.UG:
- return typeof(DataViewRowId);
- }
+ return 0;
+ }
- return null;
+ ///
+ /// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
+ /// it returns one.
+ ///
+ public static long ToMinInt(this InternalDataKind kind)
+ {
+ switch (kind)
+ {
+ case InternalDataKind.I1:
+ return sbyte.MinValue;
+ case InternalDataKind.U1:
+ return byte.MinValue;
+ case InternalDataKind.I2:
+ return short.MinValue;
+ case InternalDataKind.U2:
+ return ushort.MinValue;
+ case InternalDataKind.I4:
+ return int.MinValue;
+ case InternalDataKind.U4:
+ return uint.MinValue;
+ case InternalDataKind.I8:
+ return long.MinValue;
+ case InternalDataKind.U8:
+ return 0;
}
- ///
- /// Try to map a System.Type to a corresponding DataKind value.
- ///
- public static bool TryGetDataKind(this Type type, out InternalDataKind kind)
+ return 1;
+ }
+
+ ///
+ /// Maps a DataKind to the associated .Net representation type.
+ ///
+ public static Type ToType(this InternalDataKind kind)
+ {
+ switch (kind)
{
- Contracts.CheckValueOrNull(type);
+ case InternalDataKind.I1:
+ return typeof(sbyte);
+ case InternalDataKind.U1:
+ return typeof(byte);
+ case InternalDataKind.I2:
+ return typeof(short);
+ case InternalDataKind.U2:
+ return typeof(ushort);
+ case InternalDataKind.I4:
+ return typeof(int);
+ case InternalDataKind.U4:
+ return typeof(uint);
+ case InternalDataKind.I8:
+ return typeof(long);
+ case InternalDataKind.U8:
+ return typeof(ulong);
+ case InternalDataKind.R4:
+ return typeof(Single);
+ case InternalDataKind.R8:
+ return typeof(Double);
+ case InternalDataKind.TX:
+ return typeof(ReadOnlyMemory);
+ case InternalDataKind.BL:
+ return typeof(bool);
+ case InternalDataKind.TS:
+ return typeof(TimeSpan);
+ case InternalDataKind.DT:
+ return typeof(DateTime);
+ case InternalDataKind.DZ:
+ return typeof(DateTimeOffset);
+ case InternalDataKind.UG:
+ return typeof(DataViewRowId);
+ }
+
+ return null;
+ }
- // REVIEW: Make this more efficient. Should we have a global dictionary?
- if (type == typeof(sbyte))
- kind = InternalDataKind.I1;
- else if (type == typeof(byte))
- kind = InternalDataKind.U1;
- else if (type == typeof(short))
- kind = InternalDataKind.I2;
- else if (type == typeof(ushort))
- kind = InternalDataKind.U2;
- else if (type == typeof(int))
- kind = InternalDataKind.I4;
- else if (type == typeof(uint))
- kind = InternalDataKind.U4;
- else if (type == typeof(long))
- kind = InternalDataKind.I8;
- else if (type == typeof(ulong))
- kind = InternalDataKind.U8;
- else if (type == typeof(Single))
- kind = InternalDataKind.R4;
- else if (type == typeof(Double))
- kind = InternalDataKind.R8;
- else if (type == typeof(ReadOnlyMemory) || type == typeof(string))
- kind = InternalDataKind.TX;
- else if (type == typeof(bool))
- kind = InternalDataKind.BL;
- else if (type == typeof(TimeSpan))
- kind = InternalDataKind.TS;
- else if (type == typeof(DateTime))
- kind = InternalDataKind.DT;
- else if (type == typeof(DateTimeOffset))
- kind = InternalDataKind.DZ;
- else if (type == typeof(DataViewRowId))
- kind = InternalDataKind.UG;
- else
- {
- kind = default(InternalDataKind);
- return false;
- }
+ ///
+ /// Try to map a System.Type to a corresponding DataKind value.
+ ///
+ public static bool TryGetDataKind(this Type type, out InternalDataKind kind)
+ {
+ Contracts.CheckValueOrNull(type);
- return true;
+ // REVIEW: Make this more efficient. Should we have a global dictionary?
+ if (type == typeof(sbyte))
+ kind = InternalDataKind.I1;
+ else if (type == typeof(byte))
+ kind = InternalDataKind.U1;
+ else if (type == typeof(short))
+ kind = InternalDataKind.I2;
+ else if (type == typeof(ushort))
+ kind = InternalDataKind.U2;
+ else if (type == typeof(int))
+ kind = InternalDataKind.I4;
+ else if (type == typeof(uint))
+ kind = InternalDataKind.U4;
+ else if (type == typeof(long))
+ kind = InternalDataKind.I8;
+ else if (type == typeof(ulong))
+ kind = InternalDataKind.U8;
+ else if (type == typeof(Single))
+ kind = InternalDataKind.R4;
+ else if (type == typeof(Double))
+ kind = InternalDataKind.R8;
+ else if (type == typeof(ReadOnlyMemory) || type == typeof(string))
+ kind = InternalDataKind.TX;
+ else if (type == typeof(bool))
+ kind = InternalDataKind.BL;
+ else if (type == typeof(TimeSpan))
+ kind = InternalDataKind.TS;
+ else if (type == typeof(DateTime))
+ kind = InternalDataKind.DT;
+ else if (type == typeof(DateTimeOffset))
+ kind = InternalDataKind.DZ;
+ else if (type == typeof(DataViewRowId))
+ kind = InternalDataKind.UG;
+ else
+ {
+ kind = default(InternalDataKind);
+ return false;
}
- ///
- /// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
- /// and is also slow, so use this instead.
- ///
- public static string GetString(this InternalDataKind kind)
+ return true;
+ }
+
+ ///
+ /// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
+ /// and is also slow, so use this instead.
+ ///
+ public static string GetString(this InternalDataKind kind)
+ {
+ switch (kind)
{
- switch (kind)
- {
- case InternalDataKind.I1:
- return "I1";
- case InternalDataKind.I2:
- return "I2";
- case InternalDataKind.I4:
- return "I4";
- case InternalDataKind.I8:
- return "I8";
- case InternalDataKind.U1:
- return "U1";
- case InternalDataKind.U2:
- return "U2";
- case InternalDataKind.U4:
- return "U4";
- case InternalDataKind.U8:
- return "U8";
- case InternalDataKind.R4:
- return "R4";
- case InternalDataKind.R8:
- return "R8";
- case InternalDataKind.BL:
- return "BL";
- case InternalDataKind.TX:
- return "TX";
- case InternalDataKind.TS:
- return "TS";
- case InternalDataKind.DT:
- return "DT";
- case InternalDataKind.DZ:
- return "DZ";
- case InternalDataKind.UG:
- return "UG";
- }
- return "";
+ case InternalDataKind.I1:
+ return "I1";
+ case InternalDataKind.I2:
+ return "I2";
+ case InternalDataKind.I4:
+ return "I4";
+ case InternalDataKind.I8:
+ return "I8";
+ case InternalDataKind.U1:
+ return "U1";
+ case InternalDataKind.U2:
+ return "U2";
+ case InternalDataKind.U4:
+ return "U4";
+ case InternalDataKind.U8:
+ return "U8";
+ case InternalDataKind.R4:
+ return "R4";
+ case InternalDataKind.R8:
+ return "R8";
+ case InternalDataKind.BL:
+ return "BL";
+ case InternalDataKind.TX:
+ return "TX";
+ case InternalDataKind.TS:
+ return "TS";
+ case InternalDataKind.DT:
+ return "DT";
+ case InternalDataKind.DZ:
+ return "DZ";
+ case InternalDataKind.UG:
+ return "UG";
}
+ return "";
}
}
diff --git a/src/Microsoft.ML.Core/Data/ICommand.cs b/src/Microsoft.ML.Core/Data/ICommand.cs
index 7d52222647..f71360510f 100644
--- a/src/Microsoft.ML.Core/Data/ICommand.cs
+++ b/src/Microsoft.ML.Core/Data/ICommand.cs
@@ -2,17 +2,16 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Command
-{
- ///
- /// The signature for commands.
- ///
- [BestFriend]
- internal delegate void SignatureCommand();
+namespace Microsoft.ML.Command;
+
+///
+/// The signature for commands.
+///
+[BestFriend]
+internal delegate void SignatureCommand();
- [BestFriend]
- internal interface ICommand
- {
- void Run();
- }
+[BestFriend]
+internal interface ICommand
+{
+ void Run();
}
diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
index 7dcbab4e75..1917bbeb15 100644
--- a/src/Microsoft.ML.Core/Data/IEstimator.cs
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -8,315 +8,314 @@
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML
+namespace Microsoft.ML;
+
+///
+/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
+/// This is more relaxed than the proper , since it's only a subset of the columns,
+/// and also since it doesn't specify exact 's for vectors and keys.
+///
+public sealed class SchemaShape : IReadOnlyList
{
- ///
- /// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
- /// This is more relaxed than the proper , since it's only a subset of the columns,
- /// and also since it doesn't specify exact 's for vectors and keys.
- ///
- public sealed class SchemaShape : IReadOnlyList
- {
- private readonly Column[] _columns;
+ private readonly Column[] _columns;
- private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty());
+ private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty());
- public int Count => _columns.Count();
+ public int Count => _columns.Count();
- public Column this[int index] => _columns[index];
+ public Column this[int index] => _columns[index];
- public struct Column
+ public struct Column
+ {
+ public enum VectorKind
{
- public enum VectorKind
- {
- Scalar,
- Vector,
- VariableVector
- }
+ Scalar,
+ Vector,
+ VariableVector
+ }
- ///
- /// The column name.
- ///
- public readonly string Name;
-
- ///
- /// The type of the column: scalar, fixed vector or variable vector.
- ///
- public readonly VectorKind Kind;
-
- ///
- /// The 'raw' type of column item: must be a primitive type or a structured type.
- ///
- public readonly DataViewType ItemType;
- ///
- /// The flag whether the column is actually a key. If yes, is representing
- /// the underlying primitive type.
- ///
- public readonly bool IsKey;
- ///
- /// The annotations that are present for this column.
- ///
- public readonly SchemaShape Annotations;
-
- [BestFriend]
- internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
- {
- Contracts.CheckNonEmpty(name, nameof(name));
- Contracts.CheckValueOrNull(annotations);
- Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
- Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
- Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
-
- Name = name;
- Kind = vecKind;
- ItemType = itemType;
- IsKey = isKey;
- Annotations = annotations ?? _empty;
- }
+ ///
+ /// The column name.
+ ///
+ public readonly string Name;
+
+ ///
+ /// The type of the column: scalar, fixed vector or variable vector.
+ ///
+ public readonly VectorKind Kind;
- ///
- /// Returns whether is a valid input, if this object represents a
- /// requirement.
- ///
- /// Namely, it returns true iff:
- /// - The , , , fields match.
- /// - The columns of of is a superset of our columns.
- /// - Each such annotation column is itself compatible with the input annotation column.
- ///
- [BestFriend]
- internal bool IsCompatibleWith(Column source)
+ ///
+ /// The 'raw' type of column item: must be a primitive type or a structured type.
+ ///
+ public readonly DataViewType ItemType;
+ ///
+ /// The flag whether the column is actually a key. If yes, is representing
+ /// the underlying primitive type.
+ ///
+ public readonly bool IsKey;
+ ///
+ /// The annotations that are present for this column.
+ ///
+ public readonly SchemaShape Annotations;
+
+ [BestFriend]
+ internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+ Contracts.CheckValueOrNull(annotations);
+ Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
+ Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
+ Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
+
+ Name = name;
+ Kind = vecKind;
+ ItemType = itemType;
+ IsKey = isKey;
+ Annotations = annotations ?? _empty;
+ }
+
+ ///
+ /// Returns whether is a valid input, if this object represents a
+ /// requirement.
+ ///
+ /// Namely, it returns true iff:
+ /// - The , , , fields match.
+ /// - The columns of of is a superset of our columns.
+ /// - Each such annotation column is itself compatible with the input annotation column.
+ ///
+ [BestFriend]
+ internal bool IsCompatibleWith(Column source)
+ {
+ Contracts.Check(source.IsValid, nameof(source));
+ if (Name != source.Name)
+ return false;
+ if (Kind != source.Kind)
+ return false;
+ if (!ItemType.Equals(source.ItemType))
+ return false;
+ if (IsKey != source.IsKey)
+ return false;
+ foreach (var annotationCol in Annotations)
{
- Contracts.Check(source.IsValid, nameof(source));
- if (Name != source.Name)
+ if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
return false;
- if (Kind != source.Kind)
+ if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
return false;
- if (!ItemType.Equals(source.ItemType))
- return false;
- if (IsKey != source.IsKey)
- return false;
- foreach (var annotationCol in Annotations)
- {
- if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
- return false;
- if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
- return false;
- }
- return true;
}
-
- [BestFriend]
- internal string GetTypeString()
- {
- string result = ItemType.ToString();
- if (IsKey)
- result = $"Key<{result}>";
- if (Kind == VectorKind.Vector)
- result = $"Vector<{result}>";
- else if (Kind == VectorKind.VariableVector)
- result = $"VarVector<{result}>";
- return result;
- }
-
- ///
- /// Return if this structure is not identical to the default value of . If true,
- /// it means this structure is initialized properly and therefore considered as valid.
- ///
- [BestFriend]
- internal bool IsValid => Name != null;
+ return true;
}
- public SchemaShape(IEnumerable columns)
+ [BestFriend]
+ internal string GetTypeString()
{
- Contracts.CheckValue(columns, nameof(columns));
- _columns = columns.ToArray();
- Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
+ string result = ItemType.ToString();
+ if (IsKey)
+ result = $"Key<{result}>";
+ if (Kind == VectorKind.Vector)
+ result = $"Vector<{result}>";
+ else if (Kind == VectorKind.VariableVector)
+ result = $"VarVector<{result}>";
+ return result;
}
///
- /// Given a , extract the type parameters that describe this type
- /// as a 's column type.
+ /// Return if this structure is not identical to the default value of . If true,
+ /// it means this structure is initialized properly and therefore considered as valid.
///
- /// The actual column type to process.
- /// The vector kind of .
- /// The item type of .
- /// Whether (or its item type) is a key.
[BestFriend]
- internal static void GetColumnTypeShape(DataViewType type,
- out Column.VectorKind vecKind,
- out DataViewType itemType,
- out bool isKey)
+ internal bool IsValid => Name != null;
+ }
+
+ public SchemaShape(IEnumerable columns)
+ {
+ Contracts.CheckValue(columns, nameof(columns));
+ _columns = columns.ToArray();
+ Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
+ }
+
+ ///
+ /// Given a , extract the type parameters that describe this type
+ /// as a 's column type.
+ ///
+ /// The actual column type to process.
+ /// The vector kind of .
+ /// The item type of .
+ /// Whether (or its item type) is a key.
+ [BestFriend]
+ internal static void GetColumnTypeShape(DataViewType type,
+ out Column.VectorKind vecKind,
+ out DataViewType itemType,
+ out bool isKey)
+ {
+ if (type is VectorDataViewType vectorType)
{
- if (type is VectorDataViewType vectorType)
+ if (vectorType.IsKnownSize)
{
- if (vectorType.IsKnownSize)
- {
- vecKind = Column.VectorKind.Vector;
- }
- else
- {
- vecKind = Column.VectorKind.VariableVector;
- }
-
- itemType = vectorType.ItemType;
+ vecKind = Column.VectorKind.Vector;
}
else
{
- vecKind = Column.VectorKind.Scalar;
- itemType = type;
+ vecKind = Column.VectorKind.VariableVector;
}
- isKey = itemType is KeyDataViewType;
- if (isKey)
- itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
+ itemType = vectorType.ItemType;
}
-
- ///
- /// Create a schema shape out of the fully defined schema.
- ///
- [BestFriend]
- internal static SchemaShape Create(DataViewSchema schema)
+ else
{
- Contracts.CheckValue(schema, nameof(schema));
- var cols = new List();
+ vecKind = Column.VectorKind.Scalar;
+ itemType = type;
+ }
+
+ isKey = itemType is KeyDataViewType;
+ if (isKey)
+ itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
+ }
+
+ ///
+ /// Create a schema shape out of the fully defined schema.
+ ///
+ [BestFriend]
+ internal static SchemaShape Create(DataViewSchema schema)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ var cols = new List();
- for (int iCol = 0; iCol < schema.Count; iCol++)
+ for (int iCol = 0; iCol < schema.Count; iCol++)
+ {
+ if (!schema[iCol].IsHidden)
{
- if (!schema[iCol].IsHidden)
+ // First create the annotations.
+ var mCols = new List();
+ foreach (var annotationColumn in schema[iCol].Annotations.Schema)
{
- // First create the annotations.
- var mCols = new List();
- foreach (var annotationColumn in schema[iCol].Annotations.Schema)
- {
- GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
- mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
- }
- var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
- // Next create the single column.
- GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
- cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
+ GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
+ mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
}
+ var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
+ // Next create the single column.
+ GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
+ cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
}
- return new SchemaShape(cols);
}
+ return new SchemaShape(cols);
+ }
- ///
- /// Returns if there is a column with a specified and if so stores it in .
- ///
- [BestFriend]
- internal bool TryFindColumn(string name, out Column column)
- {
- Contracts.CheckValue(name, nameof(name));
- column = _columns.FirstOrDefault(x => x.Name == name);
- return column.IsValid;
- }
+ ///
+ /// Returns if there is a column with a specified and if so stores it in .
+ ///
+ [BestFriend]
+ internal bool TryFindColumn(string name, out Column column)
+ {
+ Contracts.CheckValue(name, nameof(name));
+ column = _columns.FirstOrDefault(x => x.Name == name);
+ return column.IsValid;
+ }
- public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator();
+ public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator();
- IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
- // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
- // as an input to another schema shape. I started writing, but realized that there's more than one way to check for
- // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
- }
+ // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
+ // as an input to another schema shape. I started writing, but realized that there's more than one way to check for
+ // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
+}
+///
+/// The 'data loader' takes a certain kind of input and turns it into an .
+///
+/// The type of input the loader takes.
+public interface IDataLoader : ICanSaveModel
+{
///
- /// The 'data loader' takes a certain kind of input and turns it into an .
+ /// Produce the data view from the specified input.
+ /// Note that 's are lazy, so no actual loading happens here, just schema validation.
///
- /// The type of input the loader takes.
- public interface IDataLoader : ICanSaveModel
- {
- ///
- /// Produce the data view from the specified input.
- /// Note that 's are lazy, so no actual loading happens here, just schema validation.
- ///
- IDataView Load(TSource input);
-
- ///
- /// The output schema of the loader.
- ///
- DataViewSchema GetOutputSchema();
- }
+ IDataView Load(TSource input);
///
- /// Sometimes we need to 'fit' an .
- /// A DataLoader estimator is the object that does it.
+ /// The output schema of the loader.
///
- public interface IDataLoaderEstimator
- where TLoader : IDataLoader
- {
- // REVIEW: you could consider the transformer to take a different , but we don't have such components
- // yet, so why complicate matters?
+ DataViewSchema GetOutputSchema();
+}
- ///
- /// Train and return a data loader.
- ///
- TLoader Fit(TSource input);
+///
+/// Sometimes we need to 'fit' an .
+/// A DataLoader estimator is the object that does it.
+///
+public interface IDataLoaderEstimator
+ where TLoader : IDataLoader
+{
+ // REVIEW: you could consider the transformer to take a different , but we don't have such components
+ // yet, so why complicate matters?
- ///
- /// The 'promise' of the output schema.
- /// It will be used for schema propagation.
- ///
- SchemaShape GetOutputSchema();
- }
+ ///
+ /// Train and return a data loader.
+ ///
+ TLoader Fit(TSource input);
///
- /// The transformer is a component that transforms data.
- /// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
+ /// The 'promise' of the output schema.
+ /// It will be used for schema propagation.
///
- public interface ITransformer : ICanSaveModel
- {
- ///
- /// Schema propagation for transformers.
- /// Returns the output schema of the data, if the input schema is like the one provided.
- ///
- DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
+ SchemaShape GetOutputSchema();
+}
- ///
- /// Take the data in, make transformations, output the data.
- /// Note that 's are lazy, so no actual transformations happen here, just schema validation.
- ///
- IDataView Transform(IDataView input);
+///
+/// The transformer is a component that transforms data.
+/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
+///
+public interface ITransformer : ICanSaveModel
+{
+ ///
+ /// Schema propagation for transformers.
+ /// Returns the output schema of the data, if the input schema is like the one provided.
+ ///
+ DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
- ///
- /// Whether a call to should succeed, on an
- /// appropriate schema.
- ///
- bool IsRowToRowMapper { get; }
+ ///
+ /// Take the data in, make transformations, output the data.
+ /// Note that 's are lazy, so no actual transformations happen here, just schema validation.
+ ///
+ IDataView Transform(IDataView input);
- ///
- /// Constructs a row-to-row mapper based on an input schema. If
- /// is false, then an exception should be thrown. If the input schema is in any way
- /// unsuitable for constructing the mapper, an exception should likewise be thrown.
- ///
- /// The input schema for which we should get the mapper.
- /// The row to row mapper.
- IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
- }
+ ///
+ /// Whether a call to should succeed, on an
+ /// appropriate schema.
+ ///
+ bool IsRowToRowMapper { get; }
- [BestFriend]
- internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
- {
- IDataView TransformForTrainingPipeline(IDataView input);
- }
+ ///
+ /// Constructs a row-to-row mapper based on an input schema. If
+ /// is false, then an exception should be thrown. If the input schema is in any way
+ /// unsuitable for constructing the mapper, an exception should likewise be thrown.
+ ///
+ /// The input schema for which we should get the mapper.
+ /// The row to row mapper.
+ IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
+}
+[BestFriend]
+internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
+{
+ IDataView TransformForTrainingPipeline(IDataView input);
+}
+
+///
+/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
+/// a transformer.
+/// It also provides the 'schema propagation' like transformers do, but over instead of .
+///
+public interface IEstimator
+ where TTransformer : ITransformer
+{
///
- /// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
- /// a transformer.
- /// It also provides the 'schema propagation' like transformers do, but over instead of .
+ /// Train and return a transformer.
///
- public interface IEstimator
- where TTransformer : ITransformer
- {
- ///
- /// Train and return a transformer.
- ///
- TTransformer Fit(IDataView input);
+ TTransformer Fit(IDataView input);
- ///
- /// Schema propagation for estimators.
- /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
- ///
- SchemaShape GetOutputSchema(SchemaShape inputSchema);
- }
+ ///
+ /// Schema propagation for estimators.
+ /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
+ ///
+ SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
diff --git a/src/Microsoft.ML.Core/Data/IFileHandle.cs b/src/Microsoft.ML.Core/Data/IFileHandle.cs
index 5afef22fee..6eac6821a9 100644
--- a/src/Microsoft.ML.Core/Data/IFileHandle.cs
+++ b/src/Microsoft.ML.Core/Data/IFileHandle.cs
@@ -8,191 +8,190 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// A file handle.
+///
+public interface IFileHandle : IDisposable
{
///
- /// A file handle.
+ /// Returns whether CreateWriteStream is expected to succeed. Typically, once
+ /// CreateWriteStream has been called once, this will forever more return false.
///
- public interface IFileHandle : IDisposable
- {
- ///
- /// Returns whether CreateWriteStream is expected to succeed. Typically, once
- /// CreateWriteStream has been called once, this will forever more return false.
- ///
- bool CanWrite { get; }
-
- ///
- /// Returns whether OpenReadStream is expected to succeed.
- ///
- bool CanRead { get; }
-
- ///
- /// Create a writable stream for this file handle.
- ///
- Stream CreateWriteStream();
-
- ///
- /// Open a readable stream for this file handle.
- ///
- Stream OpenReadStream();
- }
+ bool CanWrite { get; }
///
- /// A simple disk-based file handle.
+ /// Returns whether OpenReadStream is expected to succeed.
///
- public sealed class SimpleFileHandle : IFileHandle
- {
- private readonly string _fullPath;
+ bool CanRead { get; }
+
+ ///
+ /// Create a writable stream for this file handle.
+ ///
+ Stream CreateWriteStream();
- // Exception context.
- private readonly IExceptionContext _ectx;
+ ///
+ /// Open a readable stream for this file handle.
+ ///
+ Stream OpenReadStream();
+}
- private readonly object _lock;
+///
+/// A simple disk-based file handle.
+///
+public sealed class SimpleFileHandle : IFileHandle
+{
+ private readonly string _fullPath;
- // Whether to delete the file when this is disposed.
- private readonly bool _autoDelete;
+ // Exception context.
+ private readonly IExceptionContext _ectx;
- // Whether this file has contents. This is false if the file needs CreateWriteStream to be
- // called (before OpenReadStream can be called).
- private bool _wrote;
- // If non-null, the active write stream. This should be disposed before the first OpenReadStream call.
- private Stream _streamWrite;
+ private readonly object _lock;
- // This contains the potentially active read streams. This is set to null once this file
- // handle has been disposed.
- private List _streams;
+ // Whether to delete the file when this is disposed.
+ private readonly bool _autoDelete;
- private bool IsDisposed => _streams == null;
+ // Whether this file has contents. This is false if the file needs CreateWriteStream to be
+ // called (before OpenReadStream can be called).
+ private bool _wrote;
+ // If non-null, the active write stream. This should be disposed before the first OpenReadStream call.
+ private Stream _streamWrite;
- public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
- {
- Contracts.CheckValue(ectx, nameof(ectx));
- ectx.CheckNonEmpty(path, nameof(path));
+ // This contains the potentially active read streams. This is set to null once this file
+ // handle has been disposed.
+ private List _streams;
- _ectx = ectx;
- _fullPath = Path.GetFullPath(path);
+ private bool IsDisposed => _streams == null;
- _autoDelete = autoDelete;
+ public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
+ {
+ Contracts.CheckValue(ectx, nameof(ectx));
+ ectx.CheckNonEmpty(path, nameof(path));
- // The file has already been written to iff needsWrite is false.
- _wrote = !needsWrite;
+ _ectx = ectx;
+ _fullPath = Path.GetFullPath(path);
- // REVIEW: Should this do some basic validation? Eg, for output files, ensure that
- // the directory exists (and perhaps even create an empty file); for input files, ensure
- // that the file exists (and perhaps even attempt to open it).
+ _autoDelete = autoDelete;
- _lock = new object();
- _streams = new List();
- }
+ // The file has already been written to iff needsWrite is false.
+ _wrote = !needsWrite;
+
+ // REVIEW: Should this do some basic validation? Eg, for output files, ensure that
+ // the directory exists (and perhaps even create an empty file); for input files, ensure
+ // that the file exists (and perhaps even attempt to open it).
- public bool CanWrite => !_wrote && !IsDisposed;
+ _lock = new object();
+ _streams = new List();
+ }
+
+ public bool CanWrite => !_wrote && !IsDisposed;
- public bool CanRead => _wrote && !IsDisposed;
+ public bool CanRead => _wrote && !IsDisposed;
- public void Dispose()
+ public void Dispose()
+ {
+ lock (_lock)
{
- lock (_lock)
- {
- if (IsDisposed)
- return;
+ if (IsDisposed)
+ return;
- Contracts.Assert(_streams != null);
+ Contracts.Assert(_streams != null);
- // REVIEW: Is it safe to dispose these streams? What if they are
- // being used on other threads? Does that matter?
- if (_streamWrite != null)
+ // REVIEW: Is it safe to dispose these streams? What if they are
+ // being used on other threads? Does that matter?
+ if (_streamWrite != null)
+ {
+ try
{
- try
- {
- _streamWrite.CloseEx();
- _streamWrite.Dispose();
- }
- catch
- {
- // REVIEW: What should we do here?
- Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!");
- }
- _streamWrite = null;
+ _streamWrite.CloseEx();
+ _streamWrite.Dispose();
}
+ catch
+ {
+ // REVIEW: What should we do here?
+ Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!");
+ }
+ _streamWrite = null;
+ }
- foreach (var stream in _streams)
+ foreach (var stream in _streams)
+ {
+ try
+ {
+ stream.CloseEx();
+ stream.Dispose();
+ }
+ catch
{
- try
- {
- stream.CloseEx();
- stream.Dispose();
- }
- catch
- {
- // REVIEW: What should we do here?
- Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!");
- }
+ // REVIEW: What should we do here?
+ Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!");
}
+ }
- _streams = null;
- Contracts.Assert(IsDisposed);
+ _streams = null;
+ Contracts.Assert(IsDisposed);
- if (_autoDelete)
+ if (_autoDelete)
+ {
+ try
+ {
+ // Finally, delete the file.
+ File.Delete(_fullPath);
+ }
+ catch
{
- try
- {
- // Finally, delete the file.
- File.Delete(_fullPath);
- }
- catch
- {
- // REVIEW: What should we do here?
- Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!");
- }
+ // REVIEW: What should we do here?
+ Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!");
}
}
}
+ }
- private void CheckNotDisposed()
- {
- if (IsDisposed)
- throw _ectx.Except("SimpleFileHandle has already been disposed");
- }
+ private void CheckNotDisposed()
+ {
+ if (IsDisposed)
+ throw _ectx.Except("SimpleFileHandle has already been disposed");
+ }
- public Stream CreateWriteStream()
+ public Stream CreateWriteStream()
+ {
+ lock (_lock)
{
- lock (_lock)
- {
- CheckNotDisposed();
+ CheckNotDisposed();
- if (_wrote)
- throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle");
+ if (_wrote)
+ throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle");
- Contracts.Assert(_streamWrite == null);
- _streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write);
- _wrote = true;
- return _streamWrite;
- }
+ Contracts.Assert(_streamWrite == null);
+ _streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write);
+ _wrote = true;
+ return _streamWrite;
}
+ }
- public Stream OpenReadStream()
+ public Stream OpenReadStream()
+ {
+ lock (_lock)
{
- lock (_lock)
- {
- CheckNotDisposed();
+ CheckNotDisposed();
- if (!_wrote)
- throw _ectx.Except("SimpleFileHandle hasn't been written yet");
+ if (!_wrote)
+ throw _ectx.Except("SimpleFileHandle hasn't been written yet");
- if (_streamWrite != null)
- {
- if (_streamWrite.CanWrite)
- throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed");
- _streamWrite = null;
- }
+ if (_streamWrite != null)
+ {
+ if (_streamWrite.CanWrite)
+ throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed");
+ _streamWrite = null;
+ }
- // Drop read streams that have already been disposed.
- _streams.RemoveAll(s => !s.CanRead);
+ // Drop read streams that have already been disposed.
+ _streams.RemoveAll(s => !s.CanRead);
- var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read);
- _streams.Add(stream);
- return stream;
- }
+ var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read);
+ _streams.Add(stream);
+ return stream;
}
}
}
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index 72a1245485..06fe0b32b4 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -5,319 +5,318 @@
using System;
using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML.Runtime;
+
+///
+/// A channel provider can create new channels and generic information pipes.
+///
+public interface IChannelProvider : IExceptionContext
{
///
- /// A channel provider can create new channels and generic information pipes.
+ /// Start a standard message channel.
///
- public interface IChannelProvider : IExceptionContext
- {
- ///
- /// Start a standard message channel.
- ///
- IChannel Start(string name);
-
- ///
- /// Start a generic information pipe.
- ///
- IPipe StartPipe(string name);
- }
+ IChannel Start(string name);
///
- /// Utility class for IHostEnvironment
+ /// Start a generic information pipe.
///
- [BestFriend]
- internal static class HostEnvironmentExtensions
- {
- ///
- /// Return a file handle for an input "file".
- ///
- public static IFileHandle OpenInputFile(this IHostEnvironment env, string path)
- {
- Contracts.AssertValue(env);
- Contracts.CheckNonWhiteSpace(path, nameof(path));
- return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
- }
+ IPipe StartPipe(string name);
+}
- ///
- /// Create an output "file" and return a handle to it.
- ///
- public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path)
- {
- Contracts.AssertValue(env);
- Contracts.CheckNonWhiteSpace(path, nameof(path));
- return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
- }
+///
+/// Utility class for IHostEnvironment
+///
+[BestFriend]
+internal static class HostEnvironmentExtensions
+{
+ ///
+ /// Return a file handle for an input "file".
+ ///
+ public static IFileHandle OpenInputFile(this IHostEnvironment env, string path)
+ {
+ Contracts.AssertValue(env);
+ Contracts.CheckNonWhiteSpace(path, nameof(path));
+ return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
}
///
- /// The host environment interface creates hosts for components. Note that the methods of
- /// this interface should be called from the main thread for the environment. To get an environment
- /// to service another thread, call Fork and pass the return result to that thread.
+ /// Create an output "file" and return a handle to it.
///
- public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
+ public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path)
{
- ///
- /// Create a host with the given registration name.
- ///
- IHost Register(string name, int? seed = null, bool? verbose = null);
-
- ///
- /// The catalog of loadable components () that are available in this host.
- ///
- ComponentCatalog ComponentCatalog { get; }
+ Contracts.AssertValue(env);
+ Contracts.CheckNonWhiteSpace(path, nameof(path));
+ return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
}
+}
+
+///
+/// The host environment interface creates hosts for components. Note that the methods of
+/// this interface should be called from the main thread for the environment. To get an environment
+/// to service another thread, call Fork and pass the return result to that thread.
+///
+public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
+{
+ ///
+ /// Create a host with the given registration name.
+ ///
+ IHost Register(string name, int? seed = null, bool? verbose = null);
+
+ ///
+ /// The catalog of loadable components () that are available in this host.
+ ///
+ ComponentCatalog ComponentCatalog { get; }
+}
+
+[BestFriend]
+internal interface ICancelable
+{
+ ///
+ /// Signal to stop execution in all the hosts.
+ ///
+ void CancelExecution();
+
+ ///
+ /// Flag which indicates host execution has been stopped.
+ ///
+ bool IsCanceled { get; }
+}
+
+[BestFriend]
+internal interface IHostEnvironmentInternal : IHostEnvironment
+{
+ ///
+ /// The seed property that, if assigned, makes components requiring randomness behave deterministically.
+ ///
+ int? Seed { get; }
+
+ ///
+ /// The location for the temp files created by ML.NET
+ ///
+ string TempFilePath { get; set; }
+
+ ///
+ /// Allow falling back to run on CPU if couldn't run on GPU.
+ ///
+ bool FallbackToCpu { get; set; }
+
+ ///
+ /// GPU device ID to run execution on, to run on CPU.
+ ///
+ int? GpuDeviceId { get; set; }
+}
+
+///
+/// A host is coupled to a component and provides random number generation and concurrency guidance.
+/// Note that the random number generation, like the host environment methods, should be accessed only
+/// from the main thread for the component.
+///
+public interface IHost : IHostEnvironment
+{
+ ///
+ /// The random number generator issued to this component. Note that random number
+ /// generators are NOT thread safe.
+ ///
+ Random Rand { get; }
+}
+
+///
+/// A generic information pipe. Note that pipes are disposable. Generally, Done should
+/// be called before disposing to signal a normal shut-down of the pipe, as opposed
+/// to an aborted completion.
+///
+public interface IPipe : IExceptionContext, IDisposable
+{
+ ///
+ /// The caller relinquishes ownership of the object.
+ ///
+ void Send(TMessage msg);
+}
+
+///
+/// The kinds of standard channel messages.
+/// Note: These values should never be changed. We can add new kinds, but don't change these values.
+/// Other code bases, including native code for other projects depends on these values.
+///
+public enum ChannelMessageKind
+{
+ Trace = 0,
+ Info = 1,
+ Warning = 2,
+ Error = 3
+}
+
+///
+/// A flag that can be attached to a message or exception to indicate that
+/// it has a certain class of sensitive data. By default, messages should be
+/// specified as being of unknown sensitivity, which is to say, every
+/// sensitivity flag is turned on, corresponding to .
+/// Messages that are totally safe should be marked as .
+/// However, if, say, one prints out data from a file (for example, this might
+/// be done when expressing parse errors), it should be flagged in that case
+/// with .
+///
+[Flags]
+public enum MessageSensitivity
+{
+ ///
+ /// For non-sensitive data.
+ ///
+ None = 0,
+
+ ///
+ /// For messages that may contain user-data from data files.
+ ///
+ UserData = 0x1,
+
+ ///
+ /// For messages that contain information like column names from datasets.
+ /// Note that, despite being part of the schema, annotations should be treated
+ /// as user data, since it is often derived from user data. Note also that
+ /// types, despite being part of the schema, are not considered "sensitive"
+ /// as such, in the same way that column names might be.
+ ///
+ Schema = 0x2,
+
+ // REVIEW: Other potentially sensitive things might include
+ // stack traces in certain environments.
+
+ ///
+ /// The default value, unknown, is treated as if everything is sensitive.
+ ///
+ Unknown = ~None,
+
+ ///
+ /// An alias for , so it is functionally the same, except
+ /// semantically it communicates the idea that we want all bits set.
+ ///
+ All = Unknown,
+}
+
+///
+/// A channel message.
+///
+public readonly struct ChannelMessage
+{
+ public readonly ChannelMessageKind Kind;
+ public readonly MessageSensitivity Sensitivity;
+ private readonly string _message;
+ private readonly object[] _args;
+
+ ///
+ /// Line endings may not be normalized.
+ ///
+ public string Message => _args != null ? string.Format(_message, _args) : _message;
[BestFriend]
- internal interface ICancelable
+ internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
{
- ///
- /// Signal to stop execution in all the hosts.
- ///
- void CancelExecution();
-
- ///
- /// Flag which indicates host execution has been stopped.
- ///
- bool IsCanceled { get; }
+ Contracts.CheckNonEmpty(message, nameof(message));
+ Kind = kind;
+ Sensitivity = sensitivity;
+ _message = message;
+ _args = null;
}
[BestFriend]
- internal interface IHostEnvironmentInternal : IHostEnvironment
+ internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
{
- ///
- /// The seed property that, if assigned, makes components requiring randomness behave deterministically.
- ///
- int? Seed { get; }
-
- ///
- /// The location for the temp files created by ML.NET
- ///
- string TempFilePath { get; set; }
-
- ///
- /// Allow falling back to run on CPU if couldn't run on GPU.
- ///
- bool FallbackToCpu { get; set; }
-
- ///
- /// GPU device ID to run execution on, to run on CPU.
- ///
- int? GpuDeviceId { get; set; }
+ Contracts.CheckNonEmpty(fmt, nameof(fmt));
+ Contracts.CheckNonEmpty(args, nameof(args));
+ Kind = kind;
+ Sensitivity = sensitivity;
+ _message = fmt;
+ _args = args;
}
+}
- ///
- /// A host is coupled to a component and provides random number generation and concurrency guidance.
- /// Note that the random number generation, like the host environment methods, should be accessed only
- /// from the main thread for the component.
- ///
- public interface IHost : IHostEnvironment
+///
+/// A standard communication channel.
+///
+public interface IChannel : IPipe
+{
+ void Trace(MessageSensitivity sensitivity, string fmt);
+ void Trace(MessageSensitivity sensitivity, string fmt, params object[] args);
+ void Error(MessageSensitivity sensitivity, string fmt);
+ void Error(MessageSensitivity sensitivity, string fmt, params object[] args);
+ void Warning(MessageSensitivity sensitivity, string fmt);
+ void Warning(MessageSensitivity sensitivity, string fmt, params object[] args);
+ void Info(MessageSensitivity sensitivity, string fmt);
+ void Info(MessageSensitivity sensitivity, string fmt, params object[] args);
+}
+
+///
+/// General utility extension methods for objects in the "host" universe, i.e.,
+/// , , and
+/// that do not belong in more specific areas, for example, or
+/// component creation.
+///
+[BestFriend]
+internal static class HostExtensions
+{
+ public static T Apply(this IHost host, string channelName, Func func)
{
- ///
- /// The random number generator issued to this component. Note that random number
- /// generators are NOT thread safe.
- ///
- Random Rand { get; }
+ T t;
+ using (var ch = host.Start(channelName))
+ {
+ t = func(ch);
+ }
+ return t;
}
///
- /// A generic information pipe. Note that pipes are disposable. Generally, Done should
- /// be called before disposing to signal a normal shut-down of the pipe, as opposed
- /// to an aborted completion.
+ /// Convenience variant of
+ /// setting .
///
- public interface IPipe : IExceptionContext, IDisposable
- {
- ///
- /// The caller relinquishes ownership of the object.
- ///
- void Send(TMessage msg);
- }
+ public static void Trace(this IChannel ch, string fmt)
+ => ch.Trace(MessageSensitivity.Unknown, fmt);
///
- /// The kinds of standard channel messages.
- /// Note: These values should never be changed. We can add new kinds, but don't change these values.
- /// Other code bases, including native code for other projects depends on these values.
+ /// Convenience variant of
+ /// setting .
///
- public enum ChannelMessageKind
- {
- Trace = 0,
- Info = 1,
- Warning = 2,
- Error = 3
- }
+ public static void Trace(this IChannel ch, string fmt, params object[] args)
+ => ch.Trace(MessageSensitivity.Unknown, fmt, args);
///
- /// A flag that can be attached to a message or exception to indicate that
- /// it has a certain class of sensitive data. By default, messages should be
- /// specified as being of unknown sensitivity, which is to say, every
- /// sensitivity flag is turned on, corresponding to .
- /// Messages that are totally safe should be marked as .
- /// However, if, say, one prints out data from a file (for example, this might
- /// be done when expressing parse errors), it should be flagged in that case
- /// with .
+ /// Convenience variant of
+ /// setting .
///
- [Flags]
- public enum MessageSensitivity
- {
- ///
- /// For non-sensitive data.
- ///
- None = 0,
-
- ///
- /// For messages that may contain user-data from data files.
- ///
- UserData = 0x1,
-
- ///
- /// For messages that contain information like column names from datasets.
- /// Note that, despite being part of the schema, annotations should be treated
- /// as user data, since it is often derived from user data. Note also that
- /// types, despite being part of the schema, are not considered "sensitive"
- /// as such, in the same way that column names might be.
- ///
- Schema = 0x2,
-
- // REVIEW: Other potentially sensitive things might include
- // stack traces in certain environments.
-
- ///
- /// The default value, unknown, is treated as if everything is sensitive.
- ///
- Unknown = ~None,
-
- ///
- /// An alias for , so it is functionally the same, except
- /// semantically it communicates the idea that we want all bits set.
- ///
- All = Unknown,
- }
+ public static void Error(this IChannel ch, string fmt)
+ => ch.Error(MessageSensitivity.Unknown, fmt);
///
- /// A channel message.
+ /// Convenience variant of
+ /// setting .
///
- public readonly struct ChannelMessage
- {
- public readonly ChannelMessageKind Kind;
- public readonly MessageSensitivity Sensitivity;
- private readonly string _message;
- private readonly object[] _args;
-
- ///
- /// Line endings may not be normalized.
- ///
- public string Message => _args != null ? string.Format(_message, _args) : _message;
-
- [BestFriend]
- internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
- {
- Contracts.CheckNonEmpty(message, nameof(message));
- Kind = kind;
- Sensitivity = sensitivity;
- _message = message;
- _args = null;
- }
+ public static void Error(this IChannel ch, string fmt, params object[] args)
+ => ch.Error(MessageSensitivity.Unknown, fmt, args);
- [BestFriend]
- internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
- {
- Contracts.CheckNonEmpty(fmt, nameof(fmt));
- Contracts.CheckNonEmpty(args, nameof(args));
- Kind = kind;
- Sensitivity = sensitivity;
- _message = fmt;
- _args = args;
- }
- }
+ ///
+ /// Convenience variant of
+ /// setting .
+ ///
+ public static void Warning(this IChannel ch, string fmt)
+ => ch.Warning(MessageSensitivity.Unknown, fmt);
///
- /// A standard communication channel.
+ /// Convenience variant of
+ /// setting .
///
- public interface IChannel : IPipe
- {
- void Trace(MessageSensitivity sensitivity, string fmt);
- void Trace(MessageSensitivity sensitivity, string fmt, params object[] args);
- void Error(MessageSensitivity sensitivity, string fmt);
- void Error(MessageSensitivity sensitivity, string fmt, params object[] args);
- void Warning(MessageSensitivity sensitivity, string fmt);
- void Warning(MessageSensitivity sensitivity, string fmt, params object[] args);
- void Info(MessageSensitivity sensitivity, string fmt);
- void Info(MessageSensitivity sensitivity, string fmt, params object[] args);
- }
+ public static void Warning(this IChannel ch, string fmt, params object[] args)
+ => ch.Warning(MessageSensitivity.Unknown, fmt, args);
///
- /// General utility extension methods for objects in the "host" universe, i.e.,
- /// , , and
- /// that do not belong in more specific areas, for example, or
- /// component creation.
+ /// Convenience variant of
+ /// setting .
///
- [BestFriend]
- internal static class HostExtensions
- {
- public static T Apply(this IHost host, string channelName, Func func)
- {
- T t;
- using (var ch = host.Start(channelName))
- {
- t = func(ch);
- }
- return t;
- }
+ public static void Info(this IChannel ch, string fmt)
+ => ch.Info(MessageSensitivity.Unknown, fmt);
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Trace(this IChannel ch, string fmt)
- => ch.Trace(MessageSensitivity.Unknown, fmt);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Trace(this IChannel ch, string fmt, params object[] args)
- => ch.Trace(MessageSensitivity.Unknown, fmt, args);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Error(this IChannel ch, string fmt)
- => ch.Error(MessageSensitivity.Unknown, fmt);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Error(this IChannel ch, string fmt, params object[] args)
- => ch.Error(MessageSensitivity.Unknown, fmt, args);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Warning(this IChannel ch, string fmt)
- => ch.Warning(MessageSensitivity.Unknown, fmt);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Warning(this IChannel ch, string fmt, params object[] args)
- => ch.Warning(MessageSensitivity.Unknown, fmt, args);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Info(this IChannel ch, string fmt)
- => ch.Info(MessageSensitivity.Unknown, fmt);
-
- ///
- /// Convenience variant of
- /// setting .
- ///
- public static void Info(this IChannel ch, string fmt, params object[] args)
- => ch.Info(MessageSensitivity.Unknown, fmt, args);
- }
+ ///
+ /// Convenience variant of
+ /// setting .
+ ///
+ public static void Info(this IChannel ch, string fmt, params object[] args)
+ => ch.Info(MessageSensitivity.Unknown, fmt, args);
}
diff --git a/src/Microsoft.ML.Core/Data/IProgressChannel.cs b/src/Microsoft.ML.Core/Data/IProgressChannel.cs
index 1cff123f69..ccbbb2a6d7 100644
--- a/src/Microsoft.ML.Core/Data/IProgressChannel.cs
+++ b/src/Microsoft.ML.Core/Data/IProgressChannel.cs
@@ -5,140 +5,139 @@
using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML.Runtime;
+
+///
+/// This is a factory interface for .
+/// Both and implement this interface,
+/// to allow for nested progress reporters.
+///
+/// REVIEW: make implement this, instead of the environment?
+///
+public interface IProgressChannelProvider
{
///
- /// This is a factory interface for .
- /// Both and implement this interface,
- /// to allow for nested progress reporters.
- ///
- /// REVIEW: make implement this, instead of the environment?
+ /// Create a progress channel for a computation named .
///
- public interface IProgressChannelProvider
- {
- ///
- /// Create a progress channel for a computation named .
- ///
- IProgressChannel StartProgressChannel(string name);
- }
+ IProgressChannel StartProgressChannel(string name);
+}
+///
+/// A common interface for progress reporting.
+/// It is expected that the progress channel interface is used from only one thread.
+///
+/// Supported workflow:
+/// 1) Create the channel via .
+/// 2) Call as many times as desired (including 0).
+/// Each call to supersedes the previous one.
+/// 3) Report checkpoints (0 or more) by calling .
+/// 4) Repeat steps 2-3 as often as necessary.
+/// 5) Dispose the channel.
+///
+public interface IProgressChannel : IProgressChannelProvider, IDisposable
+{
///
- /// A common interface for progress reporting.
- /// It is expected that the progress channel interface is used from only one thread.
+ /// Set up the reporting structure:
+ /// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
+ /// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
///
- /// Supported workflow:
- /// 1) Create the channel via .
- /// 2) Call as many times as desired (including 0).
- /// Each call to supersedes the previous one.
- /// 3) Report checkpoints (0 or more) by calling .
- /// 4) Repeat steps 2-3 as often as necessary.
- /// 5) Dispose the channel.
+ /// It is acceptable to call multiple times (or none), regardless of whether the calculation is running
+ /// or not. Because of synchronization, the computation should not deny calls to the 'old'
+ /// delegates even after a new one is provided.
///
- public interface IProgressChannel : IProgressChannelProvider, IDisposable
- {
- ///
- /// Set up the reporting structure:
- /// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
- /// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
- ///
- /// It is acceptable to call multiple times (or none), regardless of whether the calculation is running
- /// or not. Because of synchronization, the computation should not deny calls to the 'old'
- /// delegates even after a new one is provided.
- ///
- /// The header object.
- /// The delegate to provide actual progress. The parameter of
- /// the delegate will correspond to the provided .
- void SetHeader(ProgressHeader header, Action fillAction);
-
- ///
- /// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
- /// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
- /// at certain moments ('checkpoints') of the computation.
- ///
- /// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
- /// every iteration.
- ///
- /// The only parameter, , is interpreted in the following fashion:
- /// * First MetricNames.Length items, if present, are metrics.
- /// * Subsequent ProgressNames.Length items, if present, are progress units.
- /// * Subsequent ProgressNames.Length items, if present, are progress limits.
- /// * If any more values remain, an exception is thrown.
- ///
- /// The metrics, progress units and progress limits.
- void Checkpoint(params Double?[] values);
- }
+ /// The header object.
+ /// The delegate to provide actual progress. The parameter of
+ /// the delegate will correspond to the provided .
+ void SetHeader(ProgressHeader header, Action fillAction);
///
- /// This is the 'header' of the progress report.
+ /// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
+ /// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
+ /// at certain moments ('checkpoints') of the computation.
+ ///
+ /// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
+ /// every iteration.
+ ///
+ /// The only parameter, , is interpreted in the following fashion:
+ /// * First MetricNames.Length items, if present, are metrics.
+ /// * Subsequent ProgressNames.Length items, if present, are progress units.
+ /// * Subsequent ProgressNames.Length items, if present, are progress limits.
+ /// * If any more values remain, an exception is thrown.
///
- public sealed class ProgressHeader
- {
- ///
- /// These are the names of the progress 'units', from the least granular to the most granular.
- /// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}.
- /// Will never be null, but can be empty.
- ///
- public readonly IReadOnlyList UnitNames;
+ /// The metrics, progress units and progress limits.
+ void Checkpoint(params Double?[] values);
+}
- ///
- /// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc.
- /// Will never be null, but can be empty.
- ///
- public readonly IReadOnlyList MetricNames;
+///
+/// This is the 'header' of the progress report.
+///
+public sealed class ProgressHeader
+{
+ ///
+ /// These are the names of the progress 'units', from the least granular to the most granular.
+ /// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}.
+ /// Will never be null, but can be empty.
+ ///
+ public readonly IReadOnlyList UnitNames;
- ///
- /// Initialize the header. This will take ownership of the arrays.
- /// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report
- /// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any
- /// progress or metrics to report, it is always better to report them.
- ///
- /// The metrics that the calculation reports. These are completely independent, and there
- /// is no contract on whether the metric values should increase or not. As naming convention,
- /// can have multiple words with spaces, and should be title-cased.
- /// The names of the progress units, listed from least granular to most granular.
- /// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
- /// As naming convention, should be lower-cased and typically plural
- /// (for example, iterations, clusters, examples).
- public ProgressHeader(string[] metricNames, string[] unitNames)
- {
- Contracts.CheckValueOrNull(unitNames);
- Contracts.CheckValueOrNull(metricNames);
+ ///
+ /// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc.
+ /// Will never be null, but can be empty.
+ ///
+ public readonly IReadOnlyList MetricNames;
- UnitNames = unitNames ?? new string[0];
- MetricNames = metricNames ?? new string[0];
- }
+ ///
+ /// Initialize the header. This will take ownership of the arrays.
+ /// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report
+ /// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any
+ /// progress or metrics to report, it is always better to report them.
+ ///
+ /// The metrics that the calculation reports. These are completely independent, and there
+ /// is no contract on whether the metric values should increase or not. As naming convention,
+ /// can have multiple words with spaces, and should be title-cased.
+ /// The names of the progress units, listed from least granular to most granular.
+ /// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
+ /// As naming convention, should be lower-cased and typically plural
+ /// (for example, iterations, clusters, examples).
+ public ProgressHeader(string[] metricNames, string[] unitNames)
+ {
+ Contracts.CheckValueOrNull(unitNames);
+ Contracts.CheckValueOrNull(metricNames);
- ///
- /// A constructor for no metrics, just progress units. As naming convention, should be lower-cased
- /// and typically plural (for example, iterations, clusters, examples).
- ///
- public ProgressHeader(params string[] unitNames)
- : this(null, unitNames)
- {
- }
+ UnitNames = unitNames ?? new string[0];
+ MetricNames = metricNames ?? new string[0];
}
///
- /// A metric/progress holder item.
+ /// A constructor for no metrics, just progress units. As naming convention, should be lower-cased
+ /// and typically plural (for example, iterations, clusters, examples).
///
- public interface IProgressEntry
+ public ProgressHeader(params string[] unitNames)
+ : this(null, unitNames)
{
- ///
- /// Set the progress value for the index to ,
- /// and the limit value for the progress becomes 'unknown'.
- ///
- void SetProgress(int index, Double value);
+ }
+}
+
+///
+/// A metric/progress holder item.
+///
+public interface IProgressEntry
+{
+ ///
+ /// Set the progress value for the index to ,
+ /// and the limit value for the progress becomes 'unknown'.
+ ///
+ void SetProgress(int index, Double value);
- ///
- /// Set the progress value for the index to ,
- /// and the limit value to . If is a NAN, it is set to null instead.
- ///
- void SetProgress(int index, Double value, Double lim);
+ ///
+ /// Set the progress value for the index to ,
+ /// and the limit value to . If is a NAN, it is set to null instead.
+ ///
+ void SetProgress(int index, Double value, Double lim);
- ///
- /// Sets the metric with index to .
- ///
- void SetMetric(int index, Double value);
+ ///
+ /// Sets the metric with index to .
+ ///
+ void SetMetric(int index, Double value);
- }
}
diff --git a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
index f98fe70ff5..b48d55f64f 100644
--- a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
+++ b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
@@ -5,47 +5,46 @@
using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// This interface maps an input to an output . Typically, the output contains
+/// both the input columns and new columns added by the implementing class, although some implementations may
+/// return a subset of the input columns.
+/// This interface is similar to , except it does not have any input role mappings,
+/// so to rebind, the same input column names must be used.
+/// Implementations of this interface are typically created over defined input .
+///
+public interface IRowToRowMapper
{
///
- /// This interface maps an input to an output . Typically, the output contains
- /// both the input columns and new columns added by the implementing class, although some implementations may
- /// return a subset of the input columns.
- /// This interface is similar to , except it does not have any input role mappings,
- /// so to rebind, the same input column names must be used.
- /// Implementations of this interface are typically created over defined input .
+ /// Mappers are defined as accepting inputs with this very specific schema.
///
- public interface IRowToRowMapper
- {
- ///
- /// Mappers are defined as accepting inputs with this very specific schema.
- ///
- DataViewSchema InputSchema { get; }
+ DataViewSchema InputSchema { get; }
- ///
- /// Gets an instance of which describes the columns' names and types in the output generated by this mapper.
- ///
- DataViewSchema OutputSchema { get; }
+ ///
+ /// Gets an instance of which describes the columns' names and types in the output generated by this mapper.
+ ///
+ DataViewSchema OutputSchema { get; }
- ///
- /// Given a set of columns, return the input columns that are needed to generate those output columns.
- ///
- IEnumerable GetDependencies(IEnumerable dependingColumns);
+ ///
+ /// Given a set of columns, return the input columns that are needed to generate those output columns.
+ ///
+ IEnumerable GetDependencies(IEnumerable dependingColumns);
- ///
- /// Get an with the indicated active columns, based on the input .
- /// Getting values on inactive columns of the returned row will throw.
- ///
- /// The of should be the same object as
- /// . Implementors of this method should throw if that is not the case. Conversely,
- /// the returned value must have the same schema as .
- ///
- /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
- /// getters of the input row and base the output values on the current values of the input .
- /// The output values are re-computed when requested through the getters. Also, the returned
- /// will dispose when it is disposed.
- ///
- DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns);
- }
+ ///
+ /// Get an with the indicated active columns, based on the input .
+ /// Getting values on inactive columns of the returned row will throw.
+ ///
+ /// The of should be the same object as
+ /// . Implementors of this method should throw if that is not the case. Conversely,
+ /// the returned value must have the same schema as .
+ ///
+ /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
+ /// getters of the input row and base the output values on the current values of the input .
+ /// The output values are re-computed when requested through the getters. Also, the returned
+ /// will dispose when it is disposed.
+ ///
+ DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns);
}
diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
index 2396966a6c..58605b050b 100644
--- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
+++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
@@ -6,85 +6,84 @@
using System.Collections.Generic;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// A mapper that can be bound to a (which encapsulates a and has mappings from column kinds
+/// to columns of that schema). Binding an to a produces an
+/// , which is an interface that has methods to return the names and indices of the input columns
+/// needed by the mapper to compute its output. The is an extention to this interface, that
+/// can also produce an output given an input . The produced generally contains only the output columns of the mapper, and not
+/// the input columns (but there is nothing preventing an from mapping input columns directly to outputs).
+/// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
+/// features column. New predictors can implement directly. Implementing
+/// includes implementing a corresponding (or ) and a corresponding ISchema
+/// for the output schema of the . In case the interface is implemented,
+/// the SimpleRow class can be used in the method.
+///
+[BestFriend]
+internal interface ISchemaBindableMapper
+{
+ ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
+}
+
+///
+/// This interface is used to map a schema from input columns to output columns. The should keep track
+/// of the input columns that are needed for the mapping.
+///
+[BestFriend]
+internal interface ISchemaBoundMapper
{
///
- /// A mapper that can be bound to a (which encapsulates a and has mappings from column kinds
- /// to columns of that schema). Binding an to a produces an
- /// , which is an interface that has methods to return the names and indices of the input columns
- /// needed by the mapper to compute its output. The is an extention to this interface, that
- /// can also produce an output given an input . The produced generally contains only the output columns of the mapper, and not
- /// the input columns (but there is nothing preventing an from mapping input columns directly to outputs).
- /// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
- /// features column. New predictors can implement directly. Implementing
- /// includes implementing a corresponding (or ) and a corresponding ISchema
- /// for the output schema of the . In case the interface is implemented,
- /// the SimpleRow class can be used in the method.
+ /// The that was passed to the in the binding process.
///
- [BestFriend]
- internal interface ISchemaBindableMapper
- {
- ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
- }
+ RoleMappedSchema InputRoleMappedSchema { get; }
///
- /// This interface is used to map a schema from input columns to output columns. The should keep track
- /// of the input columns that are needed for the mapping.
+ /// Gets schema of this mapper's output.
///
- [BestFriend]
- internal interface ISchemaBoundMapper
- {
- ///
- /// The that was passed to the in the binding process.
- ///
- RoleMappedSchema InputRoleMappedSchema { get; }
+ DataViewSchema OutputSchema { get; }
- ///
- /// Gets schema of this mapper's output.
- ///
- DataViewSchema OutputSchema { get; }
-
- ///
- /// A property to get back the that produced this .
- ///
- ISchemaBindableMapper Bindable { get; }
+ ///
+ /// A property to get back the that produced this .
+ ///
+ ISchemaBindableMapper Bindable { get; }
- ///
- /// This method returns the binding information: which input columns are used and in what roles.
- ///
- IEnumerable> GetInputColumnRoles();
- }
+ ///
+ /// This method returns the binding information: which input columns are used and in what roles.
+ ///
+ IEnumerable> GetInputColumnRoles();
+}
+///
+/// This interface extends .
+///
+[BestFriend]
+internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
+{
///
- /// This interface extends .
+ /// Input schema accepted.
///
- [BestFriend]
- internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
- {
- ///
- /// Input schema accepted.
- ///
- DataViewSchema InputSchema { get; }
+ DataViewSchema InputSchema { get; }
- ///
- /// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
- ///
- IEnumerable GetDependenciesForNewColumns(IEnumerable dependingColumns);
+ ///
+ /// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
+ ///
+ IEnumerable GetDependenciesForNewColumns(IEnumerable dependingColumns);
- ///
- /// Get an with the indicated active columns, based on the input .
- /// Getting values on inactive columns of the returned row will throw.
- ///
- /// The of should be the same object as
- /// . Implementors of this method should throw if that is not the case. Conversely,
- /// the returned value must have the same schema as .
- ///
- /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
- /// getters of the input row and base the output values on the current values of the input .
- /// The output values are re-computed when requested through the getters. Also, the returned
- /// will dispose when it is disposed.
- ///
- DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns);
- }
+ ///
+ /// Get an with the indicated active columns, based on the input .
+ /// Getting values on inactive columns of the returned row will throw.
+ ///
+ /// The of should be the same object as
+ /// . Implementors of this method should throw if that is not the case. Conversely,
+ /// the returned value must have the same schema as .
+ ///
+ /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
+ /// getters of the input row and base the output values on the current values of the input .
+ /// The output values are re-computed when requested through the getters. Also, the returned
+ /// will dispose when it is disposed.
+ ///
+ DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns);
}
diff --git a/src/Microsoft.ML.Core/Data/IValueMapper.cs b/src/Microsoft.ML.Core/Data/IValueMapper.cs
index 6dfea747ca..34a64c5fcf 100644
--- a/src/Microsoft.ML.Core/Data/IValueMapper.cs
+++ b/src/Microsoft.ML.Core/Data/IValueMapper.cs
@@ -2,57 +2,56 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Data
-{
- ///
- /// Delegate type to map/convert a value.
- ///
- [BestFriend]
- internal delegate void ValueMapper(in TSrc src, ref TDst dst);
+namespace Microsoft.ML.Data;
- ///
- /// Delegate type to map/convert among three values, for example, one input with two
- /// outputs, or two inputs with one output.
- ///
- [BestFriend]
- internal delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
+///
+/// Delegate type to map/convert a value.
+///
+[BestFriend]
+internal delegate void ValueMapper(in TSrc src, ref TDst dst);
+
+///
+/// Delegate type to map/convert among three values, for example, one input with two
+/// outputs, or two inputs with one output.
+///
+[BestFriend]
+internal delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
+
+///
+/// Interface for mapping a single input value (of an indicated ColumnType) to
+/// an output value (of an indicated ColumnType). This interface is commonly implemented
+/// by predictors. Note that the input and output ColumnTypes determine the proper
+/// type arguments for GetMapper, but typically contain additional information like
+/// vector lengths.
+///
+[BestFriend]
+internal interface IValueMapper
+{
+ DataViewType InputType { get; }
+ DataViewType OutputType { get; }
///
- /// Interface for mapping a single input value (of an indicated ColumnType) to
- /// an output value (of an indicated ColumnType). This interface is commonly implemented
- /// by predictors. Note that the input and output ColumnTypes determine the proper
- /// type arguments for GetMapper, but typically contain additional information like
- /// vector lengths.
+ /// Get a delegate used for mapping from input to output values. Note that the delegate
+ /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
///
- [BestFriend]
- internal interface IValueMapper
- {
- DataViewType InputType { get; }
- DataViewType OutputType { get; }
+ ValueMapper GetMapper();
+}
- ///
- /// Get a delegate used for mapping from input to output values. Note that the delegate
- /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
- ///
- ValueMapper GetMapper();
- }
+///
+/// Interface for mapping a single input value (of an indicated ColumnType) to an output value
+/// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented
+/// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper
+/// type arguments for GetMapper, but typically contain additional information like
+/// vector lengths.
+///
+[BestFriend]
+internal interface IValueMapperDist : IValueMapper
+{
+ DataViewType DistType { get; }
///
- /// Interface for mapping a single input value (of an indicated ColumnType) to an output value
- /// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented
- /// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper
- /// type arguments for GetMapper, but typically contain additional information like
- /// vector lengths.
+ /// Get a delegate used for mapping from input to output values. Note that the delegate
+ /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
///
- [BestFriend]
- internal interface IValueMapperDist : IValueMapper
- {
- DataViewType DistType { get; }
-
- ///
- /// Get a delegate used for mapping from input to output values. Note that the delegate
- /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
- ///
- ValueMapper GetMapper();
- }
+ ValueMapper GetMapper();
}
diff --git a/src/Microsoft.ML.Core/Data/InPredicate.cs b/src/Microsoft.ML.Core/Data/InPredicate.cs
index 8e0c27e070..bf8535b1b1 100644
--- a/src/Microsoft.ML.Core/Data/InPredicate.cs
+++ b/src/Microsoft.ML.Core/Data/InPredicate.cs
@@ -2,8 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Data
-{
- [BestFriend]
- internal delegate bool InPredicate(in T value);
-}
+namespace Microsoft.ML.Data;
+
+[BestFriend]
+internal delegate bool InPredicate(in T value);
diff --git a/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs b/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs
index 096e36fdf0..10b5381355 100644
--- a/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs
+++ b/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs
@@ -4,21 +4,20 @@
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// Extension methods related to the class.
+///
+[BestFriend]
+internal static class KeyTypeExtensions
{
///
- /// Extension methods related to the class.
+ /// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
///
- [BestFriend]
- internal static class KeyTypeExtensions
+ public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
{
- ///
- /// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
- ///
- public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
- {
- ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
- return (int)key.Count;
- }
+ ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
+ return (int)key.Count;
}
}
diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs
index 319dbbbbb8..872b339c0c 100644
--- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs
@@ -4,50 +4,49 @@
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// Base class for a cursor has an input cursor, but still needs to do work on .
+///
+[BestFriend]
+internal abstract class LinkedRootCursorBase : RootCursorBase
{
+
+ /// Gets the input cursor.
+ protected DataViewRowCursor Input { get; }
+
///
- /// Base class for a cursor has an input cursor, but still needs to do work on .
+ /// Returns the root cursor of the input. It should be used to perform
+ /// operations, but with the distinction, as compared to , that this is not
+ /// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is
+ /// on filter cursor implementations, where how that input cursor is consumed is very implementation specific.
+ /// That is why this is , not .
///
- [BestFriend]
- internal abstract class LinkedRootCursorBase : RootCursorBase
- {
+ protected DataViewRowCursor Root { get; }
- /// Gets the input cursor.
- protected DataViewRowCursor Input { get; }
+ private bool _disposed;
- ///
- /// Returns the root cursor of the input. It should be used to perform
- /// operations, but with the distinction, as compared to , that this is not
- /// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is
- /// on filter cursor implementations, where how that input cursor is consumed is very implementation specific.
- /// That is why this is , not .
- ///
- protected DataViewRowCursor Root { get; }
+ protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input)
+ : base(provider)
+ {
+ Ch.AssertValue(input, nameof(input));
- private bool _disposed;
+ Input = input;
+ Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input;
+ }
- protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input)
- : base(provider)
+ protected override void Dispose(bool disposing)
+ {
+ if (_disposed)
+ return;
+ if (disposing)
{
- Ch.AssertValue(input, nameof(input));
+ Input.Dispose();
+ // The base class should set the state to done under these circumstances.
- Input = input;
- Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input;
- }
-
- protected override void Dispose(bool disposing)
- {
- if (_disposed)
- return;
- if (disposing)
- {
- Input.Dispose();
- // The base class should set the state to done under these circumstances.
-
- }
- _disposed = true;
- base.Dispose(disposing);
}
+ _disposed = true;
+ base.Dispose(disposing);
}
}
diff --git a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs
index 670f37ffe7..e618ca2c62 100644
--- a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs
@@ -4,40 +4,39 @@
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// Base class for creating a cursor of rows that filters out some input rows.
+///
+[BestFriend]
+internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
{
- ///
- /// Base class for creating a cursor of rows that filters out some input rows.
- ///
- [BestFriend]
- internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
- {
- public override long Batch => Input.Batch;
+ public override long Batch => Input.Batch;
- protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
- : base(provider, input, schema, active)
- {
- }
+ protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
+ : base(provider, input, schema, active)
+ {
+ }
- public override ValueGetter GetIdGetter()
- {
- return Input.GetIdGetter();
- }
+ public override ValueGetter GetIdGetter()
+ {
+ return Input.GetIdGetter();
+ }
- protected override bool MoveNextCore()
+ protected override bool MoveNextCore()
+ {
+ while (Root.MoveNext())
{
- while (Root.MoveNext())
- {
- if (Accept())
- return true;
- }
-
- return false;
+ if (Accept())
+ return true;
}
- ///
- /// Return whether the current input row should be returned by this cursor.
- ///
- protected abstract bool Accept();
+ return false;
}
+
+ ///
+ /// Return whether the current input row should be returned by this cursor.
+ ///
+ protected abstract bool Accept();
}
diff --git a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs
index e1bca40a1a..2771f952c6 100644
--- a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs
@@ -4,50 +4,49 @@
using Microsoft.ML.Runtime;
-namespace Microsoft.ML.Data
+namespace Microsoft.ML.Data;
+
+///
+/// A base class for a that has an input cursor, but still needs to do work on
+/// . Note that the default
+/// assumes that each input column is exposed as an
+/// output column with the same column index.
+///
+[BestFriend]
+internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
{
- ///
- /// A base class for a that has an input cursor, but still needs to do work on
- /// . Note that the default
- /// assumes that each input column is exposed as an
- /// output column with the same column index.
- ///
- [BestFriend]
- internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
- {
- private readonly bool[] _active;
+ private readonly bool[] _active;
- /// Gets row's schema.
- public sealed override DataViewSchema Schema { get; }
+ /// Gets row's schema.
+ public sealed override DataViewSchema Schema { get; }
- protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
- : base(provider, input)
- {
- Ch.CheckValue(schema, nameof(schema));
- Ch.Check(active == null || active.Length == schema.Count);
- _active = active;
- Schema = schema;
- }
+ protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
+ : base(provider, input)
+ {
+ Ch.CheckValue(schema, nameof(schema));
+ Ch.Check(active == null || active.Length == schema.Count);
+ _active = active;
+ Schema = schema;
+ }
- ///
- /// Returns whether the given column is active in this row.
- ///
- public sealed override bool IsColumnActive(DataViewSchema.Column column)
- {
- Ch.Check(column.Index < Schema.Count);
- return _active == null || _active[column.Index];
- }
+ ///
+ /// Returns whether the given column is active in this row.
+ ///
+ public sealed override bool IsColumnActive(DataViewSchema.Column column)
+ {
+ Ch.Check(column.Index < Schema.Count);
+ return _active == null || _active[column.Index];
+ }
- ///
- /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
- /// This throws if the column is not active in this row, or if the type
- /// differs from this column's type.
- ///
- /// is the column's content type.
- /// is the output column whose getter should be returned.
- public override ValueGetter GetGetter(DataViewSchema.Column column)
- {
- return Input.GetGetter(column);
- }
+ ///
+ /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
+ /// This throws if the column is not active in this row, or if the type
+ /// differs from this column's type.
+ ///
+ /// is the column's content type.
+ /// is the output column whose getter should be returned.
+ public override ValueGetter GetGetter(DataViewSchema.Column column)
+ {
+ return Input.GetGetter(column);
}
}
diff --git a/src/Microsoft.ML.Core/Data/ModelHeader.cs b/src/Microsoft.ML.Core/Data/ModelHeader.cs
index 385fb6d068..01a694b501 100644
--- a/src/Microsoft.ML.Core/Data/ModelHeader.cs
+++ b/src/Microsoft.ML.Core/Data/ModelHeader.cs
@@ -9,654 +9,653 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML
+namespace Microsoft.ML;
+
+[BestFriend]
+[StructLayout(LayoutKind.Explicit, Size = Size)]
+internal struct ModelHeader
{
- [BestFriend]
- [StructLayout(LayoutKind.Explicit, Size = Size)]
- internal struct ModelHeader
+ ///
+ /// This spells 'ML MODEL' with zero replacing space (assuming little endian).
+ ///
+ public const ulong SignatureValue = 0x4C45444F4D004C4DUL;
+ public const ulong TailSignatureValue = 0x4D4C004D4F44454CUL;
+
+ private const uint VerAssemblyNameSupported = 0x00010002;
+
+ // These are private since they change over time. If we make them public we risk
+ // another assembly containing a "copy" of their value when the other assembly
+ // was compiled, which might not match the code that can load this.
+ //private const uint VerWrittenCur = 0x00010001; // Initial
+ private const uint VerWrittenCur = 0x00010002; // Added AssemblyName
+ private const uint VerReadableCur = 0x00010002;
+ private const uint VerWeCanReadBack = 0x00010001;
+
+ [FieldOffset(0x00)]
+ public ulong Signature;
+ [FieldOffset(0x08)]
+ public uint VerWritten;
+ [FieldOffset(0x0C)]
+ public uint VerReadable;
+
+ // Location and size (in bytes) of the model block. Note that it is legal for CbModel to be zero.
+ [FieldOffset(0x10)]
+ public long FpModel;
+ [FieldOffset(0x18)]
+ public long CbModel;
+
+ // Location and size (in bytes) of the string table block. If there are no strings, these are both zero.
+ // If there are n strings then CbStringTable is n * sizeof(long), so is divisible by sizeof(long).
+ // Each long is the offset from header.FpStringChars of the "lim" of the characters for that string.
+ // The "min" is the "lim" for the previous string. The 0th string's "min" is offset zero.
+ [FieldOffset(0x20)]
+ public long FpStringTable;
+ [FieldOffset(0x28)]
+ public long CbStringTable;
+
+ // Location and size (in bytes) of the string characters, without any prefix or termination. The characters
+ // unicode (UTF-16).
+ [FieldOffset(0x30)]
+ public long FpStringChars;
+ [FieldOffset(0x38)]
+ public long CbStringChars;
+
+ // ModelSignature specifies the format of the model block. These values are assigned by
+ // the code that writes the model block.
+ [FieldOffset(0x40)]
+ public ulong ModelSignature;
+ [FieldOffset(0x48)]
+ public uint ModelVerWritten;
+ [FieldOffset(0x4C)]
+ public uint ModelVerReadable;
+
+ // These encode up to two loader signature strings. These are up to 24 ascii characters each.
+ [FieldOffset(0x50)]
+ public ulong LoaderSignature0;
+ [FieldOffset(0x58)]
+ public ulong LoaderSignature1;
+ [FieldOffset(0x60)]
+ public ulong LoaderSignature2;
+ [FieldOffset(0x68)]
+ public ulong LoaderSignatureAlt0;
+ [FieldOffset(0x70)]
+ public ulong LoaderSignatureAlt1;
+ [FieldOffset(0x78)]
+ public ulong LoaderSignatureAlt2;
+
+ // Location of the "tail" signature, which is simply the TailSignatureValue.
+ [FieldOffset(0x80)]
+ public long FpTail;
+ [FieldOffset(0x88)]
+ public long FpLim;
+
+ // Location of the fully qualified assembly name string (in UTF-16).
+ // Note that it is legal for both to be zero.
+ [FieldOffset(0x90)]
+ public long FpAssemblyName;
+ [FieldOffset(0x98)]
+ public uint CbAssemblyName;
+
+ public const int Size = 0x0100;
+
+ // Utilities for writing.
+
+ ///
+ /// Initialize the header and writer for writing. The value of fpMin and header
+ /// should be passed to the other utility methods here.
+ ///
+ public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHeader header)
{
- ///
- /// This spells 'ML MODEL' with zero replacing space (assuming little endian).
- ///
- public const ulong SignatureValue = 0x4C45444F4D004C4DUL;
- public const ulong TailSignatureValue = 0x4D4C004D4F44454CUL;
-
- private const uint VerAssemblyNameSupported = 0x00010002;
-
- // These are private since they change over time. If we make them public we risk
- // another assembly containing a "copy" of their value when the other assembly
- // was compiled, which might not match the code that can load this.
- //private const uint VerWrittenCur = 0x00010001; // Initial
- private const uint VerWrittenCur = 0x00010002; // Added AssemblyName
- private const uint VerReadableCur = 0x00010002;
- private const uint VerWeCanReadBack = 0x00010001;
-
- [FieldOffset(0x00)]
- public ulong Signature;
- [FieldOffset(0x08)]
- public uint VerWritten;
- [FieldOffset(0x0C)]
- public uint VerReadable;
-
- // Location and size (in bytes) of the model block. Note that it is legal for CbModel to be zero.
- [FieldOffset(0x10)]
- public long FpModel;
- [FieldOffset(0x18)]
- public long CbModel;
-
- // Location and size (in bytes) of the string table block. If there are no strings, these are both zero.
- // If there are n strings then CbStringTable is n * sizeof(long), so is divisible by sizeof(long).
- // Each long is the offset from header.FpStringChars of the "lim" of the characters for that string.
- // The "min" is the "lim" for the previous string. The 0th string's "min" is offset zero.
- [FieldOffset(0x20)]
- public long FpStringTable;
- [FieldOffset(0x28)]
- public long CbStringTable;
-
- // Location and size (in bytes) of the string characters, without any prefix or termination. The characters
- // unicode (UTF-16).
- [FieldOffset(0x30)]
- public long FpStringChars;
- [FieldOffset(0x38)]
- public long CbStringChars;
-
- // ModelSignature specifies the format of the model block. These values are assigned by
- // the code that writes the model block.
- [FieldOffset(0x40)]
- public ulong ModelSignature;
- [FieldOffset(0x48)]
- public uint ModelVerWritten;
- [FieldOffset(0x4C)]
- public uint ModelVerReadable;
-
- // These encode up to two loader signature strings. These are up to 24 ascii characters each.
- [FieldOffset(0x50)]
- public ulong LoaderSignature0;
- [FieldOffset(0x58)]
- public ulong LoaderSignature1;
- [FieldOffset(0x60)]
- public ulong LoaderSignature2;
- [FieldOffset(0x68)]
- public ulong LoaderSignatureAlt0;
- [FieldOffset(0x70)]
- public ulong LoaderSignatureAlt1;
- [FieldOffset(0x78)]
- public ulong LoaderSignatureAlt2;
-
- // Location of the "tail" signature, which is simply the TailSignatureValue.
- [FieldOffset(0x80)]
- public long FpTail;
- [FieldOffset(0x88)]
- public long FpLim;
-
- // Location of the fully qualified assembly name string (in UTF-16).
- // Note that it is legal for both to be zero.
- [FieldOffset(0x90)]
- public long FpAssemblyName;
- [FieldOffset(0x98)]
- public uint CbAssemblyName;
-
- public const int Size = 0x0100;
-
- // Utilities for writing.
-
- ///
- /// Initialize the header and writer for writing. The value of fpMin and header
- /// should be passed to the other utility methods here.
- ///
- public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHeader header)
- {
- Contracts.Assert(Marshal.SizeOf(typeof(ModelHeader)) == Size);
- Contracts.CheckValue(writer, nameof(writer));
-
- fpMin = writer.FpCur();
- header = default(ModelHeader);
- header.Signature = SignatureValue;
- header.VerWritten = VerWrittenCur;
- header.VerReadable = VerReadableCur;
- header.FpModel = ModelHeader.Size;
-
- // Write a blank header - the correct information is written by WriteHeaderAndTail.
- byte[] headerBytes = new byte[ModelHeader.Size];
- writer.Write(headerBytes);
- Contracts.CheckIO(writer.FpCur() == fpMin + ModelHeader.Size);
- }
+ Contracts.Assert(Marshal.SizeOf(typeof(ModelHeader)) == Size);
+ Contracts.CheckValue(writer, nameof(writer));
+
+ fpMin = writer.FpCur();
+ header = default(ModelHeader);
+ header.Signature = SignatureValue;
+ header.VerWritten = VerWrittenCur;
+ header.VerReadable = VerReadableCur;
+ header.FpModel = ModelHeader.Size;
+
+ // Write a blank header - the correct information is written by WriteHeaderAndTail.
+ byte[] headerBytes = new byte[ModelHeader.Size];
+ writer.Write(headerBytes);
+ Contracts.CheckIO(writer.FpCur() == fpMin + ModelHeader.Size);
+ }
- ///
- /// The current writer position should be the end of the model blob. Records the model size, writes the string table,
- /// completes and writes the header, and writes the tail.
- ///
- public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null)
- {
- Contracts.CheckValue(writer, nameof(writer));
- Contracts.CheckParam(fpMin >= 0, nameof(fpMin));
- Contracts.CheckValueOrNull(pool);
+ ///
+ /// The current writer position should be the end of the model blob. Records the model size, writes the string table,
+ /// completes and writes the header, and writes the tail.
+ ///
+ public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null)
+ {
+ Contracts.CheckValue(writer, nameof(writer));
+ Contracts.CheckParam(fpMin >= 0, nameof(fpMin));
+ Contracts.CheckValueOrNull(pool);
- // Record the model size.
- EndModelCore(writer, fpMin, ref header);
+ // Record the model size.
+ EndModelCore(writer, fpMin, ref header);
- Contracts.Check(header.FpStringTable == 0);
- Contracts.Check(header.CbStringTable == 0);
- Contracts.Check(header.FpStringChars == 0);
- Contracts.Check(header.CbStringChars == 0);
+ Contracts.Check(header.FpStringTable == 0);
+ Contracts.Check(header.CbStringTable == 0);
+ Contracts.Check(header.FpStringChars == 0);
+ Contracts.Check(header.CbStringChars == 0);
- // Write the strings.
- if (pool != null && pool.Count > 0)
+ // Write the strings.
+ if (pool != null && pool.Count > 0)
+ {
+ header.FpStringTable = writer.FpCur() - fpMin;
+ long offset = 0;
+ int cv = 0;
+ // REVIEW: Implement an indexer on pool!
+ foreach (var ns in pool)
{
- header.FpStringTable = writer.FpCur() - fpMin;
- long offset = 0;
- int cv = 0;
- // REVIEW: Implement an indexer on pool!
- foreach (var ns in pool)
- {
- Contracts.Assert(ns.Id == cv);
- offset += ns.Value.Length * sizeof(char);
- writer.Write(offset);
- cv++;
- }
- Contracts.Assert(cv == pool.Count);
- header.CbStringTable = pool.Count * sizeof(long);
- header.FpStringChars = writer.FpCur() - fpMin;
- Contracts.Assert(header.FpStringChars == header.FpStringTable + header.CbStringTable);
- foreach (var ns in pool)
- {
- foreach (var ch in ns.Value.Span)
- writer.Write((short)ch);
- }
- header.CbStringChars = writer.FpCur() - header.FpStringChars - fpMin;
- Contracts.Assert(offset == header.CbStringChars);
+ Contracts.Assert(ns.Id == cv);
+ offset += ns.Value.Length * sizeof(char);
+ writer.Write(offset);
+ cv++;
+ }
+ Contracts.Assert(cv == pool.Count);
+ header.CbStringTable = pool.Count * sizeof(long);
+ header.FpStringChars = writer.FpCur() - fpMin;
+ Contracts.Assert(header.FpStringChars == header.FpStringTable + header.CbStringTable);
+ foreach (var ns in pool)
+ {
+ foreach (var ch in ns.Value.Span)
+ writer.Write((short)ch);
}
+ header.CbStringChars = writer.FpCur() - header.FpStringChars - fpMin;
+ Contracts.Assert(offset == header.CbStringChars);
+ }
- WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName);
+ WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName);
- WriteHeaderAndTailCore(writer, fpMin, ref header);
- }
+ WriteHeaderAndTailCore(writer, fpMin, ref header);
+ }
- private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName)
+ private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName)
+ {
+ if (!string.IsNullOrEmpty(loaderAssemblyName))
{
- if (!string.IsNullOrEmpty(loaderAssemblyName))
- {
- header.FpAssemblyName = writer.FpCur() - fpMin;
- header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char);
+ header.FpAssemblyName = writer.FpCur() - fpMin;
+ header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char);
- foreach (var ch in loaderAssemblyName)
- writer.Write((short)ch);
- }
- else
- {
- header.FpAssemblyName = 0;
- header.CbAssemblyName = 0;
- }
+ foreach (var ch in loaderAssemblyName)
+ writer.Write((short)ch);
}
-
- ///
- /// The current writer position should be where the tail belongs. Writes the header and tail.
- /// Typically this isn't called directly unless you are doing custom string table serialization.
- /// In that case you should have called EndModelCore before writing the string table information.
- ///
- public static void WriteHeaderAndTailCore(BinaryWriter writer, long fpMin, ref ModelHeader header)
+ else
{
- Contracts.CheckValue(writer, nameof(writer));
- Contracts.CheckParam(fpMin >= 0, nameof(fpMin));
-
- header.FpTail = writer.FpCur() - fpMin;
- writer.Write(TailSignatureValue);
- header.FpLim = writer.FpCur() - fpMin;
-
- Exception ex;
- bool res = TryValidate(ref header, header.FpLim, out ex);
- // If this fails, we didn't construct the header correctly. This is both a bug and
- // something we want to protect against at runtime, hence both assert and check.
- Contracts.Assert(res);
- Contracts.Check(res);
-
- // Write the header, then seek back to the end.
- writer.Seek(fpMin);
- byte[] headerBytes = new byte[ModelHeader.Size];
- MarshalToBytes(ref header, headerBytes);
- writer.Write(headerBytes);
- Contracts.Assert(writer.FpCur() == fpMin + ModelHeader.Size);
- writer.Seek(header.FpLim + fpMin);
+ header.FpAssemblyName = 0;
+ header.CbAssemblyName = 0;
}
+ }
- ///
- /// The current writer position should be the end of the model blob. Records the size of the model blob.
- /// Typically this isn't called directly unless you are doing custom string table serialization.
- ///
- public static void EndModelCore(BinaryWriter writer, long fpMin, ref ModelHeader header)
- {
- Contracts.Check(header.FpModel == ModelHeader.Size);
- Contracts.Check(header.CbModel == 0);
+ ///
+ /// The current writer position should be where the tail belongs. Writes the header and tail.
+ /// Typically this isn't called directly unless you are doing custom string table serialization.
+ /// In that case you should have called EndModelCore before writing the string table information.
+ ///
+ public static void WriteHeaderAndTailCore(BinaryWriter writer, long fpMin, ref ModelHeader header)
+ {
+ Contracts.CheckValue(writer, nameof(writer));
+ Contracts.CheckParam(fpMin >= 0, nameof(fpMin));
+
+ header.FpTail = writer.FpCur() - fpMin;
+ writer.Write(TailSignatureValue);
+ header.FpLim = writer.FpCur() - fpMin;
+
+ Exception ex;
+ bool res = TryValidate(ref header, header.FpLim, out ex);
+ // If this fails, we didn't construct the header correctly. This is both a bug and
+ // something we want to protect against at runtime, hence both assert and check.
+ Contracts.Assert(res);
+ Contracts.Check(res);
+
+ // Write the header, then seek back to the end.
+ writer.Seek(fpMin);
+ byte[] headerBytes = new byte[ModelHeader.Size];
+ MarshalToBytes(ref header, headerBytes);
+ writer.Write(headerBytes);
+ Contracts.Assert(writer.FpCur() == fpMin + ModelHeader.Size);
+ writer.Seek(header.FpLim + fpMin);
+ }
- long fpCur = writer.FpCur();
- Contracts.Check(fpCur - fpMin >= header.FpModel);
+ ///
+ /// The current writer position should be the end of the model blob. Records the size of the model blob.
+ /// Typically this isn't called directly unless you are doing custom string table serialization.
+ ///
+ public static void EndModelCore(BinaryWriter writer, long fpMin, ref ModelHeader header)
+ {
+ Contracts.Check(header.FpModel == ModelHeader.Size);
+ Contracts.Check(header.CbModel == 0);
- // Record the size of the model.
- header.CbModel = fpCur - header.FpModel - fpMin;
- }
+ long fpCur = writer.FpCur();
+ Contracts.Check(fpCur - fpMin >= header.FpModel);
- ///
- /// Sets the version information the header.
- ///
- public static void SetVersionInfo(ref ModelHeader header, VersionInfo ver)
- {
- header.ModelSignature = ver.ModelSignature;
- header.ModelVerWritten = ver.VerWrittenCur;
- header.ModelVerReadable = ver.VerReadableCur;
- SetLoaderSig(ref header, ver.LoaderSignature);
- SetLoaderSigAlt(ref header, ver.LoaderSignatureAlt);
- }
+ // Record the size of the model.
+ header.CbModel = fpCur - header.FpModel - fpMin;
+ }
- ///
- /// Record the given loader sig in the header. If sig is null, clears the loader sig.
- ///
- public static void SetLoaderSig(ref ModelHeader header, string sig)
- {
- header.LoaderSignature0 = 0;
- header.LoaderSignature1 = 0;
- header.LoaderSignature2 = 0;
+ ///
+ /// Sets the version information the header.
+ ///
+ public static void SetVersionInfo(ref ModelHeader header, VersionInfo ver)
+ {
+ header.ModelSignature = ver.ModelSignature;
+ header.ModelVerWritten = ver.VerWrittenCur;
+ header.ModelVerReadable = ver.VerReadableCur;
+ SetLoaderSig(ref header, ver.LoaderSignature);
+ SetLoaderSigAlt(ref header, ver.LoaderSignatureAlt);
+ }
+
+ ///
+ /// Record the given loader sig in the header. If sig is null, clears the loader sig.
+ ///
+ public static void SetLoaderSig(ref ModelHeader header, string sig)
+ {
+ header.LoaderSignature0 = 0;
+ header.LoaderSignature1 = 0;
+ header.LoaderSignature2 = 0;
- if (sig == null)
- return;
+ if (sig == null)
+ return;
- Contracts.Check(sig.Length <= 24);
- for (int ich = 0; ich < sig.Length; ich++)
- {
- char ch = sig[ich];
- Contracts.Check(ch <= 0xFF);
- if (ich < 8)
- header.LoaderSignature0 |= (ulong)ch << (ich * 8);
- else if (ich < 16)
- header.LoaderSignature1 |= (ulong)ch << ((ich - 8) * 8);
- else if (ich < 24)
- header.LoaderSignature2 |= (ulong)ch << ((ich - 16) * 8);
- }
+ Contracts.Check(sig.Length <= 24);
+ for (int ich = 0; ich < sig.Length; ich++)
+ {
+ char ch = sig[ich];
+ Contracts.Check(ch <= 0xFF);
+ if (ich < 8)
+ header.LoaderSignature0 |= (ulong)ch << (ich * 8);
+ else if (ich < 16)
+ header.LoaderSignature1 |= (ulong)ch << ((ich - 8) * 8);
+ else if (ich < 24)
+ header.LoaderSignature2 |= (ulong)ch << ((ich - 16) * 8);
}
+ }
- ///
- /// Record the given alternate loader sig in the header. If sig is null, clears the alternate loader sig.
- ///
- public static void SetLoaderSigAlt(ref ModelHeader header, string sig)
- {
- header.LoaderSignatureAlt0 = 0;
- header.LoaderSignatureAlt1 = 0;
- header.LoaderSignatureAlt2 = 0;
+ ///
+ /// Record the given alternate loader sig in the header. If sig is null, clears the alternate loader sig.
+ ///
+ public static void SetLoaderSigAlt(ref ModelHeader header, string sig)
+ {
+ header.LoaderSignatureAlt0 = 0;
+ header.LoaderSignatureAlt1 = 0;
+ header.LoaderSignatureAlt2 = 0;
- if (sig == null)
- return;
+ if (sig == null)
+ return;
- Contracts.Check(sig.Length <= 24);
- for (int ich = 0; ich < sig.Length; ich++)
- {
- char ch = sig[ich];
- Contracts.Check(ch <= 0xFF);
- if (ich < 8)
- header.LoaderSignatureAlt0 |= (ulong)ch << (ich * 8);
- else if (ich < 16)
- header.LoaderSignatureAlt1 |= (ulong)ch << ((ich - 8) * 8);
- else if (ich < 24)
- header.LoaderSignatureAlt2 |= (ulong)ch << ((ich - 16) * 8);
- }
+ Contracts.Check(sig.Length <= 24);
+ for (int ich = 0; ich < sig.Length; ich++)
+ {
+ char ch = sig[ich];
+ Contracts.Check(ch <= 0xFF);
+ if (ich < 8)
+ header.LoaderSignatureAlt0 |= (ulong)ch << (ich * 8);
+ else if (ich < 16)
+ header.LoaderSignatureAlt1 |= (ulong)ch << ((ich - 8) * 8);
+ else if (ich < 24)
+ header.LoaderSignatureAlt2 |= (ulong)ch << ((ich - 16) * 8);
}
+ }
- ///
- /// Low level method for copying bytes from a header structure into a byte array.
- ///
- public static void MarshalToBytes(ref ModelHeader header, byte[] bytes)
+ ///
+ /// Low level method for copying bytes from a header structure into a byte array.
+ ///
+ public static void MarshalToBytes(ref ModelHeader header, byte[] bytes)
+ {
+ Contracts.Check(Utils.Size(bytes) >= Size);
+ unsafe
{
- Contracts.Check(Utils.Size(bytes) >= Size);
- unsafe
- {
- fixed (ModelHeader* pheader = &header)
- Marshal.Copy((IntPtr)pheader, bytes, 0, Size);
- }
+ fixed (ModelHeader* pheader = &header)
+ Marshal.Copy((IntPtr)pheader, bytes, 0, Size);
}
+ }
- // Utilities for reading.
+ // Utilities for reading.
- ///
- /// Read the model header, strings, etc from reader. Also validates the header (throws if bad).
- /// Leaves the reader position at the beginning of the model blob.
- ///
- public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader)
- {
- fpMin = reader.FpCur();
+ ///
+ /// Read the model header, strings, etc from reader. Also validates the header (throws if bad).
+ /// Leaves the reader position at the beginning of the model blob.
+ ///
+ public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader)
+ {
+ fpMin = reader.FpCur();
- byte[] headerBytes = reader.ReadBytes(ModelHeader.Size);
- Contracts.CheckDecode(headerBytes.Length == ModelHeader.Size);
- ModelHeader.MarshalFromBytes(out header, headerBytes);
+ byte[] headerBytes = reader.ReadBytes(ModelHeader.Size);
+ Contracts.CheckDecode(headerBytes.Length == ModelHeader.Size);
+ ModelHeader.MarshalFromBytes(out header, headerBytes);
- Exception ex;
- if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex))
- throw ex;
+ Exception ex;
+ if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex))
+ throw ex;
- reader.Seek(header.FpModel + fpMin);
- }
+ reader.Seek(header.FpModel + fpMin);
+ }
- ///
- /// Finish reading. Checks that the current reader position is the end of the model blob.
- /// Seeks to the end of the entire model file (after the tail).
- ///
- public static void EndRead(long fpMin, ref ModelHeader header, BinaryReader reader)
- {
- Contracts.CheckDecode(header.FpModel + header.CbModel == reader.FpCur() - fpMin);
- reader.Seek(header.FpLim + fpMin);
- }
+ ///
+ /// Finish reading. Checks that the current reader position is the end of the model blob.
+ /// Seeks to the end of the entire model file (after the tail).
+ ///
+ public static void EndRead(long fpMin, ref ModelHeader header, BinaryReader reader)
+ {
+ Contracts.CheckDecode(header.FpModel + header.CbModel == reader.FpCur() - fpMin);
+ reader.Seek(header.FpLim + fpMin);
+ }
- ///
- /// Performs standard version validation.
- ///
- public static void CheckVersionInfo(ref ModelHeader header, VersionInfo ver)
+ ///
+ /// Performs standard version validation.
+ ///
+ public static void CheckVersionInfo(ref ModelHeader header, VersionInfo ver)
+ {
+ Contracts.CheckDecode(header.ModelSignature == ver.ModelSignature, "Unknown file type");
+ Contracts.CheckDecode(header.ModelVerReadable <= header.ModelVerWritten, "Corrupt file header");
+ if (header.ModelVerReadable > ver.VerWrittenCur)
+ throw Contracts.ExceptDecode("Cause: ML.NET {0} cannont read component '{1}' of the model, because the model is too new.\n" +
+ "Suggestion: Make sure the model is trained with ML.NET {0} or older.\n" +
+ "Debug details: Maximum expected version {2}, got {3}.",
+ typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur);
+ if (header.ModelVerWritten < ver.VerWeCanReadBack)
{
- Contracts.CheckDecode(header.ModelSignature == ver.ModelSignature, "Unknown file type");
- Contracts.CheckDecode(header.ModelVerReadable <= header.ModelVerWritten, "Corrupt file header");
- if (header.ModelVerReadable > ver.VerWrittenCur)
- throw Contracts.ExceptDecode("Cause: ML.NET {0} cannont read component '{1}' of the model, because the model is too new.\n" +
- "Suggestion: Make sure the model is trained with ML.NET {0} or older.\n" +
- "Debug details: Maximum expected version {2}, got {3}.",
- typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur);
- if (header.ModelVerWritten < ver.VerWeCanReadBack)
- {
- // Breaking backwards compatibility is something we should avoid if at all possible. If
- // this message is observed, it may be a bug.
- throw Contracts.ExceptDecode("Cause: ML.NET {0} cannot read component '{1}' of the model, because the model is too old.\n" +
- "Suggestion: Make sure the model is trained with ML.NET {0}.\n" +
- "Debug details: Minimum expected version {2}, got {3}.",
- typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur);
- }
+ // Breaking backwards compatibility is something we should avoid if at all possible. If
+ // this message is observed, it may be a bug.
+ throw Contracts.ExceptDecode("Cause: ML.NET {0} cannot read component '{1}' of the model, because the model is too old.\n" +
+ "Suggestion: Make sure the model is trained with ML.NET {0}.\n" +
+ "Debug details: Minimum expected version {2}, got {3}.",
+ typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur);
}
+ }
- ///
- /// Low level method for copying bytes from a byte array to a header structure.
- ///
- public static void MarshalFromBytes(out ModelHeader header, byte[] bytes)
+ ///
+ /// Low level method for copying bytes from a byte array to a header structure.
+ ///
+ public static void MarshalFromBytes(out ModelHeader header, byte[] bytes)
+ {
+ Contracts.Check(Utils.Size(bytes) >= Size);
+ unsafe
{
- Contracts.Check(Utils.Size(bytes) >= Size);
- unsafe
- {
- fixed (ModelHeader* pheader = &header)
- Marshal.Copy(bytes, 0, (IntPtr)pheader, Size);
- }
+ fixed (ModelHeader* pheader = &header)
+ Marshal.Copy(bytes, 0, (IntPtr)pheader, Size);
}
+ }
- ///
- /// Checks the basic validity of the header, assuming the stream is at least the given size.
- /// Returns false (and the out exception) on failure.
- ///
- public static bool TryValidate(ref ModelHeader header, long size, out Exception ex)
- {
- Contracts.Check(size >= 0);
+ ///
+ /// Checks the basic validity of the header, assuming the stream is at least the given size.
+ /// Returns false (and the out exception) on failure.
+ ///
+ public static bool TryValidate(ref ModelHeader header, long size, out Exception ex)
+ {
+ Contracts.Check(size >= 0);
- try
- {
- Contracts.CheckDecode(header.Signature == SignatureValue, "Wrong file type");
- Contracts.CheckDecode(header.VerReadable <= header.VerWritten, "Corrupt file header");
- Contracts.CheckDecode(header.VerReadable <= VerWrittenCur, "File is too new");
- Contracts.CheckDecode(header.VerWritten >= VerWeCanReadBack, "File is too old");
+ try
+ {
+ Contracts.CheckDecode(header.Signature == SignatureValue, "Wrong file type");
+ Contracts.CheckDecode(header.VerReadable <= header.VerWritten, "Corrupt file header");
+ Contracts.CheckDecode(header.VerReadable <= VerWrittenCur, "File is too new");
+ Contracts.CheckDecode(header.VerWritten >= VerWeCanReadBack, "File is too old");
- // Currently the model always comes immediately after the header.
- Contracts.CheckDecode(header.FpModel == Size);
- Contracts.CheckDecode(header.FpModel + header.CbModel >= header.FpModel);
+ // Currently the model always comes immediately after the header.
+ Contracts.CheckDecode(header.FpModel == Size);
+ Contracts.CheckDecode(header.FpModel + header.CbModel >= header.FpModel);
- if (header.FpStringTable == 0)
+ if (header.FpStringTable == 0)
+ {
+ // No strings.
+ Contracts.CheckDecode(header.CbStringTable == 0);
+ Contracts.CheckDecode(header.FpStringChars == 0);
+ Contracts.CheckDecode(header.CbStringChars == 0);
+ if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0)
{
- // No strings.
- Contracts.CheckDecode(header.CbStringTable == 0);
- Contracts.CheckDecode(header.FpStringChars == 0);
- Contracts.CheckDecode(header.CbStringChars == 0);
- if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0)
- {
- Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel);
- }
+ Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel);
}
- else
+ }
+ else
+ {
+ // Currently the string table always comes immediately after the model block.
+ Contracts.CheckDecode(header.FpStringTable == header.FpModel + header.CbModel);
+ Contracts.CheckDecode(header.CbStringTable % sizeof(long) == 0);
+ Contracts.CheckDecode(header.CbStringTable / sizeof(long) < int.MaxValue);
+ Contracts.CheckDecode(header.FpStringTable + header.CbStringTable > header.FpStringTable);
+ Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable);
+ Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0);
+ Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars);
+ if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0)
{
- // Currently the string table always comes immediately after the model block.
- Contracts.CheckDecode(header.FpStringTable == header.FpModel + header.CbModel);
- Contracts.CheckDecode(header.CbStringTable % sizeof(long) == 0);
- Contracts.CheckDecode(header.CbStringTable / sizeof(long) < int.MaxValue);
- Contracts.CheckDecode(header.FpStringTable + header.CbStringTable > header.FpStringTable);
- Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable);
- Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0);
- Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars);
- if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0)
- {
- Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars);
- }
+ Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars);
}
+ }
- if (header.VerWritten >= VerAssemblyNameSupported)
+ if (header.VerWritten >= VerAssemblyNameSupported)
+ {
+ if (header.FpAssemblyName == 0)
{
- if (header.FpAssemblyName == 0)
+ Contracts.CheckDecode(header.CbAssemblyName == 0);
+ }
+ else
+ {
+ // the assembly name always immediately after the string table, if there is one
+ if (header.FpStringTable == 0)
{
- Contracts.CheckDecode(header.CbAssemblyName == 0);
+ Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel);
}
else
{
- // the assembly name always immediately after the string table, if there is one
- if (header.FpStringTable == 0)
- {
- Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel);
- }
- else
- {
- Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars);
- }
- Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0);
- Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName);
+ Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars);
}
+ Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0);
+ Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName);
}
+ }
- Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong));
- Contracts.CheckDecode(size == 0 || size >= header.FpLim);
+ Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong));
+ Contracts.CheckDecode(size == 0 || size >= header.FpLim);
- ex = null;
- return true;
- }
- catch (Exception e)
- {
- ex = e;
- return false;
- }
+ ex = null;
+ return true;
+ }
+ catch (Exception e)
+ {
+ ex = e;
+ return false;
}
+ }
+
+ ///
+ /// Checks the validity of the header, reads the string table, etc.
+ ///
+ public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex)
+ {
+ Contracts.CheckValue(reader, nameof(reader));
+ Contracts.Check(fpMin >= 0);
- ///
- /// Checks the validity of the header, reads the string table, etc.
- ///
- public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex)
+ if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex))
{
- Contracts.CheckValue(reader, nameof(reader));
- Contracts.Check(fpMin >= 0);
+ strings = null;
+ loaderAssemblyName = null;
+ return false;
+ }
- if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex))
- {
- strings = null;
- loaderAssemblyName = null;
- return false;
- }
+ try
+ {
+ long fpOrig = reader.FpCur();
- try
+ StringBuilder sb = null;
+ if (header.FpStringTable == 0)
{
- long fpOrig = reader.FpCur();
-
- StringBuilder sb = null;
- if (header.FpStringTable == 0)
- {
- // No strings.
- strings = null;
+ // No strings.
+ strings = null;
- if (header.VerWritten < VerAssemblyNameSupported)
- {
- // Before VerAssemblyNameSupported, if there were no strings in the model,
- // validation ended here. Specifically the FpTail checks below were skipped.
- // There are earlier versions of models that don't have strings, and 'reader' is
- // not at FpTail at this point.
- // Preserve the previous behavior by returning early here.
- loaderAssemblyName = null;
- ex = null;
- return true;
- }
- }
- else
+ if (header.VerWritten < VerAssemblyNameSupported)
{
- reader.Seek(header.FpStringTable + fpMin);
- Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin);
-
- long cstr = header.CbStringTable / sizeof(long);
- Contracts.Assert(cstr < int.MaxValue);
- long[] offsets = reader.ReadLongArray((int)cstr);
- Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin);
- Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars);
-
- strings = new string[cstr];
- long offset = 0;
- sb = new StringBuilder();
- for (int i = 0; i < offsets.Length; i++)
- {
- Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin);
-
- long offsetPrev = offset;
- offset = offsets[i];
- Contracts.CheckDecode(offsetPrev <= offset && offset <= header.CbStringChars);
- Contracts.CheckDecode(offset % sizeof(char) == 0);
- long cch = (offset - offsetPrev) / sizeof(char);
- Contracts.CheckDecode(cch < int.MaxValue);
-
- sb.Clear();
- for (long ich = 0; ich < cch; ich++)
- sb.Append((char)reader.ReadUInt16());
- strings[i] = sb.ToString();
- }
- Contracts.CheckDecode(offset == header.CbStringChars);
- Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin);
+ // Before VerAssemblyNameSupported, if there were no strings in the model,
+ // validation ended here. Specifically the FpTail checks below were skipped.
+ // There are earlier versions of models that don't have strings, and 'reader' is
+ // not at FpTail at this point.
+ // Preserve the previous behavior by returning early here.
+ loaderAssemblyName = null;
+ ex = null;
+ return true;
}
+ }
+ else
+ {
+ reader.Seek(header.FpStringTable + fpMin);
+ Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin);
- if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0)
+ long cstr = header.CbStringTable / sizeof(long);
+ Contracts.Assert(cstr < int.MaxValue);
+ long[] offsets = reader.ReadLongArray((int)cstr);
+ Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin);
+ Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars);
+
+ strings = new string[cstr];
+ long offset = 0;
+ sb = new StringBuilder();
+ for (int i = 0; i < offsets.Length; i++)
{
- reader.Seek(header.FpAssemblyName + fpMin);
- int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char);
+ Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin);
- sb = sb != null ? sb.Clear() : new StringBuilder(assemblyNameLength);
+ long offsetPrev = offset;
+ offset = offsets[i];
+ Contracts.CheckDecode(offsetPrev <= offset && offset <= header.CbStringChars);
+ Contracts.CheckDecode(offset % sizeof(char) == 0);
+ long cch = (offset - offsetPrev) / sizeof(char);
+ Contracts.CheckDecode(cch < int.MaxValue);
- for (long ich = 0; ich < assemblyNameLength; ich++)
+ sb.Clear();
+ for (long ich = 0; ich < cch; ich++)
sb.Append((char)reader.ReadUInt16());
-
- loaderAssemblyName = sb.ToString();
- }
- else
- {
- loaderAssemblyName = null;
+ strings[i] = sb.ToString();
}
+ Contracts.CheckDecode(offset == header.CbStringChars);
+ Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin);
+ }
- Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin);
+ if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0)
+ {
+ reader.Seek(header.FpAssemblyName + fpMin);
+ int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char);
- ulong tail = reader.ReadUInt64();
- Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail");
+ sb = sb != null ? sb.Clear() : new StringBuilder(assemblyNameLength);
- ex = null;
+ for (long ich = 0; ich < assemblyNameLength; ich++)
+ sb.Append((char)reader.ReadUInt16());
- reader.Seek(fpOrig);
- return true;
+ loaderAssemblyName = sb.ToString();
}
- catch (Exception e)
+ else
{
- strings = null;
loaderAssemblyName = null;
- ex = e;
- return false;
}
- }
- ///
- /// Extract and return the loader sig from the header, trimming trailing zeros.
- ///
- public static string GetLoaderSig(ref ModelHeader header)
- {
- char[] chars = new char[3 * sizeof(ulong)];
+ Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin);
- for (int ich = 0; ich < chars.Length; ich++)
- {
- char ch;
- if (ich < 8)
- ch = (char)((header.LoaderSignature0 >> (ich * 8)) & 0xFF);
- else if (ich < 16)
- ch = (char)((header.LoaderSignature1 >> ((ich - 8) * 8)) & 0xFF);
- else
- ch = (char)((header.LoaderSignature2 >> ((ich - 16) * 8)) & 0xFF);
- chars[ich] = ch;
- }
+ ulong tail = reader.ReadUInt64();
+ Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail");
- int cch = 24;
- while (cch > 0 && chars[cch - 1] == 0)
- cch--;
- return new string(chars, 0, cch);
- }
+ ex = null;
- ///
- /// Extract and return the alternate loader sig from the header, trimming trailing zeros.
- ///
- public static string GetLoaderSigAlt(ref ModelHeader header)
+ reader.Seek(fpOrig);
+ return true;
+ }
+ catch (Exception e)
{
- char[] chars = new char[3 * sizeof(ulong)];
+ strings = null;
+ loaderAssemblyName = null;
+ ex = e;
+ return false;
+ }
+ }
- for (int ich = 0; ich < chars.Length; ich++)
- {
- char ch;
- if (ich < 8)
- ch = (char)((header.LoaderSignatureAlt0 >> (ich * 8)) & 0xFF);
- else if (ich < 16)
- ch = (char)((header.LoaderSignatureAlt1 >> ((ich - 8) * 8)) & 0xFF);
- else
- ch = (char)((header.LoaderSignatureAlt2 >> ((ich - 16) * 8)) & 0xFF);
- chars[ich] = ch;
- }
+ ///
+ /// Extract and return the loader sig from the header, trimming trailing zeros.
+ ///
+ public static string GetLoaderSig(ref ModelHeader header)
+ {
+ char[] chars = new char[3 * sizeof(ulong)];
- int cch = 24;
- while (cch > 0 && chars[cch - 1] == 0)
- cch--;
- return new string(chars, 0, cch);
+ for (int ich = 0; ich < chars.Length; ich++)
+ {
+ char ch;
+ if (ich < 8)
+ ch = (char)((header.LoaderSignature0 >> (ich * 8)) & 0xFF);
+ else if (ich < 16)
+ ch = (char)((header.LoaderSignature1 >> ((ich - 8) * 8)) & 0xFF);
+ else
+ ch = (char)((header.LoaderSignature2 >> ((ich - 16) * 8)) & 0xFF);
+ chars[ich] = ch;
}
+
+ int cch = 24;
+ while (cch > 0 && chars[cch - 1] == 0)
+ cch--;
+ return new string(chars, 0, cch);
}
///
- /// This is used to simplify version checking boiler-plate code. It is an optional
- /// utility type.
+ /// Extract and return the alternate loader sig from the header, trimming trailing zeros.
///
- [BestFriend]
- internal readonly struct VersionInfo
+ public static string GetLoaderSigAlt(ref ModelHeader header)
{
- public readonly ulong ModelSignature;
- public readonly uint VerWrittenCur;
- public readonly uint VerReadableCur;
- public readonly uint VerWeCanReadBack;
- public readonly string LoaderAssemblyName;
- public readonly string LoaderSignature;
- public readonly string LoaderSignatureAlt;
-
- ///
- /// Construct version info with a string value for modelSignature. The string must be 8 characters
- /// all less than 0x100. Spaces are mapped to zero. This assumes little-endian.
- ///
- public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack,
- string loaderAssemblyName, string loaderSignature = null, string loaderSignatureAlt = null)
+ char[] chars = new char[3 * sizeof(ulong)];
+
+ for (int ich = 0; ich < chars.Length; ich++)
{
- Contracts.Check(Utils.Size(modelSignature) == 8, "Model signature must be eight characters");
- ModelSignature = 0;
- for (int ich = 0; ich < modelSignature.Length; ich++)
- {
- char ch = modelSignature[ich];
- Contracts.Check(ch <= 0xFF);
- // Map space to zero.
- if (ch != ' ')
- ModelSignature |= (ulong)ch << (ich * 8);
- }
+ char ch;
+ if (ich < 8)
+ ch = (char)((header.LoaderSignatureAlt0 >> (ich * 8)) & 0xFF);
+ else if (ich < 16)
+ ch = (char)((header.LoaderSignatureAlt1 >> ((ich - 8) * 8)) & 0xFF);
+ else
+ ch = (char)((header.LoaderSignatureAlt2 >> ((ich - 16) * 8)) & 0xFF);
+ chars[ich] = ch;
+ }
+
+ int cch = 24;
+ while (cch > 0 && chars[cch - 1] == 0)
+ cch--;
+ return new string(chars, 0, cch);
+ }
+}
+
+///
+/// This is used to simplify version checking boiler-plate code. It is an optional
+/// utility type.
+///
+[BestFriend]
+internal readonly struct VersionInfo
+{
+ public readonly ulong ModelSignature;
+ public readonly uint VerWrittenCur;
+ public readonly uint VerReadableCur;
+ public readonly uint VerWeCanReadBack;
+ public readonly string LoaderAssemblyName;
+ public readonly string LoaderSignature;
+ public readonly string LoaderSignatureAlt;
- VerWrittenCur = verWrittenCur;
- VerReadableCur = verReadableCur;
- VerWeCanReadBack = verWeCanReadBack;
- LoaderAssemblyName = loaderAssemblyName;
- LoaderSignature = loaderSignature;
- LoaderSignatureAlt = loaderSignatureAlt;
+ ///
+ /// Construct version info with a string value for modelSignature. The string must be 8 characters
+ /// all less than 0x100. Spaces are mapped to zero. This assumes little-endian.
+ ///
+ public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack,
+ string loaderAssemblyName, string loaderSignature = null, string loaderSignatureAlt = null)
+ {
+ Contracts.Check(Utils.Size(modelSignature) == 8, "Model signature must be eight characters");
+ ModelSignature = 0;
+ for (int ich = 0; ich < modelSignature.Length; ich++)
+ {
+ char ch = modelSignature[ich];
+ Contracts.Check(ch <= 0xFF);
+ // Map space to zero.
+ if (ch != ' ')
+ ModelSignature |= (ulong)ch << (ich * 8);
}
+
+ VerWrittenCur = verWrittenCur;
+ VerReadableCur = verReadableCur;
+ VerWeCanReadBack = verWeCanReadBack;
+ LoaderAssemblyName = loaderAssemblyName;
+ LoaderSignature = loaderSignature;
+ LoaderSignatureAlt = loaderSignatureAlt;
}
}
diff --git a/src/Microsoft.ML.Core/Data/ModelLoadContext.cs b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs
index 50897fcce8..cb7aff08ca 100644
--- a/src/Microsoft.ML.Core/Data/ModelLoadContext.cs
+++ b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs
@@ -8,183 +8,182 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML
+namespace Microsoft.ML;
+
+///
+/// This is a convenience context object for loading models from a repository, for
+/// implementors of ICanSaveModel. It is not mandated but designed to reduce the
+/// amount of boiler plate code. It can also be used when loading from a single stream,
+/// for implementors of ICanSaveInBinaryFormat.
+///
+[BestFriend]
+internal sealed partial class ModelLoadContext : IDisposable
{
///
- /// This is a convenience context object for loading models from a repository, for
- /// implementors of ICanSaveModel. It is not mandated but designed to reduce the
- /// amount of boiler plate code. It can also be used when loading from a single stream,
- /// for implementors of ICanSaveInBinaryFormat.
+ /// When in repository mode, this is the repository we're reading from. It is null when
+ /// in single-stream mode.
+ ///
+ public readonly RepositoryReader Repository;
+
+ ///
+ /// When in repository mode, this is the directory we're reading from. Null means the root
+ /// of the repository. It is always null in single-stream mode.
+ ///
+ public readonly string Directory;
+
+ ///
+ /// The main stream reader.
+ ///
+ public readonly BinaryReader Reader;
+
+ ///
+ /// The strings loaded from the main stream's string table.
+ ///
+ public readonly string[] Strings;
+
+ ///
+ /// The name of the assembly that the loader lives in.
+ ///
+ ///
+ /// This may be null or empty if one was never written to the model, or is an older model version.
+ ///
+ public readonly string LoaderAssemblyName;
+
+ ///
+ /// The main stream's model header.
///
[BestFriend]
- internal sealed partial class ModelLoadContext : IDisposable
+ internal ModelHeader Header;
+
+ ///
+ /// The min file position of the main stream.
+ ///
+ public readonly long FpMin;
+
+ ///
+ /// Exception context provided by Repository (can be null).
+ ///
+ private readonly IExceptionContext _ectx;
+
+ ///
+ /// Returns whether this context is in repository mode (true) or single-stream mode (false).
+ ///
+ public bool InRepository { get { return Repository != null; } }
+
+ ///
+ /// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel.
+ ///
+ internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
{
- ///
- /// When in repository mode, this is the repository we're reading from. It is null when
- /// in single-stream mode.
- ///
- public readonly RepositoryReader Repository;
-
- ///
- /// When in repository mode, this is the directory we're reading from. Null means the root
- /// of the repository. It is always null in single-stream mode.
- ///
- public readonly string Directory;
-
- ///
- /// The main stream reader.
- ///
- public readonly BinaryReader Reader;
-
- ///
- /// The strings loaded from the main stream's string table.
- ///
- public readonly string[] Strings;
-
- ///
- /// The name of the assembly that the loader lives in.
- ///
- ///
- /// This may be null or empty if one was never written to the model, or is an older model version.
- ///
- public readonly string LoaderAssemblyName;
-
- ///
- /// The main stream's model header.
- ///
- [BestFriend]
- internal ModelHeader Header;
-
- ///
- /// The min file position of the main stream.
- ///
- public readonly long FpMin;
-
- ///
- /// Exception context provided by Repository (can be null).
- ///
- private readonly IExceptionContext _ectx;
-
- ///
- /// Returns whether this context is in repository mode (true) or single-stream mode (false).
- ///
- public bool InRepository { get { return Repository != null; } }
-
- ///
- /// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel.
- ///
- internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
- {
- Contracts.CheckValue(rep, nameof(rep));
- Repository = rep;
- _ectx = rep.ExceptionContext;
-
- _ectx.CheckValue(ent, nameof(ent));
- _ectx.CheckValueOrNull(dir);
-
- Directory = dir;
-
- Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
- try
- {
- ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
- }
- catch
- {
- Reader.Dispose();
- throw;
- }
- }
+ Contracts.CheckValue(rep, nameof(rep));
+ Repository = rep;
+ _ectx = rep.ExceptionContext;
- ///
- /// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat.
- ///
- internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
- {
- Contracts.AssertValueOrNull(ectx);
- _ectx = ectx;
- _ectx.CheckValue(reader, nameof(reader));
+ _ectx.CheckValue(ent, nameof(ent));
+ _ectx.CheckValueOrNull(dir);
- Repository = null;
- Directory = null;
- Reader = reader;
- ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
- }
+ Directory = dir;
- public void CheckAtModel()
+ Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
+ try
{
- _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
+ ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
}
-
- public void CheckAtModel(VersionInfo ver)
+ catch
{
- _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
- ModelHeader.CheckVersionInfo(ref Header, ver);
+ Reader.Dispose();
+ throw;
}
+ }
- ///
- /// Performs version checks.
- ///
- public void CheckVersionInfo(VersionInfo ver)
- {
- ModelHeader.CheckVersionInfo(ref Header, ver);
- }
+ ///
+ /// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat.
+ ///
+ internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
+ {
+ Contracts.AssertValueOrNull(ectx);
+ _ectx = ectx;
+ _ectx.CheckValue(reader, nameof(reader));
+
+ Repository = null;
+ Directory = null;
+ Reader = reader;
+ ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
+ }
- ///
- /// Reads an integer from the load context's reader, and returns the associated string,
- /// or null (encoded as -1).
- ///
- public string LoadStringOrNull()
- {
- int id = Reader.ReadInt32();
- // Note that -1 means null. Empty strings are in the string table.
- _ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings));
- if (id >= 0)
- return Strings[id];
- return null;
- }
+ public void CheckAtModel()
+ {
+ _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
+ }
- ///
- /// Reads an integer from the load context's reader, and returns the associated string.
- ///
- public string LoadString()
- {
- int id = Reader.ReadInt32();
- Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings));
+ public void CheckAtModel(VersionInfo ver)
+ {
+ _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
+ ModelHeader.CheckVersionInfo(ref Header, ver);
+ }
+
+ ///
+ /// Performs version checks.
+ ///
+ public void CheckVersionInfo(VersionInfo ver)
+ {
+ ModelHeader.CheckVersionInfo(ref Header, ver);
+ }
+
+ ///
+ /// Reads an integer from the load context's reader, and returns the associated string,
+ /// or null (encoded as -1).
+ ///
+ public string LoadStringOrNull()
+ {
+ int id = Reader.ReadInt32();
+ // Note that -1 means null. Empty strings are in the string table.
+ _ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings));
+ if (id >= 0)
return Strings[id];
- }
+ return null;
+ }
- ///
- /// Reads an integer from the load context's reader, and returns the associated string.
- /// Throws if the string is empty or null.
- ///
- public string LoadNonEmptyString()
- {
- int id = Reader.ReadInt32();
- _ectx.CheckDecode(0 <= id && id < Utils.Size(Strings));
- var str = Strings[id];
- _ectx.CheckDecode(str.Length > 0);
- return str;
- }
+ ///
+ /// Reads an integer from the load context's reader, and returns the associated string.
+ ///
+ public string LoadString()
+ {
+ int id = Reader.ReadInt32();
+ Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings));
+ return Strings[id];
+ }
- ///
- /// Commit the load operation. This completes reading of the main stream. When in repository
- /// mode, it disposes the Reader (but not the repository).
- ///
- public void Done()
- {
- ModelHeader.EndRead(FpMin, ref Header, Reader);
- Dispose();
- }
+ ///
+ /// Reads an integer from the load context's reader, and returns the associated string.
+ /// Throws if the string is empty or null.
+ ///
+ public string LoadNonEmptyString()
+ {
+ int id = Reader.ReadInt32();
+ _ectx.CheckDecode(0 <= id && id < Utils.Size(Strings));
+ var str = Strings[id];
+ _ectx.CheckDecode(str.Length > 0);
+ return str;
+ }
- ///
- /// When in repository mode, this disposes the Reader (but no the repository).
- ///
- public void Dispose()
- {
- // When in single-stream mode, we don't own the Reader.
- if (InRepository)
- Reader.Dispose();
- }
+ ///
+ /// Commit the load operation. This completes reading of the main stream. When in repository
+ /// mode, it disposes the Reader (but not the repository).
+ ///
+ public void Done()
+ {
+ ModelHeader.EndRead(FpMin, ref Header, Reader);
+ Dispose();
+ }
+
+ ///
+ /// When in repository mode, this disposes the Reader (but no the repository).
+ ///
+ public void Dispose()
+ {
+ // When in single-stream mode, we don't own the Reader.
+ if (InRepository)
+ Reader.Dispose();
}
}
diff --git a/src/Microsoft.ML.Core/Data/ModelLoading.cs b/src/Microsoft.ML.Core/Data/ModelLoading.cs
index c186e91bfe..e620649bb7 100644
--- a/src/Microsoft.ML.Core/Data/ModelLoading.cs
+++ b/src/Microsoft.ML.Core/Data/ModelLoading.cs
@@ -9,357 +9,356 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML
+namespace Microsoft.ML;
+
+///
+/// Signature for a repository based model loader. This is the dual of .
+///
+[BestFriend]
+internal delegate void SignatureLoadModel(ModelLoadContext ctx);
+
+internal sealed partial class ModelLoadContext : IDisposable
{
+ public const string ModelStreamName = "Model.key";
+ internal const string NameBinary = "Model.bin";
+
///
- /// Signature for a repository based model loader. This is the dual of .
+ /// Returns the new assembly name to maintain backward compatibility.
///
- [BestFriend]
- internal delegate void SignatureLoadModel(ModelLoadContext ctx);
-
- internal sealed partial class ModelLoadContext : IDisposable
+ private string ForwardedLoaderAssemblyName
{
- public const string ModelStreamName = "Model.key";
- internal const string NameBinary = "Model.bin";
-
- ///
- /// Returns the new assembly name to maintain backward compatibility.
- ///
- private string ForwardedLoaderAssemblyName
+ get
{
- get
+ string[] nameDetails = LoaderAssemblyName.Split(',');
+ switch (nameDetails[0])
{
- string[] nameDetails = LoaderAssemblyName.Split(',');
- switch (nameDetails[0])
- {
- case "Microsoft.ML.HalLearners":
- nameDetails[0] = "Microsoft.ML.Mkl.Components";
- break;
- case "Microsoft.ML.StandardLearners":
- nameDetails[0] = "Microsoft.ML.StandardTrainers";
- break;
- default:
- return LoaderAssemblyName;
- }
-
- return string.Join(",", nameDetails);
+ case "Microsoft.ML.HalLearners":
+ nameDetails[0] = "Microsoft.ML.Mkl.Components";
+ break;
+ case "Microsoft.ML.StandardLearners":
+ nameDetails[0] = "Microsoft.ML.StandardTrainers";
+ break;
+ default:
+ return LoaderAssemblyName;
}
- }
- ///
- /// Return whether this context contains a directory and stream for a sub-model with
- /// the indicated name. This does not attempt to load the sub-model.
- ///
- public bool ContainsModel(string name)
- {
- if (!InRepository)
- return false;
- if (string.IsNullOrEmpty(name))
- return false;
-
- var dir = Path.Combine(Directory ?? "", name);
- var ent = Repository.OpenEntryOrNull(dir, ModelStreamName);
- if (ent != null)
- {
- ent.Dispose();
- return true;
- }
-
- if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null)
- {
- ent.Dispose();
- return true;
- }
-
- return false;
+ return string.Join(",", nameDetails);
}
+ }
- ///
- /// Load an optional object from the repository directory.
- /// Returns false iff no stream was found for the object, iff result is set to null.
- /// Throws if loading fails for any other reason.
- ///
- public static bool LoadModelOrNull(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
- where TRes : class
- {
- Contracts.CheckValue(env, nameof(env));
- env.CheckValue(rep, nameof(rep));
- var ent = rep.OpenEntryOrNull(dir, ModelStreamName);
- if (ent != null)
- {
- using (ent)
- {
- // Provide the repository, entry, and directory name to the loadable class ctor.
- env.Assert(ent.Stream.Position == 0);
- LoadModel(env, out result, rep, ent, dir, extra);
- return true;
- }
- }
-
- if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
- {
- using (ent)
- {
- env.Assert(ent.Stream.Position == 0);
- LoadModel(env, out result, ent.Stream, extra);
- return true;
- }
- }
-
- result = null;
+ ///
+ /// Return whether this context contains a directory and stream for a sub-model with
+ /// the indicated name. This does not attempt to load the sub-model.
+ ///
+ public bool ContainsModel(string name)
+ {
+ if (!InRepository)
+ return false;
+ if (string.IsNullOrEmpty(name))
return false;
- }
- ///
- /// Load an object from the repository directory.
- ///
- public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
- where TRes : class
+ var dir = Path.Combine(Directory ?? "", name);
+ var ent = Repository.OpenEntryOrNull(dir, ModelStreamName);
+ if (ent != null)
{
- Contracts.CheckValue(env, nameof(env));
- env.CheckValue(rep, nameof(rep));
- if (!LoadModelOrNull(env, out result, rep, dir, extra))
- throw env.ExceptDecode("Corrupt model file");
- env.AssertValue(result);
+ ent.Dispose();
+ return true;
}
- ///
- /// Load a sub model from the given sub directory if it exists. This requires InRepository to be true.
- /// Returns false iff no stream was found for the object, iff result is set to null.
- /// Throws if loading fails for any other reason.
- ///
- public bool LoadModelOrNull(IHostEnvironment env, out TRes result, string name, params object[] extra)
- where TRes : class
+ if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null)
{
- _ectx.CheckValue(env, nameof(env));
- _ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream");
- return LoadModelOrNull(env, out result, Repository, Path.Combine(Directory ?? "", name), extra);
+ ent.Dispose();
+ return true;
}
- ///
- /// Load a sub model from the given sub directory. This requires InRepository to be true.
- ///
- public void LoadModel(IHostEnvironment env, out TRes result, string name, params object[] extra)
- where TRes : class
+ return false;
+ }
+
+ ///
+ /// Load an optional object from the repository directory.
+ /// Returns false iff no stream was found for the object, iff result is set to null.
+ /// Throws if loading fails for any other reason.
+ ///
+ public static bool LoadModelOrNull(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(rep, nameof(rep));
+ var ent = rep.OpenEntryOrNull(dir, ModelStreamName);
+ if (ent != null)
{
- _ectx.CheckValue(env, nameof(env));
- if (!LoadModelOrNull(env, out result, name, extra))
- throw _ectx.ExceptDecode("Corrupt model file");
- _ectx.AssertValue(result);
+ using (ent)
+ {
+ // Provide the repository, entry, and directory name to the loadable class ctor.
+ env.Assert(ent.Stream.Position == 0);
+ LoadModel(env, out result, rep, ent, dir, extra);
+ return true;
+ }
}
- ///
- /// Try to load from the given repository entry using the default loader(s) specified in the header.
- /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
- ///
- private static bool TryLoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
- where TRes : class
+ if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
{
- Contracts.CheckValue(env, nameof(env));
- env.CheckValue(rep, nameof(rep));
- long fp = ent.Stream.Position;
- using (var ctx = new ModelLoadContext(rep, ent, dir))
+ using (ent)
{
- env.Assert(fp == ctx.FpMin);
- if (ctx.TryLoadModelCore(env, out result, extra))
- return true;
+ env.Assert(ent.Stream.Position == 0);
+ LoadModel(env, out result, ent.Stream, extra);
+ return true;
}
+ }
- // TryLoadModelCore should rewind on failure.
- Contracts.Assert(fp == ent.Stream.Position);
+ result = null;
+ return false;
+ }
- return false;
- }
+ ///
+ /// Load an object from the repository directory.
+ ///
+ public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(rep, nameof(rep));
+ if (!LoadModelOrNull(env, out result, rep, dir, extra))
+ throw env.ExceptDecode("Corrupt model file");
+ env.AssertValue(result);
+ }
- ///
- /// Load from the given repository entry using the default loader(s) specified in the header.
- ///
- public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
- where TRes : class
- {
- Contracts.CheckValue(env, nameof(env));
- env.CheckValue(rep, nameof(rep));
- if (!TryLoadModel(env, out result, rep, ent, dir, extra))
- throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
- }
+ ///
+ /// Load a sub model from the given sub directory if it exists. This requires InRepository to be true.
+ /// Returns false iff no stream was found for the object, iff result is set to null.
+ /// Throws if loading fails for any other reason.
+ ///
+ public bool LoadModelOrNull(IHostEnvironment env, out TRes result, string name, params object[] extra)
+ where TRes : class
+ {
+ _ectx.CheckValue(env, nameof(env));
+ _ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream");
+ return LoadModelOrNull(env, out result, Repository, Path.Combine(Directory ?? "", name), extra);
+ }
- ///
- /// Try to load from the given stream (non-Repository).
- /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
- ///
- public static bool TryLoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
- where TRes : class
- {
- Contracts.CheckValue(env, nameof(env));
- using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true))
- return TryLoadModel(env, out result, reader, extra);
- }
+ ///
+ /// Load a sub model from the given sub directory. This requires InRepository to be true.
+ ///
+ public void LoadModel(IHostEnvironment env, out TRes result, string name, params object[] extra)
+ where TRes : class
+ {
+ _ectx.CheckValue(env, nameof(env));
+ if (!LoadModelOrNull(env, out result, name, extra))
+ throw _ectx.ExceptDecode("Corrupt model file");
+ _ectx.AssertValue(result);
+ }
- ///
- /// Load from the given stream (non-Repository) using the default loader(s) specified in the header.
- ///
- public static void LoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
- where TRes : class
+ ///
+ /// Try to load from the given repository entry using the default loader(s) specified in the header.
+ /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
+ ///
+ private static bool TryLoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(rep, nameof(rep));
+ long fp = ent.Stream.Position;
+ using (var ctx = new ModelLoadContext(rep, ent, dir))
{
- Contracts.CheckValue(env, nameof(env));
- if (!TryLoadModel(env, out result, stream, extra))
- throw Contracts.ExceptDecode("Couldn't load model");
+ env.Assert(fp == ctx.FpMin);
+ if (ctx.TryLoadModelCore(env, out result, extra))
+ return true;
}
- ///
- /// Try to load from the given reader (non-Repository).
- /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
- ///
- public static bool TryLoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
- where TRes : class
- {
- Contracts.CheckValue(env, nameof(env));
- long fp = reader.BaseStream.Position;
- using (var ctx = new ModelLoadContext(reader))
- {
- Contracts.Assert(fp == ctx.FpMin);
- return ctx.TryLoadModelCore(env, out result, extra);
- }
- }
+ // TryLoadModelCore should rewind on failure.
+ Contracts.Assert(fp == ent.Stream.Position);
+
+ return false;
+ }
- ///
- /// Load from the given reader (non-Repository) using the default loader(s) specified in the header.
- ///
- public static void LoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
- where TRes : class
+ ///
+ /// Load from the given repository entry using the default loader(s) specified in the header.
+ ///
+ public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(rep, nameof(rep));
+ if (!TryLoadModel(env, out result, rep, ent, dir, extra))
+ throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
+ }
+
+ ///
+ /// Try to load from the given stream (non-Repository).
+ /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
+ ///
+ public static bool TryLoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true))
+ return TryLoadModel(env, out result, reader, extra);
+ }
+
+ ///
+ /// Load from the given stream (non-Repository) using the default loader(s) specified in the header.
+ ///
+ public static void LoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ if (!TryLoadModel(env, out result, stream, extra))
+ throw Contracts.ExceptDecode("Couldn't load model");
+ }
+
+ ///
+ /// Try to load from the given reader (non-Repository).
+ /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
+ ///
+ public static bool TryLoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ long fp = reader.BaseStream.Position;
+ using (var ctx = new ModelLoadContext(reader))
{
- Contracts.CheckValue(env, nameof(env));
- if (!TryLoadModel(env, out result, reader, extra))
- throw Contracts.ExceptDecode("Couldn't load model");
+ Contracts.Assert(fp == ctx.FpMin);
+ return ctx.TryLoadModelCore(env, out result, extra);
}
+ }
- ///
- /// Tries to load.
- /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
- ///
- private bool TryLoadModelCore(IHostEnvironment env, out TRes result, params object[] extra)
- where TRes : class
- {
- _ectx.AssertValue(env, "env");
- _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
+ ///
+ /// Load from the given reader (non-Repository) using the default loader(s) specified in the header.
+ ///
+ public static void LoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
+ where TRes : class
+ {
+ Contracts.CheckValue(env, nameof(env));
+ if (!TryLoadModel(env, out result, reader, extra))
+ throw Contracts.ExceptDecode("Couldn't load model");
+ }
- var args = ConcatArgsRev(extra, this);
+ ///
+ /// Tries to load.
+ /// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
+ ///
+ private bool TryLoadModelCore(IHostEnvironment env, out TRes result, params object[] extra)
+ where TRes : class
+ {
+ _ectx.AssertValue(env, "env");
+ _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
- EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog);
+ var args = ConcatArgsRev(extra, this);
- object tmp;
- string sig = ModelHeader.GetLoaderSig(ref Header);
- if (!string.IsNullOrWhiteSpace(sig) &&
- ComponentCatalog.TryCreateInstance