diff --git a/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs b/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs index 63e893569f..a93e588970 100644 --- a/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs +++ b/src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs @@ -4,29 +4,28 @@ using System; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +[BestFriend] +internal static class AnnotationBuilderExtensions { - [BestFriend] - internal static class AnnotationBuilderExtensions - { - /// - /// Add slot names annotation. - /// - /// The to which to add the slot names. - /// The size of the slot names vector. - /// The getter delegate for the slot names. - public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter>> getter) - => builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter); + /// + /// Add slot names annotation. + /// + /// The to which to add the slot names. + /// The size of the slot names vector. + /// The getter delegate for the slot names. + public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter>> getter) + => builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter); - /// - /// Add key values annotation. - /// - /// The value type of key values. - /// The to which to add the key values. - /// The size of key values vector. - /// The value type of key values. Its raw type must match . - /// The getter delegate for the key values. - public static void AddKeyValues(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter> getter) - => builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter); - } + /// + /// Add key values annotation. + /// + /// The value type of key values. + /// The to which to add the key values. + /// The size of key values vector. + /// The value type of key values. Its raw type must match . + /// The getter delegate for the key values. + public static void AddKeyValues(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter> getter) + => builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter); } diff --git a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs index f942e4755e..257eb73728 100644 --- a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs +++ b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs @@ -9,493 +9,492 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Utilities for implementing and using the annotation API of . +/// +[BestFriend] +internal static class AnnotationUtils { /// - /// Utilities for implementing and using the annotation API of . + /// This class lists the canonical annotation kinds /// - [BestFriend] - internal static class AnnotationUtils + public static class Kinds { /// - /// This class lists the canonical annotation kinds + /// Annotation kind for names associated with slots/positions in a vector-valued column. + /// The associated annotation type is typically fixed-sized vector of Text. /// - public static class Kinds - { - /// - /// Annotation kind for names associated with slots/positions in a vector-valued column. - /// The associated annotation type is typically fixed-sized vector of Text. - /// - public const string SlotNames = "SlotNames"; - - /// - /// Annotation kind for values associated with the key indices when the column type's item type - /// is a key type. The associated annotation type is typically fixed-sized vector of a primitive - /// type. The primitive type is frequently Text, but can be anything. - /// - public const string KeyValues = "KeyValues"; - - /// - /// Annotation kind for sets of score columns. The value is typically a with raw type U4. - /// - public const string ScoreColumnSetId = "ScoreColumnSetId"; - - /// - /// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification". - /// The value is typically a ReadOnlyMemory<char>. - /// - public const string ScoreColumnKind = "ScoreColumnKind"; - - /// - /// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory. - /// - public const string ScoreValueKind = "ScoreValueKind"; - - /// - /// Annotation kind that indicates if a column is normalized. The value is typically a Bool. - /// - public const string IsNormalized = "IsNormalized"; - - /// - /// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool. - /// Not to be confused with IsHidden() that determines if a column is masked. - /// - public const string IsUserVisible = "IsUserVisible"; - - /// - /// Annotation kind for the label values used in training to be used for the predicted label. - /// The value is typically a fixed-sized vector of Text. - /// - public const string TrainingLabelValues = "TrainingLabelValues"; - - /// - /// Annotation kind that indicates the ranges within a column that are categorical features. - /// The value is a vector type of ints with dimension of two. The first dimension - /// represents the number of categorical features and second dimension represents the range - /// and is of size two. The range has start and end index(both inclusive) of categorical - /// slots within that column. - /// - public const string CategoricalSlotRanges = "CategoricalSlotRanges"; - } + public const string SlotNames = "SlotNames"; /// - /// This class holds all pre-defined string values that can be found in canonical annotations + /// Annotation kind for values associated with the key indices when the column type's item type + /// is a key type. The associated annotation type is typically fixed-sized vector of a primitive + /// type. The primitive type is frequently Text, but can be anything. /// - public static class Const - { - public static class ScoreColumnKind - { - public const string BinaryClassification = "BinaryClassification"; - public const string MulticlassClassification = "MulticlassClassification"; - public const string Regression = "Regression"; - public const string Ranking = "Ranking"; - public const string Clustering = "Clustering"; - public const string MultiOutputRegression = "MultiOutputRegression"; - public const string AnomalyDetection = "AnomalyDetection"; - public const string SequenceClassification = "SequenceClassification"; - public const string QuantileRegression = "QuantileRegression"; - public const string Recommender = "Recommender"; - public const string ItemSimilarity = "ItemSimilarity"; - public const string FeatureContribution = "FeatureContribution"; - } - - public static class ScoreValueKind - { - public const string Score = "Score"; - public const string PredictedLabel = "PredictedLabel"; - public const string Probability = "Probability"; - } - } + public const string KeyValues = "KeyValues"; /// - /// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below. + /// Annotation kind for sets of score columns. The value is typically a with raw type U4. /// - public delegate void AnnotationGetter(int col, ref TValue dst); + public const string ScoreColumnSetId = "ScoreColumnSetId"; /// - /// Returns a standard exception for responding to an invalid call to GetAnnotation. + /// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification". + /// The value is typically a ReadOnlyMemory<char>. /// - public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation"); + public const string ScoreColumnKind = "ScoreColumnKind"; /// - /// Returns a standard exception for responding to an invalid call to GetAnnotation. + /// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory. /// - public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation"); + public const string ScoreValueKind = "ScoreValueKind"; /// - /// Helper to marshal a call to GetAnnotation{TValue} to a specific type. + /// Annotation kind that indicates if a column is normalized. The value is typically a Bool. /// - public static void Marshal(this AnnotationGetter getter, int col, ref TNeed dst) - { - Contracts.CheckValue(getter, nameof(getter)); - - if (typeof(TNeed) != typeof(THave)) - throw ExceptGetAnnotation(); - var get = (AnnotationGetter)(Delegate)getter; - get(col, ref dst); - } + public const string IsNormalized = "IsNormalized"; /// - /// Returns a vector type with item type text and the given size. The size must be positive. - /// This is a standard type for annotation consisting of multiple text values, eg SlotNames. + /// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool. + /// Not to be confused with IsHidden() that determines if a column is masked. /// - public static VectorDataViewType GetNamesType(int size) - { - Contracts.CheckParam(size > 0, nameof(size), "must be known size"); - return new VectorDataViewType(TextDataViewType.Instance, size); - } + public const string IsUserVisible = "IsUserVisible"; /// - /// Returns a vector type with item type int and the given size. - /// The range count must be a positive integer. - /// This is a standard type for annotation consisting of multiple int values that represent - /// categorical slot ranges with in a column. + /// Annotation kind for the label values used in training to be used for the predicted label. + /// The value is typically a fixed-sized vector of Text. /// - public static VectorDataViewType GetCategoricalType(int rangeCount) - { - Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size"); - return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2); - } - - private static volatile KeyDataViewType _scoreColumnSetIdType; + public const string TrainingLabelValues = "TrainingLabelValues"; /// - /// The type of the ScoreColumnSetId annotation. + /// Annotation kind that indicates the ranges within a column that are categorical features. + /// The value is a vector type of ints with dimension of two. The first dimension + /// represents the number of categorical features and second dimension represents the range + /// and is of size two. The range has start and end index(both inclusive) of categorical + /// slots within that column. /// - public static KeyDataViewType ScoreColumnSetIdType - { - get - { - return _scoreColumnSetIdType ?? - Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ?? - _scoreColumnSetIdType; - } - } + public const string CategoricalSlotRanges = "CategoricalSlotRanges"; + } - /// - /// Returns a key-value pair useful when implementing GetAnnotationTypes(col). - /// - public static KeyValuePair GetSlotNamesPair(int size) + /// + /// This class holds all pre-defined string values that can be found in canonical annotations + /// + public static class Const + { + public static class ScoreColumnKind { - return GetNamesType(size).GetPair(Kinds.SlotNames); + public const string BinaryClassification = "BinaryClassification"; + public const string MulticlassClassification = "MulticlassClassification"; + public const string Regression = "Regression"; + public const string Ranking = "Ranking"; + public const string Clustering = "Clustering"; + public const string MultiOutputRegression = "MultiOutputRegression"; + public const string AnomalyDetection = "AnomalyDetection"; + public const string SequenceClassification = "SequenceClassification"; + public const string QuantileRegression = "QuantileRegression"; + public const string Recommender = "Recommender"; + public const string ItemSimilarity = "ItemSimilarity"; + public const string FeatureContribution = "FeatureContribution"; } - /// - /// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes - /// that the values of the key type are Text. - /// - public static KeyValuePair GetKeyNamesPair(int size) + public static class ScoreValueKind { - return GetNamesType(size).GetPair(Kinds.KeyValues); + public const string Score = "Score"; + public const string PredictedLabel = "PredictedLabel"; + public const string Probability = "Probability"; } + } - /// - /// Given a type and annotation kind string, returns a key-value pair. This is useful when - /// implementing GetAnnotationTypes(col). - /// - public static KeyValuePair GetPair(this DataViewType type, string kind) - { - Contracts.CheckValue(type, nameof(type)); - return new KeyValuePair(kind, type); - } + /// + /// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below. + /// + public delegate void AnnotationGetter(int col, ref TValue dst); - // REVIEW: This should be in some general utility code. + /// + /// Returns a standard exception for responding to an invalid call to GetAnnotation. + /// + public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation"); - /// - /// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes. - /// - public static IEnumerable Prepend(this IEnumerable tail, params T[] head) + /// + /// Returns a standard exception for responding to an invalid call to GetAnnotation. + /// + public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation"); + + /// + /// Helper to marshal a call to GetAnnotation{TValue} to a specific type. + /// + public static void Marshal(this AnnotationGetter getter, int col, ref TNeed dst) + { + Contracts.CheckValue(getter, nameof(getter)); + + if (typeof(TNeed) != typeof(THave)) + throw ExceptGetAnnotation(); + var get = (AnnotationGetter)(Delegate)getter; + get(col, ref dst); + } + + /// + /// Returns a vector type with item type text and the given size. The size must be positive. + /// This is a standard type for annotation consisting of multiple text values, eg SlotNames. + /// + public static VectorDataViewType GetNamesType(int size) + { + Contracts.CheckParam(size > 0, nameof(size), "must be known size"); + return new VectorDataViewType(TextDataViewType.Instance, size); + } + + /// + /// Returns a vector type with item type int and the given size. + /// The range count must be a positive integer. + /// This is a standard type for annotation consisting of multiple int values that represent + /// categorical slot ranges with in a column. + /// + public static VectorDataViewType GetCategoricalType(int rangeCount) + { + Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size"); + return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2); + } + + private static volatile KeyDataViewType _scoreColumnSetIdType; + + /// + /// The type of the ScoreColumnSetId annotation. + /// + public static KeyDataViewType ScoreColumnSetIdType + { + get { - return head.Concat(tail); + return _scoreColumnSetIdType ?? + Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ?? + _scoreColumnSetIdType; } + } - /// - /// Returns the max value for the specified annotation kind. - /// The annotation type should be a with raw type U4. - /// colMax will be set to the first column that has the max value for the specified annotation. - /// If no column has the specified annotation, colMax is set to -1 and the method returns zero. - /// The filter function is called for each column, passing in the schema and the column index, and returns - /// true if the column should be considered, false if the column should be skipped. - /// - public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func filterFunc = null) + /// + /// Returns a key-value pair useful when implementing GetAnnotationTypes(col). + /// + public static KeyValuePair GetSlotNamesPair(int size) + { + return GetNamesType(size).GetPair(Kinds.SlotNames); + } + + /// + /// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes + /// that the values of the key type are Text. + /// + public static KeyValuePair GetKeyNamesPair(int size) + { + return GetNamesType(size).GetPair(Kinds.KeyValues); + } + + /// + /// Given a type and annotation kind string, returns a key-value pair. This is useful when + /// implementing GetAnnotationTypes(col). + /// + public static KeyValuePair GetPair(this DataViewType type, string kind) + { + Contracts.CheckValue(type, nameof(type)); + return new KeyValuePair(kind, type); + } + + // REVIEW: This should be in some general utility code. + + /// + /// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes. + /// + public static IEnumerable Prepend(this IEnumerable tail, params T[] head) + { + return head.Concat(tail); + } + + /// + /// Returns the max value for the specified annotation kind. + /// The annotation type should be a with raw type U4. + /// colMax will be set to the first column that has the max value for the specified annotation. + /// If no column has the specified annotation, colMax is set to -1 and the method returns zero. + /// The filter function is called for each column, passing in the schema and the column index, and returns + /// true if the column should be considered, false if the column should be skipped. + /// + public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func filterFunc = null) + { + uint max = 0; + colMax = -1; + for (int col = 0; col < schema.Count; col++) { - uint max = 0; - colMax = -1; - for (int col = 0; col < schema.Count; col++) + var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type; + if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint)) + continue; + if (filterFunc != null && !filterFunc(schema, col)) + continue; + uint value = 0; + schema[col].Annotations.GetValue(annotationKind, ref value); + if (max < value) { - var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type; - if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint)) - continue; - if (filterFunc != null && !filterFunc(schema, col)) - continue; - uint value = 0; - schema[col].Annotations.GetValue(annotationKind, ref value); - if (max < value) - { - max = value; - colMax = col; - } + max = value; + colMax = col; } - return max; } + return max; + } - /// - /// Returns the set of column ids which match the value of specified annotation kind. - /// The annotation type should be a with raw type U4. - /// - public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, uint value) + /// + /// Returns the set of column ids which match the value of specified annotation kind. + /// The annotation type should be a with raw type U4. + /// + public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, uint value) + { + for (int col = 0; col < schema.Count; col++) { - for (int col = 0; col < schema.Count; col++) + var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type; + if (columnType is KeyDataViewType && columnType.RawType == typeof(uint)) { - var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type; - if (columnType is KeyDataViewType && columnType.RawType == typeof(uint)) - { - uint val = 0; - schema[col].Annotations.GetValue(annotationKind, ref val); - if (val == value) - yield return col; - } + uint val = 0; + schema[col].Annotations.GetValue(annotationKind, ref val); + if (val == value) + yield return col; } } + } - /// - /// Returns the set of column ids which match the value of specified annotation kind. - /// The annotation type should be of type text. - /// - public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, string value) + /// + /// Returns the set of column ids which match the value of specified annotation kind. + /// The annotation type should be of type text. + /// + public static IEnumerable GetColumnSet(this DataViewSchema schema, string annotationKind, string value) + { + for (int col = 0; col < schema.Count; col++) { - for (int col = 0; col < schema.Count; col++) + var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type; + if (columnType is TextDataViewType) { - var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type; - if (columnType is TextDataViewType) - { - ReadOnlyMemory val = default; - schema[col].Annotations.GetValue(annotationKind, ref val); - if (ReadOnlyMemoryUtils.EqualsStr(value, val)) - yield return col; - } + ReadOnlyMemory val = default; + schema[col].Annotations.GetValue(annotationKind, ref val); + if (ReadOnlyMemoryUtils.EqualsStr(value, val)) + yield return col; } } + } - /// - /// Returns true if the specified column: - /// * has a SlotNames annotation - /// * annotation type is VBuffer<ReadOnlyMemory<char>> of length . - /// - public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize) - { - if (vectorSize == 0) - return false; - - var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames); - return - metaColumn != null - && metaColumn.Value.Type is VectorDataViewType vectorType - && vectorType.Size == vectorSize - && vectorType.ItemType is TextDataViewType; - } + /// + /// Returns true if the specified column: + /// * has a SlotNames annotation + /// * annotation type is VBuffer<ReadOnlyMemory<char>> of length . + /// + public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize) + { + if (vectorSize == 0) + return false; + + var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames); + return + metaColumn != null + && metaColumn.Value.Type is VectorDataViewType vectorType + && vectorType.Size == vectorSize + && vectorType.ItemType is TextDataViewType; + } - public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) - { - Contracts.CheckValueOrNull(schema); - Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize)); - - IReadOnlyList list = schema?.GetColumns(role); - if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize)) - VBufferUtils.Resize(ref slotNames, vectorSize, 0); - else - schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames); - } + public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) + { + Contracts.CheckValueOrNull(schema); + Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize)); + + IReadOnlyList list = schema?.GetColumns(role); + if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize)) + VBufferUtils.Resize(ref slotNames, vectorSize, 0); + else + schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames); + } - public static bool NeedsSlotNames(this SchemaShape.Column col) - { - return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) - && metaCol.Kind == SchemaShape.Column.VectorKind.Vector - && metaCol.ItemType is TextDataViewType; - } + public static bool NeedsSlotNames(this SchemaShape.Column col) + { + return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) + && metaCol.Kind == SchemaShape.Column.VectorKind.Vector + && metaCol.ItemType is TextDataViewType; + } - /// - /// Returns whether a column has the annotation indicated by - /// the schema shape. - /// - /// The schema shape column to query - /// True if and only if the column has the annotation - /// of a scalar type, which we assume, if set, should be true. - public static bool IsNormalized(this SchemaShape.Column column) - { - Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly"); - return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol) - && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey - && metaCol.ItemType == BooleanDataViewType.Instance; - } + /// + /// Returns whether a column has the annotation indicated by + /// the schema shape. + /// + /// The schema shape column to query + /// True if and only if the column has the annotation + /// of a scalar type, which we assume, if set, should be true. + public static bool IsNormalized(this SchemaShape.Column column) + { + Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly"); + return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol) + && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey + && metaCol.ItemType == BooleanDataViewType.Instance; + } - /// - /// Returns whether a column has the annotation indicated by - /// the schema shape. - /// - /// The schema shape column to query - /// True if and only if the column is a definite sized vector type, has the - /// annotation of definite sized vectors of text. - public static bool HasSlotNames(this SchemaShape.Column col) - { - Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); - return col.Kind == SchemaShape.Column.VectorKind.Vector - && col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol) - && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey - && metaCol.ItemType == TextDataViewType.Instance; - } + /// + /// Returns whether a column has the annotation indicated by + /// the schema shape. + /// + /// The schema shape column to query + /// True if and only if the column is a definite sized vector type, has the + /// annotation of definite sized vectors of text. + public static bool HasSlotNames(this SchemaShape.Column col) + { + Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); + return col.Kind == SchemaShape.Column.VectorKind.Vector + && col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol) + && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey + && metaCol.ItemType == TextDataViewType.Instance; + } - /// - /// Tries to get the annotation kind of the specified type for a column. - /// - /// The raw type of the annotation, should match the PrimitiveType type - /// The schema - /// The type of the annotation - /// The annotation kind - /// The column - /// The value to return, if successful - /// True if the annotation of the right type exists, false otherwise - public static bool TryGetAnnotation(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckValue(type, nameof(type)); - - var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type; - if (!type.Equals(annotationType)) - return false; - schema[col].Annotations.GetValue(kind, ref value); - return true; - } + /// + /// Tries to get the annotation kind of the specified type for a column. + /// + /// The raw type of the annotation, should match the PrimitiveType type + /// The schema + /// The type of the annotation + /// The annotation kind + /// The column + /// The value to return, if successful + /// True if the annotation of the right type exists, false otherwise + public static bool TryGetAnnotation(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value) + { + Contracts.CheckValue(schema, nameof(schema)); + Contracts.CheckValue(type, nameof(type)); + + var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type; + if (!type.Equals(annotationType)) + return false; + schema[col].Annotations.GetValue(kind, ref value); + return true; + } - /// - /// The categoricalFeatures is a vector of the indices of categorical features slots. - /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers. - /// So if its value is the range of numbers: 0,2,3,4,8,9 - /// look at it as [0,2],[3,4],[8,9]. - /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical - /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals. - /// - public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.Check(colIndex >= 0, nameof(colIndex)); + /// + /// The categoricalFeatures is a vector of the indices of categorical features slots. + /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers. + /// So if its value is the range of numbers: 0,2,3,4,8,9 + /// look at it as [0,2],[3,4],[8,9]. + /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical + /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals. + /// + public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures) + { + Contracts.CheckValue(schema, nameof(schema)); + Contracts.Check(colIndex >= 0, nameof(colIndex)); - bool isValid = false; - categoricalFeatures = null; - if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0)) - return isValid; + bool isValid = false; + categoricalFeatures = null; + if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0)) + return isValid; - var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type; - if (type?.RawType == typeof(VBuffer)) + var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type; + if (type?.RawType == typeof(VBuffer)) + { + VBuffer catIndices = default(VBuffer); + schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices); + VBufferUtils.Densify(ref catIndices); + int columnSlotsCount = vecType.Size; + if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2) { - VBuffer catIndices = default(VBuffer); - schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices); - VBufferUtils.Densify(ref catIndices); - int columnSlotsCount = vecType.Size; - if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2) + int previousEndIndex = -1; + isValid = true; + var catIndicesValues = catIndices.GetValues(); + for (int i = 0; i < catIndicesValues.Length; i += 2) { - int previousEndIndex = -1; - isValid = true; - var catIndicesValues = catIndices.GetValues(); - for (int i = 0; i < catIndicesValues.Length; i += 2) + if (catIndicesValues[i] > catIndicesValues[i + 1] || + catIndicesValues[i] <= previousEndIndex || + catIndicesValues[i] >= columnSlotsCount || + catIndicesValues[i + 1] >= columnSlotsCount) { - if (catIndicesValues[i] > catIndicesValues[i + 1] || - catIndicesValues[i] <= previousEndIndex || - catIndicesValues[i] >= columnSlotsCount || - catIndicesValues[i + 1] >= columnSlotsCount) - { - isValid = false; - break; - } - - previousEndIndex = catIndicesValues[i + 1]; + isValid = false; + break; } - if (isValid) - categoricalFeatures = catIndicesValues.ToArray(); + + previousEndIndex = catIndicesValues[i + 1]; } + if (isValid) + categoricalFeatures = catIndicesValues.ToArray(); } - - return isValid; } - /// - /// Produces sequence of columns that are generated by trainer estimators. - /// - /// whether we should also append 'IsNormalized' (typically for probability column) - public static IEnumerable GetTrainerOutputAnnotation(bool isNormalized = false) - { - var cols = new List(); - cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true)); - cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false)); - cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false)); - if (isNormalized) - cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false)); - return cols; - } + return isValid; + } - /// - /// Produces annotations for the score column generated by trainer estimators for multiclass classification. - /// If input LabelColumn is not available it produces slotnames annotation by default. - /// - /// Label column. - public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) + /// + /// Produces sequence of columns that are generated by trainer estimators. + /// + /// whether we should also append 'IsNormalized' (typically for probability column) + public static IEnumerable GetTrainerOutputAnnotation(bool isNormalized = false) + { + var cols = new List(); + cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true)); + cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false)); + cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false)); + if (isNormalized) + cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false)); + return cols; + } + + /// + /// Produces annotations for the score column generated by trainer estimators for multiclass classification. + /// If input LabelColumn is not available it produces slotnames annotation by default. + /// + /// Label column. + public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) + { + var cols = new List(); + if (labelColumn != null && labelColumn.Value.IsKey) { - var cols = new List(); - if (labelColumn != null && labelColumn.Value.IsKey) + if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) && + metaCol.Kind == SchemaShape.Column.VectorKind.Vector) { - if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) && - metaCol.Kind == SchemaShape.Column.VectorKind.Vector) - { - if (metaCol.ItemType is TextDataViewType) - cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); - cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false)); - } + if (metaCol.ItemType is TextDataViewType) + cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); + cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false)); } - cols.AddRange(GetTrainerOutputAnnotation()); - return cols; } + cols.AddRange(GetTrainerOutputAnnotation()); + return cols; + } + + private sealed class AnnotationRow : DataViewRow + { + private readonly DataViewSchema.Annotations _annotations; - private sealed class AnnotationRow : DataViewRow + public AnnotationRow(DataViewSchema.Annotations annotations) { - private readonly DataViewSchema.Annotations _annotations; + Contracts.AssertValue(annotations); + _annotations = annotations; + } - public AnnotationRow(DataViewSchema.Annotations annotations) - { - Contracts.AssertValue(annotations); - _annotations = annotations; - } + public override DataViewSchema Schema => _annotations.Schema; + public override long Position => 0; + public override long Batch => 0; - public override DataViewSchema Schema => _annotations.Schema; - public override long Position => 0; - public override long Batch => 0; - - /// - /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. - /// This throws if the column is not active in this row, or if the type - /// differs from this column's type. - /// - /// is the column's content type. - /// is the output column whose getter should be returned. - public override ValueGetter GetGetter(DataViewSchema.Column column) => _annotations.GetGetter(column); - - public override ValueGetter GetIdGetter() => (ref DataViewRowId dst) => dst = default; - - /// - /// Returns whether the given column is active in this row. - /// - public override bool IsColumnActive(DataViewSchema.Column column) => true; - } + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) => _annotations.GetGetter(column); + + public override ValueGetter GetIdGetter() => (ref DataViewRowId dst) => dst = default; /// - /// Presents a as a an . + /// Returns whether the given column is active in this row. /// - /// The annotations to wrap. - /// A row that wraps an input annotations. - [BestFriend] - internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations) - { - Contracts.CheckValue(annotations, nameof(annotations)); - return new AnnotationRow(annotations); - } + public override bool IsColumnActive(DataViewSchema.Column column) => true; + } + + /// + /// Presents a as a an . + /// + /// The annotations to wrap. + /// A row that wraps an input annotations. + [BestFriend] + internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations) + { + Contracts.CheckValue(annotations, nameof(annotations)); + return new AnnotationRow(annotations); } } diff --git a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs index d1465f2c03..d595ff4d06 100644 --- a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs +++ b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs @@ -5,165 +5,164 @@ using System; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Extension methods related to the ColumnType class. +/// +[BestFriend] +internal static class ColumnTypeExtensions { /// - /// Extension methods related to the ColumnType class. + /// Whether this type is a standard scalar type completely determined by its + /// (not a or , etc). + /// + public static bool IsStandardScalar(this DataViewType columnType) => + (columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) || + (columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) || + (columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType); + + /// + /// Zero return means it's not a key type. + /// + public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0; + + /// + /// Sometimes it is necessary to cast the Count to an int. This performs overflow check. + /// Zero return means it's not a key type. /// - [BestFriend] - internal static class ColumnTypeExtensions + public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null) { - /// - /// Whether this type is a standard scalar type completely determined by its - /// (not a or , etc). - /// - public static bool IsStandardScalar(this DataViewType columnType) => - (columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) || - (columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) || - (columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType); - - /// - /// Zero return means it's not a key type. - /// - public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0; - - /// - /// Sometimes it is necessary to cast the Count to an int. This performs overflow check. - /// Zero return means it's not a key type. - /// - public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null) - { - ulong count = columnType.GetKeyCount(); - ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue."); - return (int)count; - } + ulong count = columnType.GetKeyCount(); + ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue."); + return (int)count; + } - /// - /// For non-vector types, this returns the column type itself (i.e., return ). - /// For vector types, this returns the type of the items stored as values in vector. - /// - public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType; - - /// - /// Zero return means either it's not a vector or the size is unknown. - /// - public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0; - - /// - /// For non-vectors, this returns one. For unknown size vectors, it returns zero. - /// For known sized vectors, it returns size. - /// - public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1; - - /// - /// Whether this is a vector type with known size. Returns false for non-vector types. - /// Equivalent to > 0. - /// - public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0; - - /// - /// Gets the equivalent for the 's RawType. - /// This can return default() if the RawType doesn't have a corresponding - /// . - /// - public static InternalDataKind GetRawKind(this DataViewType columnType) - { - columnType.RawType.TryGetDataKind(out InternalDataKind result); - return result; - } + /// + /// For non-vector types, this returns the column type itself (i.e., return ). + /// For vector types, this returns the type of the items stored as values in vector. + /// + public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType; - /// - /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type, - /// returns true if current and other vector types have the same size and item type. - /// - public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other) - { - if (other == null) - return false; - - if (columnType.Equals(other)) - return true; - - // For vector types, we don't care about the factoring of the dimensions. - if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType)) - return false; - if (!vectorType.ItemType.Equals(otherVectorType.ItemType)) - return false; - return vectorType.Size == otherVectorType.Size; - } + /// + /// Zero return means either it's not a vector or the size is unknown. + /// + public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0; - public static PrimitiveDataViewType PrimitiveTypeFromType(Type type) - { - if (type == typeof(ReadOnlyMemory) || type == typeof(string)) - return TextDataViewType.Instance; - if (type == typeof(bool)) - return BooleanDataViewType.Instance; - if (type == typeof(TimeSpan)) - return TimeSpanDataViewType.Instance; - if (type == typeof(DateTime)) - return DateTimeDataViewType.Instance; - if (type == typeof(DateTimeOffset)) - return DateTimeOffsetDataViewType.Instance; - if (type == typeof(DataViewRowId)) - return RowIdDataViewType.Instance; - return NumberTypeFromType(type); - } + /// + /// For non-vectors, this returns one. For unknown size vectors, it returns zero. + /// For known sized vectors, it returns size. + /// + public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1; - public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind) - { - if (kind == InternalDataKind.TX) - return TextDataViewType.Instance; - if (kind == InternalDataKind.BL) - return BooleanDataViewType.Instance; - if (kind == InternalDataKind.TS) - return TimeSpanDataViewType.Instance; - if (kind == InternalDataKind.DT) - return DateTimeDataViewType.Instance; - if (kind == InternalDataKind.DZ) - return DateTimeOffsetDataViewType.Instance; - if (kind == InternalDataKind.UG) - return RowIdDataViewType.Instance; - return NumberTypeFromKind(kind); - } + /// + /// Whether this is a vector type with known size. Returns false for non-vector types. + /// Equivalent to > 0. + /// + public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0; - public static NumberDataViewType NumberTypeFromType(Type type) - { - InternalDataKind kind; - if (type.TryGetDataKind(out kind)) - return NumberTypeFromKind(kind); + /// + /// Gets the equivalent for the 's RawType. + /// This can return default() if the RawType doesn't have a corresponding + /// . + /// + public static InternalDataKind GetRawKind(this DataViewType columnType) + { + columnType.RawType.TryGetDataKind(out InternalDataKind result); + return result; + } - Contracts.Assert(false); - throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}"); - } + /// + /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type, + /// returns true if current and other vector types have the same size and item type. + /// + public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other) + { + if (other == null) + return false; + + if (columnType.Equals(other)) + return true; + + // For vector types, we don't care about the factoring of the dimensions. + if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType)) + return false; + if (!vectorType.ItemType.Equals(otherVectorType.ItemType)) + return false; + return vectorType.Size == otherVectorType.Size; + } - private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind) + public static PrimitiveDataViewType PrimitiveTypeFromType(Type type) + { + if (type == typeof(ReadOnlyMemory) || type == typeof(string)) + return TextDataViewType.Instance; + if (type == typeof(bool)) + return BooleanDataViewType.Instance; + if (type == typeof(TimeSpan)) + return TimeSpanDataViewType.Instance; + if (type == typeof(DateTime)) + return DateTimeDataViewType.Instance; + if (type == typeof(DateTimeOffset)) + return DateTimeOffsetDataViewType.Instance; + if (type == typeof(DataViewRowId)) + return RowIdDataViewType.Instance; + return NumberTypeFromType(type); + } + + public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind) + { + if (kind == InternalDataKind.TX) + return TextDataViewType.Instance; + if (kind == InternalDataKind.BL) + return BooleanDataViewType.Instance; + if (kind == InternalDataKind.TS) + return TimeSpanDataViewType.Instance; + if (kind == InternalDataKind.DT) + return DateTimeDataViewType.Instance; + if (kind == InternalDataKind.DZ) + return DateTimeOffsetDataViewType.Instance; + if (kind == InternalDataKind.UG) + return RowIdDataViewType.Instance; + return NumberTypeFromKind(kind); + } + + public static NumberDataViewType NumberTypeFromType(Type type) + { + InternalDataKind kind; + if (type.TryGetDataKind(out kind)) + return NumberTypeFromKind(kind); + + Contracts.Assert(false); + throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}"); + } + + private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind) + { + switch (kind) { - switch (kind) - { - case InternalDataKind.I1: - return NumberDataViewType.SByte; - case InternalDataKind.U1: - return NumberDataViewType.Byte; - case InternalDataKind.I2: - return NumberDataViewType.Int16; - case InternalDataKind.U2: - return NumberDataViewType.UInt16; - case InternalDataKind.I4: - return NumberDataViewType.Int32; - case InternalDataKind.U4: - return NumberDataViewType.UInt32; - case InternalDataKind.I8: - return NumberDataViewType.Int64; - case InternalDataKind.U8: - return NumberDataViewType.UInt64; - case InternalDataKind.R4: - return NumberDataViewType.Single; - case InternalDataKind.R8: - return NumberDataViewType.Double; - } - - Contracts.Assert(false); - throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}"); + case InternalDataKind.I1: + return NumberDataViewType.SByte; + case InternalDataKind.U1: + return NumberDataViewType.Byte; + case InternalDataKind.I2: + return NumberDataViewType.Int16; + case InternalDataKind.U2: + return NumberDataViewType.UInt16; + case InternalDataKind.I4: + return NumberDataViewType.Int32; + case InternalDataKind.U4: + return NumberDataViewType.UInt32; + case InternalDataKind.I8: + return NumberDataViewType.Int64; + case InternalDataKind.U8: + return NumberDataViewType.UInt64; + case InternalDataKind.R4: + return NumberDataViewType.Single; + case InternalDataKind.R8: + return NumberDataViewType.Double; } + + Contracts.Assert(false); + throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}"); } } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 3baf7b5511..567819278e 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -4,380 +4,379 @@ using System; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Specifies a simple data type. +/// +/// +/// or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory` | | +/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` | +/// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value | +/// +/// The table below shows the missing value definition for each of the data types. +/// +/// | Type | Missing Value | IsMissing Indicator | +/// | -- | -- | -- | +/// | or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` | +/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` | +/// | | | | +/// | | | | +/// | All other types | Not defined | Always `false` | +/// +/// ]]> +/// +/// +// Data type specifiers mainly used in creating text loader and type converter. +public enum DataKind : byte { + /// 1-byte integer, type of . + SByte = 1, + /// 1-byte unsigned integer, type of . + Byte = 2, + /// 2-byte integer, type of . + Int16 = 3, + /// 2-byte unsigned integer, type of . + UInt16 = 4, + /// 4-byte integer, type of . + Int32 = 5, + /// 4-byte unsigned integer, type of . + UInt32 = 6, + /// 8-byte integer, type of . + Int64 = 7, + /// 8-byte unsigned integer, type of . + UInt64 = 8, + /// 4-byte floating-point number, type of . + Single = 9, + /// 8-byte floating-point number, type of . + Double = 10, /// - /// Specifies a simple data type. + /// string, type of , where T is . + /// Also compatible with . /// - /// - /// or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory` | | - /// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` | - /// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value | - /// - /// The table below shows the missing value definition for each of the data types. - /// - /// | Type | Missing Value | IsMissing Indicator | - /// | -- | -- | -- | - /// | or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` | - /// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` | - /// | | | | - /// | | | | - /// | All other types | Not defined | Always `false` | - /// - /// ]]> - /// - /// - // Data type specifiers mainly used in creating text loader and type converter. - public enum DataKind : byte - { - /// 1-byte integer, type of . - SByte = 1, - /// 1-byte unsigned integer, type of . - Byte = 2, - /// 2-byte integer, type of . - Int16 = 3, - /// 2-byte unsigned integer, type of . - UInt16 = 4, - /// 4-byte integer, type of . - Int32 = 5, - /// 4-byte unsigned integer, type of . - UInt32 = 6, - /// 8-byte integer, type of . - Int64 = 7, - /// 8-byte unsigned integer, type of . - UInt64 = 8, - /// 4-byte floating-point number, type of . - Single = 9, - /// 8-byte floating-point number, type of . - Double = 10, - /// - /// string, type of , where T is . - /// Also compatible with . - /// - String = 11, - /// boolean variable type, type of . - Boolean = 12, - /// type of . - TimeSpan = 13, - /// type of . - DateTime = 14, - /// type of . - DateTimeOffset = 15, - } + String = 11, + /// boolean variable type, type of . + Boolean = 12, + /// type of . + TimeSpan = 13, + /// type of . + DateTime = 14, + /// type of . + DateTimeOffset = 15, +} - /// - /// Data type specifier used in command line. is the underlying version of - /// used for command line and entry point BC. - /// - [BestFriend] - internal enum InternalDataKind : byte - { - // Notes: - // * These values are serialized, so changing them breaks binary formats. - // * We intentionally skip zero. - // * Some code depends on sizeof(DataKind) == sizeof(byte). +/// +/// Data type specifier used in command line. is the underlying version of +/// used for command line and entry point BC. +/// +[BestFriend] +internal enum InternalDataKind : byte +{ + // Notes: + // * These values are serialized, so changing them breaks binary formats. + // * We intentionally skip zero. + // * Some code depends on sizeof(DataKind) == sizeof(byte). - I1 = DataKind.SByte, - U1 = DataKind.Byte, - I2 = DataKind.Int16, - U2 = DataKind.UInt16, - I4 = DataKind.Int32, - U4 = DataKind.UInt32, - I8 = DataKind.Int64, - U8 = DataKind.UInt64, - R4 = DataKind.Single, - R8 = DataKind.Double, - Num = R4, + I1 = DataKind.SByte, + U1 = DataKind.Byte, + I2 = DataKind.Int16, + U2 = DataKind.UInt16, + I4 = DataKind.Int32, + U4 = DataKind.UInt32, + I8 = DataKind.Int64, + U8 = DataKind.UInt64, + R4 = DataKind.Single, + R8 = DataKind.Double, + Num = R4, - TX = DataKind.String, + TX = DataKind.String, #pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independent of C# naming conventions. - TXT = TX, - Text = TX, + TXT = TX, + Text = TX, - BL = DataKind.Boolean, - Bool = BL, + BL = DataKind.Boolean, + Bool = BL, - TS = DataKind.TimeSpan, - TimeSpan = TS, - DT = DataKind.DateTime, - DateTime = DT, - DZ = DataKind.DateTimeOffset, - DateTimeZone = DZ, + TS = DataKind.TimeSpan, + TimeSpan = TS, + DT = DataKind.DateTime, + DateTime = DT, + DZ = DataKind.DateTimeOffset, + DateTimeZone = DZ, - UG = 16, // Unsigned 16-byte integer. - U16 = UG, + UG = 16, // Unsigned 16-byte integer. + U16 = UG, #pragma warning restore MSML_GeneralName - } +} + +/// +/// Extension methods related to the DataKind enum. +/// +[BestFriend] +internal static class InternalDataKindExtensions +{ + public const InternalDataKind KindMin = InternalDataKind.I1; + public const InternalDataKind KindLim = InternalDataKind.U16 + 1; + public const int KindCount = KindLim - KindMin; /// - /// Extension methods related to the DataKind enum. + /// Maps a DataKind to a value suitable for indexing into an array of size KindCount. /// - [BestFriend] - internal static class InternalDataKindExtensions + public static int ToIndex(this InternalDataKind kind) { - public const InternalDataKind KindMin = InternalDataKind.I1; - public const InternalDataKind KindLim = InternalDataKind.U16 + 1; - public const int KindCount = KindLim - KindMin; - - /// - /// Maps a DataKind to a value suitable for indexing into an array of size KindCount. - /// - public static int ToIndex(this InternalDataKind kind) - { - return kind - KindMin; - } - - /// - /// Maps from an index into an array of size KindCount to the corresponding DataKind - /// - public static InternalDataKind FromIndex(int index) - { - Contracts.Check(0 <= index && index < KindCount); - return (InternalDataKind)(index + (int)KindMin); - } - - /// - /// This function converts to . - /// Because is a subset of , the conversion is straightforward. - /// - public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind; + return kind - KindMin; + } - /// - /// This function converts to . - /// Because is a subset of , we should check if - /// can be found in . - /// - public static DataKind ToDataKind(this InternalDataKind kind) - { - Contracts.Check(kind != InternalDataKind.UG); - return (DataKind)kind; - } + /// + /// Maps from an index into an array of size KindCount to the corresponding DataKind + /// + public static InternalDataKind FromIndex(int index) + { + Contracts.Check(0 <= index && index < KindCount); + return (InternalDataKind)(index + (int)KindMin); + } - /// - /// For integer DataKinds, this returns the maximum legal value. For un-supported kinds, - /// it returns zero. - /// - public static ulong ToMaxInt(this InternalDataKind kind) - { - switch (kind) - { - case InternalDataKind.I1: - return (ulong)sbyte.MaxValue; - case InternalDataKind.U1: - return byte.MaxValue; - case InternalDataKind.I2: - return (ulong)short.MaxValue; - case InternalDataKind.U2: - return ushort.MaxValue; - case InternalDataKind.I4: - return int.MaxValue; - case InternalDataKind.U4: - return uint.MaxValue; - case InternalDataKind.I8: - return long.MaxValue; - case InternalDataKind.U8: - return ulong.MaxValue; - } + /// + /// This function converts to . + /// Because is a subset of , the conversion is straightforward. + /// + public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind; - return 0; - } + /// + /// This function converts to . + /// Because is a subset of , we should check if + /// can be found in . + /// + public static DataKind ToDataKind(this InternalDataKind kind) + { + Contracts.Check(kind != InternalDataKind.UG); + return (DataKind)kind; + } - /// - /// For integer Types, this returns the maximum legal value. For un-supported Types, - /// it returns zero. - /// - public static ulong ToMaxInt(this Type type) + /// + /// For integer DataKinds, this returns the maximum legal value. For un-supported kinds, + /// it returns zero. + /// + public static ulong ToMaxInt(this InternalDataKind kind) + { + switch (kind) { - if (type == typeof(sbyte)) + case InternalDataKind.I1: return (ulong)sbyte.MaxValue; - else if (type == typeof(byte)) + case InternalDataKind.U1: return byte.MaxValue; - else if (type == typeof(short)) + case InternalDataKind.I2: return (ulong)short.MaxValue; - else if (type == typeof(ushort)) + case InternalDataKind.U2: return ushort.MaxValue; - else if (type == typeof(int)) + case InternalDataKind.I4: return int.MaxValue; - else if (type == typeof(uint)) + case InternalDataKind.U4: return uint.MaxValue; - else if (type == typeof(long)) + case InternalDataKind.I8: return long.MaxValue; - else if (type == typeof(ulong)) + case InternalDataKind.U8: return ulong.MaxValue; - - return 0; } - /// - /// For integer DataKinds, this returns the minimum legal value. For un-supported kinds, - /// it returns one. - /// - public static long ToMinInt(this InternalDataKind kind) - { - switch (kind) - { - case InternalDataKind.I1: - return sbyte.MinValue; - case InternalDataKind.U1: - return byte.MinValue; - case InternalDataKind.I2: - return short.MinValue; - case InternalDataKind.U2: - return ushort.MinValue; - case InternalDataKind.I4: - return int.MinValue; - case InternalDataKind.U4: - return uint.MinValue; - case InternalDataKind.I8: - return long.MinValue; - case InternalDataKind.U8: - return 0; - } + return 0; + } - return 1; - } + /// + /// For integer Types, this returns the maximum legal value. For un-supported Types, + /// it returns zero. + /// + public static ulong ToMaxInt(this Type type) + { + if (type == typeof(sbyte)) + return (ulong)sbyte.MaxValue; + else if (type == typeof(byte)) + return byte.MaxValue; + else if (type == typeof(short)) + return (ulong)short.MaxValue; + else if (type == typeof(ushort)) + return ushort.MaxValue; + else if (type == typeof(int)) + return int.MaxValue; + else if (type == typeof(uint)) + return uint.MaxValue; + else if (type == typeof(long)) + return long.MaxValue; + else if (type == typeof(ulong)) + return ulong.MaxValue; - /// - /// Maps a DataKind to the associated .Net representation type. - /// - public static Type ToType(this InternalDataKind kind) - { - switch (kind) - { - case InternalDataKind.I1: - return typeof(sbyte); - case InternalDataKind.U1: - return typeof(byte); - case InternalDataKind.I2: - return typeof(short); - case InternalDataKind.U2: - return typeof(ushort); - case InternalDataKind.I4: - return typeof(int); - case InternalDataKind.U4: - return typeof(uint); - case InternalDataKind.I8: - return typeof(long); - case InternalDataKind.U8: - return typeof(ulong); - case InternalDataKind.R4: - return typeof(Single); - case InternalDataKind.R8: - return typeof(Double); - case InternalDataKind.TX: - return typeof(ReadOnlyMemory); - case InternalDataKind.BL: - return typeof(bool); - case InternalDataKind.TS: - return typeof(TimeSpan); - case InternalDataKind.DT: - return typeof(DateTime); - case InternalDataKind.DZ: - return typeof(DateTimeOffset); - case InternalDataKind.UG: - return typeof(DataViewRowId); - } + return 0; + } - return null; + /// + /// For integer DataKinds, this returns the minimum legal value. For un-supported kinds, + /// it returns one. + /// + public static long ToMinInt(this InternalDataKind kind) + { + switch (kind) + { + case InternalDataKind.I1: + return sbyte.MinValue; + case InternalDataKind.U1: + return byte.MinValue; + case InternalDataKind.I2: + return short.MinValue; + case InternalDataKind.U2: + return ushort.MinValue; + case InternalDataKind.I4: + return int.MinValue; + case InternalDataKind.U4: + return uint.MinValue; + case InternalDataKind.I8: + return long.MinValue; + case InternalDataKind.U8: + return 0; } - /// - /// Try to map a System.Type to a corresponding DataKind value. - /// - public static bool TryGetDataKind(this Type type, out InternalDataKind kind) + return 1; + } + + /// + /// Maps a DataKind to the associated .Net representation type. + /// + public static Type ToType(this InternalDataKind kind) + { + switch (kind) { - Contracts.CheckValueOrNull(type); + case InternalDataKind.I1: + return typeof(sbyte); + case InternalDataKind.U1: + return typeof(byte); + case InternalDataKind.I2: + return typeof(short); + case InternalDataKind.U2: + return typeof(ushort); + case InternalDataKind.I4: + return typeof(int); + case InternalDataKind.U4: + return typeof(uint); + case InternalDataKind.I8: + return typeof(long); + case InternalDataKind.U8: + return typeof(ulong); + case InternalDataKind.R4: + return typeof(Single); + case InternalDataKind.R8: + return typeof(Double); + case InternalDataKind.TX: + return typeof(ReadOnlyMemory); + case InternalDataKind.BL: + return typeof(bool); + case InternalDataKind.TS: + return typeof(TimeSpan); + case InternalDataKind.DT: + return typeof(DateTime); + case InternalDataKind.DZ: + return typeof(DateTimeOffset); + case InternalDataKind.UG: + return typeof(DataViewRowId); + } + + return null; + } - // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(sbyte)) - kind = InternalDataKind.I1; - else if (type == typeof(byte)) - kind = InternalDataKind.U1; - else if (type == typeof(short)) - kind = InternalDataKind.I2; - else if (type == typeof(ushort)) - kind = InternalDataKind.U2; - else if (type == typeof(int)) - kind = InternalDataKind.I4; - else if (type == typeof(uint)) - kind = InternalDataKind.U4; - else if (type == typeof(long)) - kind = InternalDataKind.I8; - else if (type == typeof(ulong)) - kind = InternalDataKind.U8; - else if (type == typeof(Single)) - kind = InternalDataKind.R4; - else if (type == typeof(Double)) - kind = InternalDataKind.R8; - else if (type == typeof(ReadOnlyMemory) || type == typeof(string)) - kind = InternalDataKind.TX; - else if (type == typeof(bool)) - kind = InternalDataKind.BL; - else if (type == typeof(TimeSpan)) - kind = InternalDataKind.TS; - else if (type == typeof(DateTime)) - kind = InternalDataKind.DT; - else if (type == typeof(DateTimeOffset)) - kind = InternalDataKind.DZ; - else if (type == typeof(DataViewRowId)) - kind = InternalDataKind.UG; - else - { - kind = default(InternalDataKind); - return false; - } + /// + /// Try to map a System.Type to a corresponding DataKind value. + /// + public static bool TryGetDataKind(this Type type, out InternalDataKind kind) + { + Contracts.CheckValueOrNull(type); - return true; + // REVIEW: Make this more efficient. Should we have a global dictionary? + if (type == typeof(sbyte)) + kind = InternalDataKind.I1; + else if (type == typeof(byte)) + kind = InternalDataKind.U1; + else if (type == typeof(short)) + kind = InternalDataKind.I2; + else if (type == typeof(ushort)) + kind = InternalDataKind.U2; + else if (type == typeof(int)) + kind = InternalDataKind.I4; + else if (type == typeof(uint)) + kind = InternalDataKind.U4; + else if (type == typeof(long)) + kind = InternalDataKind.I8; + else if (type == typeof(ulong)) + kind = InternalDataKind.U8; + else if (type == typeof(Single)) + kind = InternalDataKind.R4; + else if (type == typeof(Double)) + kind = InternalDataKind.R8; + else if (type == typeof(ReadOnlyMemory) || type == typeof(string)) + kind = InternalDataKind.TX; + else if (type == typeof(bool)) + kind = InternalDataKind.BL; + else if (type == typeof(TimeSpan)) + kind = InternalDataKind.TS; + else if (type == typeof(DateTime)) + kind = InternalDataKind.DT; + else if (type == typeof(DateTimeOffset)) + kind = InternalDataKind.DZ; + else if (type == typeof(DataViewRowId)) + kind = InternalDataKind.UG; + else + { + kind = default(InternalDataKind); + return false; } - /// - /// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable - /// and is also slow, so use this instead. - /// - public static string GetString(this InternalDataKind kind) + return true; + } + + /// + /// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable + /// and is also slow, so use this instead. + /// + public static string GetString(this InternalDataKind kind) + { + switch (kind) { - switch (kind) - { - case InternalDataKind.I1: - return "I1"; - case InternalDataKind.I2: - return "I2"; - case InternalDataKind.I4: - return "I4"; - case InternalDataKind.I8: - return "I8"; - case InternalDataKind.U1: - return "U1"; - case InternalDataKind.U2: - return "U2"; - case InternalDataKind.U4: - return "U4"; - case InternalDataKind.U8: - return "U8"; - case InternalDataKind.R4: - return "R4"; - case InternalDataKind.R8: - return "R8"; - case InternalDataKind.BL: - return "BL"; - case InternalDataKind.TX: - return "TX"; - case InternalDataKind.TS: - return "TS"; - case InternalDataKind.DT: - return "DT"; - case InternalDataKind.DZ: - return "DZ"; - case InternalDataKind.UG: - return "UG"; - } - return ""; + case InternalDataKind.I1: + return "I1"; + case InternalDataKind.I2: + return "I2"; + case InternalDataKind.I4: + return "I4"; + case InternalDataKind.I8: + return "I8"; + case InternalDataKind.U1: + return "U1"; + case InternalDataKind.U2: + return "U2"; + case InternalDataKind.U4: + return "U4"; + case InternalDataKind.U8: + return "U8"; + case InternalDataKind.R4: + return "R4"; + case InternalDataKind.R8: + return "R8"; + case InternalDataKind.BL: + return "BL"; + case InternalDataKind.TX: + return "TX"; + case InternalDataKind.TS: + return "TS"; + case InternalDataKind.DT: + return "DT"; + case InternalDataKind.DZ: + return "DZ"; + case InternalDataKind.UG: + return "UG"; } + return ""; } } diff --git a/src/Microsoft.ML.Core/Data/ICommand.cs b/src/Microsoft.ML.Core/Data/ICommand.cs index 7d52222647..f71360510f 100644 --- a/src/Microsoft.ML.Core/Data/ICommand.cs +++ b/src/Microsoft.ML.Core/Data/ICommand.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Command -{ - /// - /// The signature for commands. - /// - [BestFriend] - internal delegate void SignatureCommand(); +namespace Microsoft.ML.Command; + +/// +/// The signature for commands. +/// +[BestFriend] +internal delegate void SignatureCommand(); - [BestFriend] - internal interface ICommand - { - void Run(); - } +[BestFriend] +internal interface ICommand +{ + void Run(); } diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 7dcbab4e75..1917bbeb15 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -8,315 +8,314 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema. +/// This is more relaxed than the proper , since it's only a subset of the columns, +/// and also since it doesn't specify exact 's for vectors and keys. +/// +public sealed class SchemaShape : IReadOnlyList { - /// - /// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema. - /// This is more relaxed than the proper , since it's only a subset of the columns, - /// and also since it doesn't specify exact 's for vectors and keys. - /// - public sealed class SchemaShape : IReadOnlyList - { - private readonly Column[] _columns; + private readonly Column[] _columns; - private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty()); + private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty()); - public int Count => _columns.Count(); + public int Count => _columns.Count(); - public Column this[int index] => _columns[index]; + public Column this[int index] => _columns[index]; - public struct Column + public struct Column + { + public enum VectorKind { - public enum VectorKind - { - Scalar, - Vector, - VariableVector - } + Scalar, + Vector, + VariableVector + } - /// - /// The column name. - /// - public readonly string Name; - - /// - /// The type of the column: scalar, fixed vector or variable vector. - /// - public readonly VectorKind Kind; - - /// - /// The 'raw' type of column item: must be a primitive type or a structured type. - /// - public readonly DataViewType ItemType; - /// - /// The flag whether the column is actually a key. If yes, is representing - /// the underlying primitive type. - /// - public readonly bool IsKey; - /// - /// The annotations that are present for this column. - /// - public readonly SchemaShape Annotations; - - [BestFriend] - internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValueOrNull(annotations); - Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key"); - Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector"); - Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key"); - - Name = name; - Kind = vecKind; - ItemType = itemType; - IsKey = isKey; - Annotations = annotations ?? _empty; - } + /// + /// The column name. + /// + public readonly string Name; + + /// + /// The type of the column: scalar, fixed vector or variable vector. + /// + public readonly VectorKind Kind; - /// - /// Returns whether is a valid input, if this object represents a - /// requirement. - /// - /// Namely, it returns true iff: - /// - The , , , fields match. - /// - The columns of of is a superset of our columns. - /// - Each such annotation column is itself compatible with the input annotation column. - /// - [BestFriend] - internal bool IsCompatibleWith(Column source) + /// + /// The 'raw' type of column item: must be a primitive type or a structured type. + /// + public readonly DataViewType ItemType; + /// + /// The flag whether the column is actually a key. If yes, is representing + /// the underlying primitive type. + /// + public readonly bool IsKey; + /// + /// The annotations that are present for this column. + /// + public readonly SchemaShape Annotations; + + [BestFriend] + internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null) + { + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValueOrNull(annotations); + Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key"); + Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector"); + Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key"); + + Name = name; + Kind = vecKind; + ItemType = itemType; + IsKey = isKey; + Annotations = annotations ?? _empty; + } + + /// + /// Returns whether is a valid input, if this object represents a + /// requirement. + /// + /// Namely, it returns true iff: + /// - The , , , fields match. + /// - The columns of of is a superset of our columns. + /// - Each such annotation column is itself compatible with the input annotation column. + /// + [BestFriend] + internal bool IsCompatibleWith(Column source) + { + Contracts.Check(source.IsValid, nameof(source)); + if (Name != source.Name) + return false; + if (Kind != source.Kind) + return false; + if (!ItemType.Equals(source.ItemType)) + return false; + if (IsKey != source.IsKey) + return false; + foreach (var annotationCol in Annotations) { - Contracts.Check(source.IsValid, nameof(source)); - if (Name != source.Name) + if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol)) return false; - if (Kind != source.Kind) + if (!annotationCol.IsCompatibleWith(inputAnnotationCol)) return false; - if (!ItemType.Equals(source.ItemType)) - return false; - if (IsKey != source.IsKey) - return false; - foreach (var annotationCol in Annotations) - { - if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol)) - return false; - if (!annotationCol.IsCompatibleWith(inputAnnotationCol)) - return false; - } - return true; } - - [BestFriend] - internal string GetTypeString() - { - string result = ItemType.ToString(); - if (IsKey) - result = $"Key<{result}>"; - if (Kind == VectorKind.Vector) - result = $"Vector<{result}>"; - else if (Kind == VectorKind.VariableVector) - result = $"VarVector<{result}>"; - return result; - } - - /// - /// Return if this structure is not identical to the default value of . If true, - /// it means this structure is initialized properly and therefore considered as valid. - /// - [BestFriend] - internal bool IsValid => Name != null; + return true; } - public SchemaShape(IEnumerable columns) + [BestFriend] + internal string GetTypeString() { - Contracts.CheckValue(columns, nameof(columns)); - _columns = columns.ToArray(); - Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly."); + string result = ItemType.ToString(); + if (IsKey) + result = $"Key<{result}>"; + if (Kind == VectorKind.Vector) + result = $"Vector<{result}>"; + else if (Kind == VectorKind.VariableVector) + result = $"VarVector<{result}>"; + return result; } /// - /// Given a , extract the type parameters that describe this type - /// as a 's column type. + /// Return if this structure is not identical to the default value of . If true, + /// it means this structure is initialized properly and therefore considered as valid. /// - /// The actual column type to process. - /// The vector kind of . - /// The item type of . - /// Whether (or its item type) is a key. [BestFriend] - internal static void GetColumnTypeShape(DataViewType type, - out Column.VectorKind vecKind, - out DataViewType itemType, - out bool isKey) + internal bool IsValid => Name != null; + } + + public SchemaShape(IEnumerable columns) + { + Contracts.CheckValue(columns, nameof(columns)); + _columns = columns.ToArray(); + Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly."); + } + + /// + /// Given a , extract the type parameters that describe this type + /// as a 's column type. + /// + /// The actual column type to process. + /// The vector kind of . + /// The item type of . + /// Whether (or its item type) is a key. + [BestFriend] + internal static void GetColumnTypeShape(DataViewType type, + out Column.VectorKind vecKind, + out DataViewType itemType, + out bool isKey) + { + if (type is VectorDataViewType vectorType) { - if (type is VectorDataViewType vectorType) + if (vectorType.IsKnownSize) { - if (vectorType.IsKnownSize) - { - vecKind = Column.VectorKind.Vector; - } - else - { - vecKind = Column.VectorKind.VariableVector; - } - - itemType = vectorType.ItemType; + vecKind = Column.VectorKind.Vector; } else { - vecKind = Column.VectorKind.Scalar; - itemType = type; + vecKind = Column.VectorKind.VariableVector; } - isKey = itemType is KeyDataViewType; - if (isKey) - itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType); + itemType = vectorType.ItemType; } - - /// - /// Create a schema shape out of the fully defined schema. - /// - [BestFriend] - internal static SchemaShape Create(DataViewSchema schema) + else { - Contracts.CheckValue(schema, nameof(schema)); - var cols = new List(); + vecKind = Column.VectorKind.Scalar; + itemType = type; + } + + isKey = itemType is KeyDataViewType; + if (isKey) + itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType); + } + + /// + /// Create a schema shape out of the fully defined schema. + /// + [BestFriend] + internal static SchemaShape Create(DataViewSchema schema) + { + Contracts.CheckValue(schema, nameof(schema)); + var cols = new List(); - for (int iCol = 0; iCol < schema.Count; iCol++) + for (int iCol = 0; iCol < schema.Count; iCol++) + { + if (!schema[iCol].IsHidden) { - if (!schema[iCol].IsHidden) + // First create the annotations. + var mCols = new List(); + foreach (var annotationColumn in schema[iCol].Annotations.Schema) { - // First create the annotations. - var mCols = new List(); - foreach (var annotationColumn in schema[iCol].Annotations.Schema) - { - GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey); - mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey)); - } - var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty; - // Next create the single column. - GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey); - cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations)); + GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey); + mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey)); } + var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty; + // Next create the single column. + GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey); + cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations)); } - return new SchemaShape(cols); } + return new SchemaShape(cols); + } - /// - /// Returns if there is a column with a specified and if so stores it in . - /// - [BestFriend] - internal bool TryFindColumn(string name, out Column column) - { - Contracts.CheckValue(name, nameof(name)); - column = _columns.FirstOrDefault(x => x.Name == name); - return column.IsValid; - } + /// + /// Returns if there is a column with a specified and if so stores it in . + /// + [BestFriend] + internal bool TryFindColumn(string name, out Column column) + { + Contracts.CheckValue(name, nameof(name)); + column = _columns.FirstOrDefault(x => x.Name == name); + return column.IsValid; + } - public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator(); + public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape - // as an input to another schema shape. I started writing, but realized that there's more than one way to check for - // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'. - } + // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape + // as an input to another schema shape. I started writing, but realized that there's more than one way to check for + // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'. +} +/// +/// The 'data loader' takes a certain kind of input and turns it into an . +/// +/// The type of input the loader takes. +public interface IDataLoader : ICanSaveModel +{ /// - /// The 'data loader' takes a certain kind of input and turns it into an . + /// Produce the data view from the specified input. + /// Note that 's are lazy, so no actual loading happens here, just schema validation. /// - /// The type of input the loader takes. - public interface IDataLoader : ICanSaveModel - { - /// - /// Produce the data view from the specified input. - /// Note that 's are lazy, so no actual loading happens here, just schema validation. - /// - IDataView Load(TSource input); - - /// - /// The output schema of the loader. - /// - DataViewSchema GetOutputSchema(); - } + IDataView Load(TSource input); /// - /// Sometimes we need to 'fit' an . - /// A DataLoader estimator is the object that does it. + /// The output schema of the loader. /// - public interface IDataLoaderEstimator - where TLoader : IDataLoader - { - // REVIEW: you could consider the transformer to take a different , but we don't have such components - // yet, so why complicate matters? + DataViewSchema GetOutputSchema(); +} - /// - /// Train and return a data loader. - /// - TLoader Fit(TSource input); +/// +/// Sometimes we need to 'fit' an . +/// A DataLoader estimator is the object that does it. +/// +public interface IDataLoaderEstimator + where TLoader : IDataLoader +{ + // REVIEW: you could consider the transformer to take a different , but we don't have such components + // yet, so why complicate matters? - /// - /// The 'promise' of the output schema. - /// It will be used for schema propagation. - /// - SchemaShape GetOutputSchema(); - } + /// + /// Train and return a data loader. + /// + TLoader Fit(TSource input); /// - /// The transformer is a component that transforms data. - /// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'. + /// The 'promise' of the output schema. + /// It will be used for schema propagation. /// - public interface ITransformer : ICanSaveModel - { - /// - /// Schema propagation for transformers. - /// Returns the output schema of the data, if the input schema is like the one provided. - /// - DataViewSchema GetOutputSchema(DataViewSchema inputSchema); + SchemaShape GetOutputSchema(); +} - /// - /// Take the data in, make transformations, output the data. - /// Note that 's are lazy, so no actual transformations happen here, just schema validation. - /// - IDataView Transform(IDataView input); +/// +/// The transformer is a component that transforms data. +/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'. +/// +public interface ITransformer : ICanSaveModel +{ + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// + DataViewSchema GetOutputSchema(DataViewSchema inputSchema); - /// - /// Whether a call to should succeed, on an - /// appropriate schema. - /// - bool IsRowToRowMapper { get; } + /// + /// Take the data in, make transformations, output the data. + /// Note that 's are lazy, so no actual transformations happen here, just schema validation. + /// + IDataView Transform(IDataView input); - /// - /// Constructs a row-to-row mapper based on an input schema. If - /// is false, then an exception should be thrown. If the input schema is in any way - /// unsuitable for constructing the mapper, an exception should likewise be thrown. - /// - /// The input schema for which we should get the mapper. - /// The row to row mapper. - IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema); - } + /// + /// Whether a call to should succeed, on an + /// appropriate schema. + /// + bool IsRowToRowMapper { get; } - [BestFriend] - internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer - { - IDataView TransformForTrainingPipeline(IDataView input); - } + /// + /// Constructs a row-to-row mapper based on an input schema. If + /// is false, then an exception should be thrown. If the input schema is in any way + /// unsuitable for constructing the mapper, an exception should likewise be thrown. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema); +} +[BestFriend] +internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer +{ + IDataView TransformForTrainingPipeline(IDataView input); +} + +/// +/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture +/// a transformer. +/// It also provides the 'schema propagation' like transformers do, but over instead of . +/// +public interface IEstimator + where TTransformer : ITransformer +{ /// - /// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture - /// a transformer. - /// It also provides the 'schema propagation' like transformers do, but over instead of . + /// Train and return a transformer. /// - public interface IEstimator - where TTransformer : ITransformer - { - /// - /// Train and return a transformer. - /// - TTransformer Fit(IDataView input); + TTransformer Fit(IDataView input); - /// - /// Schema propagation for estimators. - /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided. - /// - SchemaShape GetOutputSchema(SchemaShape inputSchema); - } + /// + /// Schema propagation for estimators. + /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided. + /// + SchemaShape GetOutputSchema(SchemaShape inputSchema); } diff --git a/src/Microsoft.ML.Core/Data/IFileHandle.cs b/src/Microsoft.ML.Core/Data/IFileHandle.cs index 5afef22fee..6eac6821a9 100644 --- a/src/Microsoft.ML.Core/Data/IFileHandle.cs +++ b/src/Microsoft.ML.Core/Data/IFileHandle.cs @@ -8,191 +8,190 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// A file handle. +/// +public interface IFileHandle : IDisposable { /// - /// A file handle. + /// Returns whether CreateWriteStream is expected to succeed. Typically, once + /// CreateWriteStream has been called once, this will forever more return false. /// - public interface IFileHandle : IDisposable - { - /// - /// Returns whether CreateWriteStream is expected to succeed. Typically, once - /// CreateWriteStream has been called once, this will forever more return false. - /// - bool CanWrite { get; } - - /// - /// Returns whether OpenReadStream is expected to succeed. - /// - bool CanRead { get; } - - /// - /// Create a writable stream for this file handle. - /// - Stream CreateWriteStream(); - - /// - /// Open a readable stream for this file handle. - /// - Stream OpenReadStream(); - } + bool CanWrite { get; } /// - /// A simple disk-based file handle. + /// Returns whether OpenReadStream is expected to succeed. /// - public sealed class SimpleFileHandle : IFileHandle - { - private readonly string _fullPath; + bool CanRead { get; } + + /// + /// Create a writable stream for this file handle. + /// + Stream CreateWriteStream(); - // Exception context. - private readonly IExceptionContext _ectx; + /// + /// Open a readable stream for this file handle. + /// + Stream OpenReadStream(); +} - private readonly object _lock; +/// +/// A simple disk-based file handle. +/// +public sealed class SimpleFileHandle : IFileHandle +{ + private readonly string _fullPath; - // Whether to delete the file when this is disposed. - private readonly bool _autoDelete; + // Exception context. + private readonly IExceptionContext _ectx; - // Whether this file has contents. This is false if the file needs CreateWriteStream to be - // called (before OpenReadStream can be called). - private bool _wrote; - // If non-null, the active write stream. This should be disposed before the first OpenReadStream call. - private Stream _streamWrite; + private readonly object _lock; - // This contains the potentially active read streams. This is set to null once this file - // handle has been disposed. - private List _streams; + // Whether to delete the file when this is disposed. + private readonly bool _autoDelete; - private bool IsDisposed => _streams == null; + // Whether this file has contents. This is false if the file needs CreateWriteStream to be + // called (before OpenReadStream can be called). + private bool _wrote; + // If non-null, the active write stream. This should be disposed before the first OpenReadStream call. + private Stream _streamWrite; - public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete) - { - Contracts.CheckValue(ectx, nameof(ectx)); - ectx.CheckNonEmpty(path, nameof(path)); + // This contains the potentially active read streams. This is set to null once this file + // handle has been disposed. + private List _streams; - _ectx = ectx; - _fullPath = Path.GetFullPath(path); + private bool IsDisposed => _streams == null; - _autoDelete = autoDelete; + public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete) + { + Contracts.CheckValue(ectx, nameof(ectx)); + ectx.CheckNonEmpty(path, nameof(path)); - // The file has already been written to iff needsWrite is false. - _wrote = !needsWrite; + _ectx = ectx; + _fullPath = Path.GetFullPath(path); - // REVIEW: Should this do some basic validation? Eg, for output files, ensure that - // the directory exists (and perhaps even create an empty file); for input files, ensure - // that the file exists (and perhaps even attempt to open it). + _autoDelete = autoDelete; - _lock = new object(); - _streams = new List(); - } + // The file has already been written to iff needsWrite is false. + _wrote = !needsWrite; + + // REVIEW: Should this do some basic validation? Eg, for output files, ensure that + // the directory exists (and perhaps even create an empty file); for input files, ensure + // that the file exists (and perhaps even attempt to open it). - public bool CanWrite => !_wrote && !IsDisposed; + _lock = new object(); + _streams = new List(); + } + + public bool CanWrite => !_wrote && !IsDisposed; - public bool CanRead => _wrote && !IsDisposed; + public bool CanRead => _wrote && !IsDisposed; - public void Dispose() + public void Dispose() + { + lock (_lock) { - lock (_lock) - { - if (IsDisposed) - return; + if (IsDisposed) + return; - Contracts.Assert(_streams != null); + Contracts.Assert(_streams != null); - // REVIEW: Is it safe to dispose these streams? What if they are - // being used on other threads? Does that matter? - if (_streamWrite != null) + // REVIEW: Is it safe to dispose these streams? What if they are + // being used on other threads? Does that matter? + if (_streamWrite != null) + { + try { - try - { - _streamWrite.CloseEx(); - _streamWrite.Dispose(); - } - catch - { - // REVIEW: What should we do here? - Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!"); - } - _streamWrite = null; + _streamWrite.CloseEx(); + _streamWrite.Dispose(); } + catch + { + // REVIEW: What should we do here? + Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!"); + } + _streamWrite = null; + } - foreach (var stream in _streams) + foreach (var stream in _streams) + { + try + { + stream.CloseEx(); + stream.Dispose(); + } + catch { - try - { - stream.CloseEx(); - stream.Dispose(); - } - catch - { - // REVIEW: What should we do here? - Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!"); - } + // REVIEW: What should we do here? + Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!"); } + } - _streams = null; - Contracts.Assert(IsDisposed); + _streams = null; + Contracts.Assert(IsDisposed); - if (_autoDelete) + if (_autoDelete) + { + try + { + // Finally, delete the file. + File.Delete(_fullPath); + } + catch { - try - { - // Finally, delete the file. - File.Delete(_fullPath); - } - catch - { - // REVIEW: What should we do here? - Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!"); - } + // REVIEW: What should we do here? + Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!"); } } } + } - private void CheckNotDisposed() - { - if (IsDisposed) - throw _ectx.Except("SimpleFileHandle has already been disposed"); - } + private void CheckNotDisposed() + { + if (IsDisposed) + throw _ectx.Except("SimpleFileHandle has already been disposed"); + } - public Stream CreateWriteStream() + public Stream CreateWriteStream() + { + lock (_lock) { - lock (_lock) - { - CheckNotDisposed(); + CheckNotDisposed(); - if (_wrote) - throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle"); + if (_wrote) + throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle"); - Contracts.Assert(_streamWrite == null); - _streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write); - _wrote = true; - return _streamWrite; - } + Contracts.Assert(_streamWrite == null); + _streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write); + _wrote = true; + return _streamWrite; } + } - public Stream OpenReadStream() + public Stream OpenReadStream() + { + lock (_lock) { - lock (_lock) - { - CheckNotDisposed(); + CheckNotDisposed(); - if (!_wrote) - throw _ectx.Except("SimpleFileHandle hasn't been written yet"); + if (!_wrote) + throw _ectx.Except("SimpleFileHandle hasn't been written yet"); - if (_streamWrite != null) - { - if (_streamWrite.CanWrite) - throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed"); - _streamWrite = null; - } + if (_streamWrite != null) + { + if (_streamWrite.CanWrite) + throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed"); + _streamWrite = null; + } - // Drop read streams that have already been disposed. - _streams.RemoveAll(s => !s.CanRead); + // Drop read streams that have already been disposed. + _streams.RemoveAll(s => !s.CanRead); - var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read); - _streams.Add(stream); - return stream; - } + var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read); + _streams.Add(stream); + return stream; } } } diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index 72a1245485..06fe0b32b4 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -5,319 +5,318 @@ using System; using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +/// +/// A channel provider can create new channels and generic information pipes. +/// +public interface IChannelProvider : IExceptionContext { /// - /// A channel provider can create new channels and generic information pipes. + /// Start a standard message channel. /// - public interface IChannelProvider : IExceptionContext - { - /// - /// Start a standard message channel. - /// - IChannel Start(string name); - - /// - /// Start a generic information pipe. - /// - IPipe StartPipe(string name); - } + IChannel Start(string name); /// - /// Utility class for IHostEnvironment + /// Start a generic information pipe. /// - [BestFriend] - internal static class HostEnvironmentExtensions - { - /// - /// Return a file handle for an input "file". - /// - public static IFileHandle OpenInputFile(this IHostEnvironment env, string path) - { - Contracts.AssertValue(env); - Contracts.CheckNonWhiteSpace(path, nameof(path)); - return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false); - } + IPipe StartPipe(string name); +} - /// - /// Create an output "file" and return a handle to it. - /// - public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path) - { - Contracts.AssertValue(env); - Contracts.CheckNonWhiteSpace(path, nameof(path)); - return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false); - } +/// +/// Utility class for IHostEnvironment +/// +[BestFriend] +internal static class HostEnvironmentExtensions +{ + /// + /// Return a file handle for an input "file". + /// + public static IFileHandle OpenInputFile(this IHostEnvironment env, string path) + { + Contracts.AssertValue(env); + Contracts.CheckNonWhiteSpace(path, nameof(path)); + return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false); } /// - /// The host environment interface creates hosts for components. Note that the methods of - /// this interface should be called from the main thread for the environment. To get an environment - /// to service another thread, call Fork and pass the return result to that thread. + /// Create an output "file" and return a handle to it. /// - public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider + public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path) { - /// - /// Create a host with the given registration name. - /// - IHost Register(string name, int? seed = null, bool? verbose = null); - - /// - /// The catalog of loadable components () that are available in this host. - /// - ComponentCatalog ComponentCatalog { get; } + Contracts.AssertValue(env); + Contracts.CheckNonWhiteSpace(path, nameof(path)); + return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false); } +} + +/// +/// The host environment interface creates hosts for components. Note that the methods of +/// this interface should be called from the main thread for the environment. To get an environment +/// to service another thread, call Fork and pass the return result to that thread. +/// +public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider +{ + /// + /// Create a host with the given registration name. + /// + IHost Register(string name, int? seed = null, bool? verbose = null); + + /// + /// The catalog of loadable components () that are available in this host. + /// + ComponentCatalog ComponentCatalog { get; } +} + +[BestFriend] +internal interface ICancelable +{ + /// + /// Signal to stop execution in all the hosts. + /// + void CancelExecution(); + + /// + /// Flag which indicates host execution has been stopped. + /// + bool IsCanceled { get; } +} + +[BestFriend] +internal interface IHostEnvironmentInternal : IHostEnvironment +{ + /// + /// The seed property that, if assigned, makes components requiring randomness behave deterministically. + /// + int? Seed { get; } + + /// + /// The location for the temp files created by ML.NET + /// + string TempFilePath { get; set; } + + /// + /// Allow falling back to run on CPU if couldn't run on GPU. + /// + bool FallbackToCpu { get; set; } + + /// + /// GPU device ID to run execution on, to run on CPU. + /// + int? GpuDeviceId { get; set; } +} + +/// +/// A host is coupled to a component and provides random number generation and concurrency guidance. +/// Note that the random number generation, like the host environment methods, should be accessed only +/// from the main thread for the component. +/// +public interface IHost : IHostEnvironment +{ + /// + /// The random number generator issued to this component. Note that random number + /// generators are NOT thread safe. + /// + Random Rand { get; } +} + +/// +/// A generic information pipe. Note that pipes are disposable. Generally, Done should +/// be called before disposing to signal a normal shut-down of the pipe, as opposed +/// to an aborted completion. +/// +public interface IPipe : IExceptionContext, IDisposable +{ + /// + /// The caller relinquishes ownership of the object. + /// + void Send(TMessage msg); +} + +/// +/// The kinds of standard channel messages. +/// Note: These values should never be changed. We can add new kinds, but don't change these values. +/// Other code bases, including native code for other projects depends on these values. +/// +public enum ChannelMessageKind +{ + Trace = 0, + Info = 1, + Warning = 2, + Error = 3 +} + +/// +/// A flag that can be attached to a message or exception to indicate that +/// it has a certain class of sensitive data. By default, messages should be +/// specified as being of unknown sensitivity, which is to say, every +/// sensitivity flag is turned on, corresponding to . +/// Messages that are totally safe should be marked as . +/// However, if, say, one prints out data from a file (for example, this might +/// be done when expressing parse errors), it should be flagged in that case +/// with . +/// +[Flags] +public enum MessageSensitivity +{ + /// + /// For non-sensitive data. + /// + None = 0, + + /// + /// For messages that may contain user-data from data files. + /// + UserData = 0x1, + + /// + /// For messages that contain information like column names from datasets. + /// Note that, despite being part of the schema, annotations should be treated + /// as user data, since it is often derived from user data. Note also that + /// types, despite being part of the schema, are not considered "sensitive" + /// as such, in the same way that column names might be. + /// + Schema = 0x2, + + // REVIEW: Other potentially sensitive things might include + // stack traces in certain environments. + + /// + /// The default value, unknown, is treated as if everything is sensitive. + /// + Unknown = ~None, + + /// + /// An alias for , so it is functionally the same, except + /// semantically it communicates the idea that we want all bits set. + /// + All = Unknown, +} + +/// +/// A channel message. +/// +public readonly struct ChannelMessage +{ + public readonly ChannelMessageKind Kind; + public readonly MessageSensitivity Sensitivity; + private readonly string _message; + private readonly object[] _args; + + /// + /// Line endings may not be normalized. + /// + public string Message => _args != null ? string.Format(_message, _args) : _message; [BestFriend] - internal interface ICancelable + internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message) { - /// - /// Signal to stop execution in all the hosts. - /// - void CancelExecution(); - - /// - /// Flag which indicates host execution has been stopped. - /// - bool IsCanceled { get; } + Contracts.CheckNonEmpty(message, nameof(message)); + Kind = kind; + Sensitivity = sensitivity; + _message = message; + _args = null; } [BestFriend] - internal interface IHostEnvironmentInternal : IHostEnvironment + internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args) { - /// - /// The seed property that, if assigned, makes components requiring randomness behave deterministically. - /// - int? Seed { get; } - - /// - /// The location for the temp files created by ML.NET - /// - string TempFilePath { get; set; } - - /// - /// Allow falling back to run on CPU if couldn't run on GPU. - /// - bool FallbackToCpu { get; set; } - - /// - /// GPU device ID to run execution on, to run on CPU. - /// - int? GpuDeviceId { get; set; } + Contracts.CheckNonEmpty(fmt, nameof(fmt)); + Contracts.CheckNonEmpty(args, nameof(args)); + Kind = kind; + Sensitivity = sensitivity; + _message = fmt; + _args = args; } +} - /// - /// A host is coupled to a component and provides random number generation and concurrency guidance. - /// Note that the random number generation, like the host environment methods, should be accessed only - /// from the main thread for the component. - /// - public interface IHost : IHostEnvironment +/// +/// A standard communication channel. +/// +public interface IChannel : IPipe +{ + void Trace(MessageSensitivity sensitivity, string fmt); + void Trace(MessageSensitivity sensitivity, string fmt, params object[] args); + void Error(MessageSensitivity sensitivity, string fmt); + void Error(MessageSensitivity sensitivity, string fmt, params object[] args); + void Warning(MessageSensitivity sensitivity, string fmt); + void Warning(MessageSensitivity sensitivity, string fmt, params object[] args); + void Info(MessageSensitivity sensitivity, string fmt); + void Info(MessageSensitivity sensitivity, string fmt, params object[] args); +} + +/// +/// General utility extension methods for objects in the "host" universe, i.e., +/// , , and +/// that do not belong in more specific areas, for example, or +/// component creation. +/// +[BestFriend] +internal static class HostExtensions +{ + public static T Apply(this IHost host, string channelName, Func func) { - /// - /// The random number generator issued to this component. Note that random number - /// generators are NOT thread safe. - /// - Random Rand { get; } + T t; + using (var ch = host.Start(channelName)) + { + t = func(ch); + } + return t; } /// - /// A generic information pipe. Note that pipes are disposable. Generally, Done should - /// be called before disposing to signal a normal shut-down of the pipe, as opposed - /// to an aborted completion. + /// Convenience variant of + /// setting . /// - public interface IPipe : IExceptionContext, IDisposable - { - /// - /// The caller relinquishes ownership of the object. - /// - void Send(TMessage msg); - } + public static void Trace(this IChannel ch, string fmt) + => ch.Trace(MessageSensitivity.Unknown, fmt); /// - /// The kinds of standard channel messages. - /// Note: These values should never be changed. We can add new kinds, but don't change these values. - /// Other code bases, including native code for other projects depends on these values. + /// Convenience variant of + /// setting . /// - public enum ChannelMessageKind - { - Trace = 0, - Info = 1, - Warning = 2, - Error = 3 - } + public static void Trace(this IChannel ch, string fmt, params object[] args) + => ch.Trace(MessageSensitivity.Unknown, fmt, args); /// - /// A flag that can be attached to a message or exception to indicate that - /// it has a certain class of sensitive data. By default, messages should be - /// specified as being of unknown sensitivity, which is to say, every - /// sensitivity flag is turned on, corresponding to . - /// Messages that are totally safe should be marked as . - /// However, if, say, one prints out data from a file (for example, this might - /// be done when expressing parse errors), it should be flagged in that case - /// with . + /// Convenience variant of + /// setting . /// - [Flags] - public enum MessageSensitivity - { - /// - /// For non-sensitive data. - /// - None = 0, - - /// - /// For messages that may contain user-data from data files. - /// - UserData = 0x1, - - /// - /// For messages that contain information like column names from datasets. - /// Note that, despite being part of the schema, annotations should be treated - /// as user data, since it is often derived from user data. Note also that - /// types, despite being part of the schema, are not considered "sensitive" - /// as such, in the same way that column names might be. - /// - Schema = 0x2, - - // REVIEW: Other potentially sensitive things might include - // stack traces in certain environments. - - /// - /// The default value, unknown, is treated as if everything is sensitive. - /// - Unknown = ~None, - - /// - /// An alias for , so it is functionally the same, except - /// semantically it communicates the idea that we want all bits set. - /// - All = Unknown, - } + public static void Error(this IChannel ch, string fmt) + => ch.Error(MessageSensitivity.Unknown, fmt); /// - /// A channel message. + /// Convenience variant of + /// setting . /// - public readonly struct ChannelMessage - { - public readonly ChannelMessageKind Kind; - public readonly MessageSensitivity Sensitivity; - private readonly string _message; - private readonly object[] _args; - - /// - /// Line endings may not be normalized. - /// - public string Message => _args != null ? string.Format(_message, _args) : _message; - - [BestFriend] - internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message) - { - Contracts.CheckNonEmpty(message, nameof(message)); - Kind = kind; - Sensitivity = sensitivity; - _message = message; - _args = null; - } + public static void Error(this IChannel ch, string fmt, params object[] args) + => ch.Error(MessageSensitivity.Unknown, fmt, args); - [BestFriend] - internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args) - { - Contracts.CheckNonEmpty(fmt, nameof(fmt)); - Contracts.CheckNonEmpty(args, nameof(args)); - Kind = kind; - Sensitivity = sensitivity; - _message = fmt; - _args = args; - } - } + /// + /// Convenience variant of + /// setting . + /// + public static void Warning(this IChannel ch, string fmt) + => ch.Warning(MessageSensitivity.Unknown, fmt); /// - /// A standard communication channel. + /// Convenience variant of + /// setting . /// - public interface IChannel : IPipe - { - void Trace(MessageSensitivity sensitivity, string fmt); - void Trace(MessageSensitivity sensitivity, string fmt, params object[] args); - void Error(MessageSensitivity sensitivity, string fmt); - void Error(MessageSensitivity sensitivity, string fmt, params object[] args); - void Warning(MessageSensitivity sensitivity, string fmt); - void Warning(MessageSensitivity sensitivity, string fmt, params object[] args); - void Info(MessageSensitivity sensitivity, string fmt); - void Info(MessageSensitivity sensitivity, string fmt, params object[] args); - } + public static void Warning(this IChannel ch, string fmt, params object[] args) + => ch.Warning(MessageSensitivity.Unknown, fmt, args); /// - /// General utility extension methods for objects in the "host" universe, i.e., - /// , , and - /// that do not belong in more specific areas, for example, or - /// component creation. + /// Convenience variant of + /// setting . /// - [BestFriend] - internal static class HostExtensions - { - public static T Apply(this IHost host, string channelName, Func func) - { - T t; - using (var ch = host.Start(channelName)) - { - t = func(ch); - } - return t; - } + public static void Info(this IChannel ch, string fmt) + => ch.Info(MessageSensitivity.Unknown, fmt); - /// - /// Convenience variant of - /// setting . - /// - public static void Trace(this IChannel ch, string fmt) - => ch.Trace(MessageSensitivity.Unknown, fmt); - - /// - /// Convenience variant of - /// setting . - /// - public static void Trace(this IChannel ch, string fmt, params object[] args) - => ch.Trace(MessageSensitivity.Unknown, fmt, args); - - /// - /// Convenience variant of - /// setting . - /// - public static void Error(this IChannel ch, string fmt) - => ch.Error(MessageSensitivity.Unknown, fmt); - - /// - /// Convenience variant of - /// setting . - /// - public static void Error(this IChannel ch, string fmt, params object[] args) - => ch.Error(MessageSensitivity.Unknown, fmt, args); - - /// - /// Convenience variant of - /// setting . - /// - public static void Warning(this IChannel ch, string fmt) - => ch.Warning(MessageSensitivity.Unknown, fmt); - - /// - /// Convenience variant of - /// setting . - /// - public static void Warning(this IChannel ch, string fmt, params object[] args) - => ch.Warning(MessageSensitivity.Unknown, fmt, args); - - /// - /// Convenience variant of - /// setting . - /// - public static void Info(this IChannel ch, string fmt) - => ch.Info(MessageSensitivity.Unknown, fmt); - - /// - /// Convenience variant of - /// setting . - /// - public static void Info(this IChannel ch, string fmt, params object[] args) - => ch.Info(MessageSensitivity.Unknown, fmt, args); - } + /// + /// Convenience variant of + /// setting . + /// + public static void Info(this IChannel ch, string fmt, params object[] args) + => ch.Info(MessageSensitivity.Unknown, fmt, args); } diff --git a/src/Microsoft.ML.Core/Data/IProgressChannel.cs b/src/Microsoft.ML.Core/Data/IProgressChannel.cs index 1cff123f69..ccbbb2a6d7 100644 --- a/src/Microsoft.ML.Core/Data/IProgressChannel.cs +++ b/src/Microsoft.ML.Core/Data/IProgressChannel.cs @@ -5,140 +5,139 @@ using System; using System.Collections.Generic; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +/// +/// This is a factory interface for . +/// Both and implement this interface, +/// to allow for nested progress reporters. +/// +/// REVIEW: make implement this, instead of the environment? +/// +public interface IProgressChannelProvider { /// - /// This is a factory interface for . - /// Both and implement this interface, - /// to allow for nested progress reporters. - /// - /// REVIEW: make implement this, instead of the environment? + /// Create a progress channel for a computation named . /// - public interface IProgressChannelProvider - { - /// - /// Create a progress channel for a computation named . - /// - IProgressChannel StartProgressChannel(string name); - } + IProgressChannel StartProgressChannel(string name); +} +/// +/// A common interface for progress reporting. +/// It is expected that the progress channel interface is used from only one thread. +/// +/// Supported workflow: +/// 1) Create the channel via . +/// 2) Call as many times as desired (including 0). +/// Each call to supersedes the previous one. +/// 3) Report checkpoints (0 or more) by calling . +/// 4) Repeat steps 2-3 as often as necessary. +/// 5) Dispose the channel. +/// +public interface IProgressChannel : IProgressChannelProvider, IDisposable +{ /// - /// A common interface for progress reporting. - /// It is expected that the progress channel interface is used from only one thread. + /// Set up the reporting structure: + /// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported. + /// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress. /// - /// Supported workflow: - /// 1) Create the channel via . - /// 2) Call as many times as desired (including 0). - /// Each call to supersedes the previous one. - /// 3) Report checkpoints (0 or more) by calling . - /// 4) Repeat steps 2-3 as often as necessary. - /// 5) Dispose the channel. + /// It is acceptable to call multiple times (or none), regardless of whether the calculation is running + /// or not. Because of synchronization, the computation should not deny calls to the 'old' + /// delegates even after a new one is provided. /// - public interface IProgressChannel : IProgressChannelProvider, IDisposable - { - /// - /// Set up the reporting structure: - /// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported. - /// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress. - /// - /// It is acceptable to call multiple times (or none), regardless of whether the calculation is running - /// or not. Because of synchronization, the computation should not deny calls to the 'old' - /// delegates even after a new one is provided. - /// - /// The header object. - /// The delegate to provide actual progress. The parameter of - /// the delegate will correspond to the provided . - void SetHeader(ProgressHeader header, Action fillAction); - - /// - /// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener, - /// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated - /// at certain moments ('checkpoints') of the computation. - /// - /// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint - /// every iteration. - /// - /// The only parameter, , is interpreted in the following fashion: - /// * First MetricNames.Length items, if present, are metrics. - /// * Subsequent ProgressNames.Length items, if present, are progress units. - /// * Subsequent ProgressNames.Length items, if present, are progress limits. - /// * If any more values remain, an exception is thrown. - /// - /// The metrics, progress units and progress limits. - void Checkpoint(params Double?[] values); - } + /// The header object. + /// The delegate to provide actual progress. The parameter of + /// the delegate will correspond to the provided . + void SetHeader(ProgressHeader header, Action fillAction); /// - /// This is the 'header' of the progress report. + /// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener, + /// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated + /// at certain moments ('checkpoints') of the computation. + /// + /// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint + /// every iteration. + /// + /// The only parameter, , is interpreted in the following fashion: + /// * First MetricNames.Length items, if present, are metrics. + /// * Subsequent ProgressNames.Length items, if present, are progress units. + /// * Subsequent ProgressNames.Length items, if present, are progress limits. + /// * If any more values remain, an exception is thrown. /// - public sealed class ProgressHeader - { - /// - /// These are the names of the progress 'units', from the least granular to the most granular. - /// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}. - /// Will never be null, but can be empty. - /// - public readonly IReadOnlyList UnitNames; + /// The metrics, progress units and progress limits. + void Checkpoint(params Double?[] values); +} - /// - /// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc. - /// Will never be null, but can be empty. - /// - public readonly IReadOnlyList MetricNames; +/// +/// This is the 'header' of the progress report. +/// +public sealed class ProgressHeader +{ + /// + /// These are the names of the progress 'units', from the least granular to the most granular. + /// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}. + /// Will never be null, but can be empty. + /// + public readonly IReadOnlyList UnitNames; - /// - /// Initialize the header. This will take ownership of the arrays. - /// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report - /// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any - /// progress or metrics to report, it is always better to report them. - /// - /// The metrics that the calculation reports. These are completely independent, and there - /// is no contract on whether the metric values should increase or not. As naming convention, - /// can have multiple words with spaces, and should be title-cased. - /// The names of the progress units, listed from least granular to most granular. - /// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.). - /// As naming convention, should be lower-cased and typically plural - /// (for example, iterations, clusters, examples). - public ProgressHeader(string[] metricNames, string[] unitNames) - { - Contracts.CheckValueOrNull(unitNames); - Contracts.CheckValueOrNull(metricNames); + /// + /// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc. + /// Will never be null, but can be empty. + /// + public readonly IReadOnlyList MetricNames; - UnitNames = unitNames ?? new string[0]; - MetricNames = metricNames ?? new string[0]; - } + /// + /// Initialize the header. This will take ownership of the arrays. + /// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report + /// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any + /// progress or metrics to report, it is always better to report them. + /// + /// The metrics that the calculation reports. These are completely independent, and there + /// is no contract on whether the metric values should increase or not. As naming convention, + /// can have multiple words with spaces, and should be title-cased. + /// The names of the progress units, listed from least granular to most granular. + /// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.). + /// As naming convention, should be lower-cased and typically plural + /// (for example, iterations, clusters, examples). + public ProgressHeader(string[] metricNames, string[] unitNames) + { + Contracts.CheckValueOrNull(unitNames); + Contracts.CheckValueOrNull(metricNames); - /// - /// A constructor for no metrics, just progress units. As naming convention, should be lower-cased - /// and typically plural (for example, iterations, clusters, examples). - /// - public ProgressHeader(params string[] unitNames) - : this(null, unitNames) - { - } + UnitNames = unitNames ?? new string[0]; + MetricNames = metricNames ?? new string[0]; } /// - /// A metric/progress holder item. + /// A constructor for no metrics, just progress units. As naming convention, should be lower-cased + /// and typically plural (for example, iterations, clusters, examples). /// - public interface IProgressEntry + public ProgressHeader(params string[] unitNames) + : this(null, unitNames) { - /// - /// Set the progress value for the index to , - /// and the limit value for the progress becomes 'unknown'. - /// - void SetProgress(int index, Double value); + } +} + +/// +/// A metric/progress holder item. +/// +public interface IProgressEntry +{ + /// + /// Set the progress value for the index to , + /// and the limit value for the progress becomes 'unknown'. + /// + void SetProgress(int index, Double value); - /// - /// Set the progress value for the index to , - /// and the limit value to . If is a NAN, it is set to null instead. - /// - void SetProgress(int index, Double value, Double lim); + /// + /// Set the progress value for the index to , + /// and the limit value to . If is a NAN, it is set to null instead. + /// + void SetProgress(int index, Double value, Double lim); - /// - /// Sets the metric with index to . - /// - void SetMetric(int index, Double value); + /// + /// Sets the metric with index to . + /// + void SetMetric(int index, Double value); - } } diff --git a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs index f98fe70ff5..b48d55f64f 100644 --- a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs +++ b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs @@ -5,47 +5,46 @@ using System; using System.Collections.Generic; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// This interface maps an input to an output . Typically, the output contains +/// both the input columns and new columns added by the implementing class, although some implementations may +/// return a subset of the input columns. +/// This interface is similar to , except it does not have any input role mappings, +/// so to rebind, the same input column names must be used. +/// Implementations of this interface are typically created over defined input . +/// +public interface IRowToRowMapper { /// - /// This interface maps an input to an output . Typically, the output contains - /// both the input columns and new columns added by the implementing class, although some implementations may - /// return a subset of the input columns. - /// This interface is similar to , except it does not have any input role mappings, - /// so to rebind, the same input column names must be used. - /// Implementations of this interface are typically created over defined input . + /// Mappers are defined as accepting inputs with this very specific schema. /// - public interface IRowToRowMapper - { - /// - /// Mappers are defined as accepting inputs with this very specific schema. - /// - DataViewSchema InputSchema { get; } + DataViewSchema InputSchema { get; } - /// - /// Gets an instance of which describes the columns' names and types in the output generated by this mapper. - /// - DataViewSchema OutputSchema { get; } + /// + /// Gets an instance of which describes the columns' names and types in the output generated by this mapper. + /// + DataViewSchema OutputSchema { get; } - /// - /// Given a set of columns, return the input columns that are needed to generate those output columns. - /// - IEnumerable GetDependencies(IEnumerable dependingColumns); + /// + /// Given a set of columns, return the input columns that are needed to generate those output columns. + /// + IEnumerable GetDependencies(IEnumerable dependingColumns); - /// - /// Get an with the indicated active columns, based on the input . - /// Getting values on inactive columns of the returned row will throw. - /// - /// The of should be the same object as - /// . Implementors of this method should throw if that is not the case. Conversely, - /// the returned value must have the same schema as . - /// - /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the - /// getters of the input row and base the output values on the current values of the input . - /// The output values are re-computed when requested through the getters. Also, the returned - /// will dispose when it is disposed. - /// - DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns); - } + /// + /// Get an with the indicated active columns, based on the input . + /// Getting values on inactive columns of the returned row will throw. + /// + /// The of should be the same object as + /// . Implementors of this method should throw if that is not the case. Conversely, + /// the returned value must have the same schema as . + /// + /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the + /// getters of the input row and base the output values on the current values of the input . + /// The output values are re-computed when requested through the getters. Also, the returned + /// will dispose when it is disposed. + /// + DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns); } diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs index 2396966a6c..58605b050b 100644 --- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs +++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs @@ -6,85 +6,84 @@ using System.Collections.Generic; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// A mapper that can be bound to a (which encapsulates a and has mappings from column kinds +/// to columns of that schema). Binding an to a produces an +/// , which is an interface that has methods to return the names and indices of the input columns +/// needed by the mapper to compute its output. The is an extention to this interface, that +/// can also produce an output given an input . The produced generally contains only the output columns of the mapper, and not +/// the input columns (but there is nothing preventing an from mapping input columns directly to outputs). +/// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single +/// features column. New predictors can implement directly. Implementing +/// includes implementing a corresponding (or ) and a corresponding ISchema +/// for the output schema of the . In case the interface is implemented, +/// the SimpleRow class can be used in the method. +/// +[BestFriend] +internal interface ISchemaBindableMapper +{ + ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); +} + +/// +/// This interface is used to map a schema from input columns to output columns. The should keep track +/// of the input columns that are needed for the mapping. +/// +[BestFriend] +internal interface ISchemaBoundMapper { /// - /// A mapper that can be bound to a (which encapsulates a and has mappings from column kinds - /// to columns of that schema). Binding an to a produces an - /// , which is an interface that has methods to return the names and indices of the input columns - /// needed by the mapper to compute its output. The is an extention to this interface, that - /// can also produce an output given an input . The produced generally contains only the output columns of the mapper, and not - /// the input columns (but there is nothing preventing an from mapping input columns directly to outputs). - /// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single - /// features column. New predictors can implement directly. Implementing - /// includes implementing a corresponding (or ) and a corresponding ISchema - /// for the output schema of the . In case the interface is implemented, - /// the SimpleRow class can be used in the method. + /// The that was passed to the in the binding process. /// - [BestFriend] - internal interface ISchemaBindableMapper - { - ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); - } + RoleMappedSchema InputRoleMappedSchema { get; } /// - /// This interface is used to map a schema from input columns to output columns. The should keep track - /// of the input columns that are needed for the mapping. + /// Gets schema of this mapper's output. /// - [BestFriend] - internal interface ISchemaBoundMapper - { - /// - /// The that was passed to the in the binding process. - /// - RoleMappedSchema InputRoleMappedSchema { get; } + DataViewSchema OutputSchema { get; } - /// - /// Gets schema of this mapper's output. - /// - DataViewSchema OutputSchema { get; } - - /// - /// A property to get back the that produced this . - /// - ISchemaBindableMapper Bindable { get; } + /// + /// A property to get back the that produced this . + /// + ISchemaBindableMapper Bindable { get; } - /// - /// This method returns the binding information: which input columns are used and in what roles. - /// - IEnumerable> GetInputColumnRoles(); - } + /// + /// This method returns the binding information: which input columns are used and in what roles. + /// + IEnumerable> GetInputColumnRoles(); +} +/// +/// This interface extends . +/// +[BestFriend] +internal interface ISchemaBoundRowMapper : ISchemaBoundMapper +{ /// - /// This interface extends . + /// Input schema accepted. /// - [BestFriend] - internal interface ISchemaBoundRowMapper : ISchemaBoundMapper - { - /// - /// Input schema accepted. - /// - DataViewSchema InputSchema { get; } + DataViewSchema InputSchema { get; } - /// - /// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns. - /// - IEnumerable GetDependenciesForNewColumns(IEnumerable dependingColumns); + /// + /// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns. + /// + IEnumerable GetDependenciesForNewColumns(IEnumerable dependingColumns); - /// - /// Get an with the indicated active columns, based on the input . - /// Getting values on inactive columns of the returned row will throw. - /// - /// The of should be the same object as - /// . Implementors of this method should throw if that is not the case. Conversely, - /// the returned value must have the same schema as . - /// - /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the - /// getters of the input row and base the output values on the current values of the input . - /// The output values are re-computed when requested through the getters. Also, the returned - /// will dispose when it is disposed. - /// - DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns); - } + /// + /// Get an with the indicated active columns, based on the input . + /// Getting values on inactive columns of the returned row will throw. + /// + /// The of should be the same object as + /// . Implementors of this method should throw if that is not the case. Conversely, + /// the returned value must have the same schema as . + /// + /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the + /// getters of the input row and base the output values on the current values of the input . + /// The output values are re-computed when requested through the getters. Also, the returned + /// will dispose when it is disposed. + /// + DataViewRow GetRow(DataViewRow input, IEnumerable activeColumns); } diff --git a/src/Microsoft.ML.Core/Data/IValueMapper.cs b/src/Microsoft.ML.Core/Data/IValueMapper.cs index 6dfea747ca..34a64c5fcf 100644 --- a/src/Microsoft.ML.Core/Data/IValueMapper.cs +++ b/src/Microsoft.ML.Core/Data/IValueMapper.cs @@ -2,57 +2,56 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Data -{ - /// - /// Delegate type to map/convert a value. - /// - [BestFriend] - internal delegate void ValueMapper(in TSrc src, ref TDst dst); +namespace Microsoft.ML.Data; - /// - /// Delegate type to map/convert among three values, for example, one input with two - /// outputs, or two inputs with one output. - /// - [BestFriend] - internal delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3); +/// +/// Delegate type to map/convert a value. +/// +[BestFriend] +internal delegate void ValueMapper(in TSrc src, ref TDst dst); + +/// +/// Delegate type to map/convert among three values, for example, one input with two +/// outputs, or two inputs with one output. +/// +[BestFriend] +internal delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3); + +/// +/// Interface for mapping a single input value (of an indicated ColumnType) to +/// an output value (of an indicated ColumnType). This interface is commonly implemented +/// by predictors. Note that the input and output ColumnTypes determine the proper +/// type arguments for GetMapper, but typically contain additional information like +/// vector lengths. +/// +[BestFriend] +internal interface IValueMapper +{ + DataViewType InputType { get; } + DataViewType OutputType { get; } /// - /// Interface for mapping a single input value (of an indicated ColumnType) to - /// an output value (of an indicated ColumnType). This interface is commonly implemented - /// by predictors. Note that the input and output ColumnTypes determine the proper - /// type arguments for GetMapper, but typically contain additional information like - /// vector lengths. + /// Get a delegate used for mapping from input to output values. Note that the delegate + /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency. /// - [BestFriend] - internal interface IValueMapper - { - DataViewType InputType { get; } - DataViewType OutputType { get; } + ValueMapper GetMapper(); +} - /// - /// Get a delegate used for mapping from input to output values. Note that the delegate - /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency. - /// - ValueMapper GetMapper(); - } +/// +/// Interface for mapping a single input value (of an indicated ColumnType) to an output value +/// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented +/// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper +/// type arguments for GetMapper, but typically contain additional information like +/// vector lengths. +/// +[BestFriend] +internal interface IValueMapperDist : IValueMapper +{ + DataViewType DistType { get; } /// - /// Interface for mapping a single input value (of an indicated ColumnType) to an output value - /// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented - /// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper - /// type arguments for GetMapper, but typically contain additional information like - /// vector lengths. + /// Get a delegate used for mapping from input to output values. Note that the delegate + /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency. /// - [BestFriend] - internal interface IValueMapperDist : IValueMapper - { - DataViewType DistType { get; } - - /// - /// Get a delegate used for mapping from input to output values. Note that the delegate - /// should only be used on a single thread - it should NOT be assumed to be safe for concurrency. - /// - ValueMapper GetMapper(); - } + ValueMapper GetMapper(); } diff --git a/src/Microsoft.ML.Core/Data/InPredicate.cs b/src/Microsoft.ML.Core/Data/InPredicate.cs index 8e0c27e070..bf8535b1b1 100644 --- a/src/Microsoft.ML.Core/Data/InPredicate.cs +++ b/src/Microsoft.ML.Core/Data/InPredicate.cs @@ -2,8 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Data -{ - [BestFriend] - internal delegate bool InPredicate(in T value); -} +namespace Microsoft.ML.Data; + +[BestFriend] +internal delegate bool InPredicate(in T value); diff --git a/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs b/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs index 096e36fdf0..10b5381355 100644 --- a/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs +++ b/src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs @@ -4,21 +4,20 @@ using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Extension methods related to the class. +/// +[BestFriend] +internal static class KeyTypeExtensions { /// - /// Extension methods related to the class. + /// Sometimes it is necessary to cast the Count to an int. This performs overflow check. /// - [BestFriend] - internal static class KeyTypeExtensions + public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null) { - /// - /// Sometimes it is necessary to cast the Count to an int. This performs overflow check. - /// - public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null) - { - ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue."); - return (int)key.Count; - } + ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue."); + return (int)key.Count; } } diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs index 319dbbbbb8..872b339c0c 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs @@ -4,50 +4,49 @@ using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Base class for a cursor has an input cursor, but still needs to do work on . +/// +[BestFriend] +internal abstract class LinkedRootCursorBase : RootCursorBase { + + /// Gets the input cursor. + protected DataViewRowCursor Input { get; } + /// - /// Base class for a cursor has an input cursor, but still needs to do work on . + /// Returns the root cursor of the input. It should be used to perform + /// operations, but with the distinction, as compared to , that this is not + /// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is + /// on filter cursor implementations, where how that input cursor is consumed is very implementation specific. + /// That is why this is , not . /// - [BestFriend] - internal abstract class LinkedRootCursorBase : RootCursorBase - { + protected DataViewRowCursor Root { get; } - /// Gets the input cursor. - protected DataViewRowCursor Input { get; } + private bool _disposed; - /// - /// Returns the root cursor of the input. It should be used to perform - /// operations, but with the distinction, as compared to , that this is not - /// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is - /// on filter cursor implementations, where how that input cursor is consumed is very implementation specific. - /// That is why this is , not . - /// - protected DataViewRowCursor Root { get; } + protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input) + : base(provider) + { + Ch.AssertValue(input, nameof(input)); - private bool _disposed; + Input = input; + Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input; + } - protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input) - : base(provider) + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) { - Ch.AssertValue(input, nameof(input)); + Input.Dispose(); + // The base class should set the state to done under these circumstances. - Input = input; - Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input; - } - - protected override void Dispose(bool disposing) - { - if (_disposed) - return; - if (disposing) - { - Input.Dispose(); - // The base class should set the state to done under these circumstances. - - } - _disposed = true; - base.Dispose(disposing); } + _disposed = true; + base.Dispose(disposing); } } diff --git a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs index 670f37ffe7..e618ca2c62 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs @@ -4,40 +4,39 @@ using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Base class for creating a cursor of rows that filters out some input rows. +/// +[BestFriend] +internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase { - /// - /// Base class for creating a cursor of rows that filters out some input rows. - /// - [BestFriend] - internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase - { - public override long Batch => Input.Batch; + public override long Batch => Input.Batch; - protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active) - : base(provider, input, schema, active) - { - } + protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active) + : base(provider, input, schema, active) + { + } - public override ValueGetter GetIdGetter() - { - return Input.GetIdGetter(); - } + public override ValueGetter GetIdGetter() + { + return Input.GetIdGetter(); + } - protected override bool MoveNextCore() + protected override bool MoveNextCore() + { + while (Root.MoveNext()) { - while (Root.MoveNext()) - { - if (Accept()) - return true; - } - - return false; + if (Accept()) + return true; } - /// - /// Return whether the current input row should be returned by this cursor. - /// - protected abstract bool Accept(); + return false; } + + /// + /// Return whether the current input row should be returned by this cursor. + /// + protected abstract bool Accept(); } diff --git a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs index e1bca40a1a..2771f952c6 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs @@ -4,50 +4,49 @@ using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// A base class for a that has an input cursor, but still needs to do work on +/// . Note that the default +/// assumes that each input column is exposed as an +/// output column with the same column index. +/// +[BestFriend] +internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase { - /// - /// A base class for a that has an input cursor, but still needs to do work on - /// . Note that the default - /// assumes that each input column is exposed as an - /// output column with the same column index. - /// - [BestFriend] - internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase - { - private readonly bool[] _active; + private readonly bool[] _active; - /// Gets row's schema. - public sealed override DataViewSchema Schema { get; } + /// Gets row's schema. + public sealed override DataViewSchema Schema { get; } - protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active) - : base(provider, input) - { - Ch.CheckValue(schema, nameof(schema)); - Ch.Check(active == null || active.Length == schema.Count); - _active = active; - Schema = schema; - } + protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active) + : base(provider, input) + { + Ch.CheckValue(schema, nameof(schema)); + Ch.Check(active == null || active.Length == schema.Count); + _active = active; + Schema = schema; + } - /// - /// Returns whether the given column is active in this row. - /// - public sealed override bool IsColumnActive(DataViewSchema.Column column) - { - Ch.Check(column.Index < Schema.Count); - return _active == null || _active[column.Index]; - } + /// + /// Returns whether the given column is active in this row. + /// + public sealed override bool IsColumnActive(DataViewSchema.Column column) + { + Ch.Check(column.Index < Schema.Count); + return _active == null || _active[column.Index]; + } - /// - /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. - /// This throws if the column is not active in this row, or if the type - /// differs from this column's type. - /// - /// is the column's content type. - /// is the output column whose getter should be returned. - public override ValueGetter GetGetter(DataViewSchema.Column column) - { - return Input.GetGetter(column); - } + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + return Input.GetGetter(column); } } diff --git a/src/Microsoft.ML.Core/Data/ModelHeader.cs b/src/Microsoft.ML.Core/Data/ModelHeader.cs index 385fb6d068..01a694b501 100644 --- a/src/Microsoft.ML.Core/Data/ModelHeader.cs +++ b/src/Microsoft.ML.Core/Data/ModelHeader.cs @@ -9,654 +9,653 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +[BestFriend] +[StructLayout(LayoutKind.Explicit, Size = Size)] +internal struct ModelHeader { - [BestFriend] - [StructLayout(LayoutKind.Explicit, Size = Size)] - internal struct ModelHeader + /// + /// This spells 'ML MODEL' with zero replacing space (assuming little endian). + /// + public const ulong SignatureValue = 0x4C45444F4D004C4DUL; + public const ulong TailSignatureValue = 0x4D4C004D4F44454CUL; + + private const uint VerAssemblyNameSupported = 0x00010002; + + // These are private since they change over time. If we make them public we risk + // another assembly containing a "copy" of their value when the other assembly + // was compiled, which might not match the code that can load this. + //private const uint VerWrittenCur = 0x00010001; // Initial + private const uint VerWrittenCur = 0x00010002; // Added AssemblyName + private const uint VerReadableCur = 0x00010002; + private const uint VerWeCanReadBack = 0x00010001; + + [FieldOffset(0x00)] + public ulong Signature; + [FieldOffset(0x08)] + public uint VerWritten; + [FieldOffset(0x0C)] + public uint VerReadable; + + // Location and size (in bytes) of the model block. Note that it is legal for CbModel to be zero. + [FieldOffset(0x10)] + public long FpModel; + [FieldOffset(0x18)] + public long CbModel; + + // Location and size (in bytes) of the string table block. If there are no strings, these are both zero. + // If there are n strings then CbStringTable is n * sizeof(long), so is divisible by sizeof(long). + // Each long is the offset from header.FpStringChars of the "lim" of the characters for that string. + // The "min" is the "lim" for the previous string. The 0th string's "min" is offset zero. + [FieldOffset(0x20)] + public long FpStringTable; + [FieldOffset(0x28)] + public long CbStringTable; + + // Location and size (in bytes) of the string characters, without any prefix or termination. The characters + // unicode (UTF-16). + [FieldOffset(0x30)] + public long FpStringChars; + [FieldOffset(0x38)] + public long CbStringChars; + + // ModelSignature specifies the format of the model block. These values are assigned by + // the code that writes the model block. + [FieldOffset(0x40)] + public ulong ModelSignature; + [FieldOffset(0x48)] + public uint ModelVerWritten; + [FieldOffset(0x4C)] + public uint ModelVerReadable; + + // These encode up to two loader signature strings. These are up to 24 ascii characters each. + [FieldOffset(0x50)] + public ulong LoaderSignature0; + [FieldOffset(0x58)] + public ulong LoaderSignature1; + [FieldOffset(0x60)] + public ulong LoaderSignature2; + [FieldOffset(0x68)] + public ulong LoaderSignatureAlt0; + [FieldOffset(0x70)] + public ulong LoaderSignatureAlt1; + [FieldOffset(0x78)] + public ulong LoaderSignatureAlt2; + + // Location of the "tail" signature, which is simply the TailSignatureValue. + [FieldOffset(0x80)] + public long FpTail; + [FieldOffset(0x88)] + public long FpLim; + + // Location of the fully qualified assembly name string (in UTF-16). + // Note that it is legal for both to be zero. + [FieldOffset(0x90)] + public long FpAssemblyName; + [FieldOffset(0x98)] + public uint CbAssemblyName; + + public const int Size = 0x0100; + + // Utilities for writing. + + /// + /// Initialize the header and writer for writing. The value of fpMin and header + /// should be passed to the other utility methods here. + /// + public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHeader header) { - /// - /// This spells 'ML MODEL' with zero replacing space (assuming little endian). - /// - public const ulong SignatureValue = 0x4C45444F4D004C4DUL; - public const ulong TailSignatureValue = 0x4D4C004D4F44454CUL; - - private const uint VerAssemblyNameSupported = 0x00010002; - - // These are private since they change over time. If we make them public we risk - // another assembly containing a "copy" of their value when the other assembly - // was compiled, which might not match the code that can load this. - //private const uint VerWrittenCur = 0x00010001; // Initial - private const uint VerWrittenCur = 0x00010002; // Added AssemblyName - private const uint VerReadableCur = 0x00010002; - private const uint VerWeCanReadBack = 0x00010001; - - [FieldOffset(0x00)] - public ulong Signature; - [FieldOffset(0x08)] - public uint VerWritten; - [FieldOffset(0x0C)] - public uint VerReadable; - - // Location and size (in bytes) of the model block. Note that it is legal for CbModel to be zero. - [FieldOffset(0x10)] - public long FpModel; - [FieldOffset(0x18)] - public long CbModel; - - // Location and size (in bytes) of the string table block. If there are no strings, these are both zero. - // If there are n strings then CbStringTable is n * sizeof(long), so is divisible by sizeof(long). - // Each long is the offset from header.FpStringChars of the "lim" of the characters for that string. - // The "min" is the "lim" for the previous string. The 0th string's "min" is offset zero. - [FieldOffset(0x20)] - public long FpStringTable; - [FieldOffset(0x28)] - public long CbStringTable; - - // Location and size (in bytes) of the string characters, without any prefix or termination. The characters - // unicode (UTF-16). - [FieldOffset(0x30)] - public long FpStringChars; - [FieldOffset(0x38)] - public long CbStringChars; - - // ModelSignature specifies the format of the model block. These values are assigned by - // the code that writes the model block. - [FieldOffset(0x40)] - public ulong ModelSignature; - [FieldOffset(0x48)] - public uint ModelVerWritten; - [FieldOffset(0x4C)] - public uint ModelVerReadable; - - // These encode up to two loader signature strings. These are up to 24 ascii characters each. - [FieldOffset(0x50)] - public ulong LoaderSignature0; - [FieldOffset(0x58)] - public ulong LoaderSignature1; - [FieldOffset(0x60)] - public ulong LoaderSignature2; - [FieldOffset(0x68)] - public ulong LoaderSignatureAlt0; - [FieldOffset(0x70)] - public ulong LoaderSignatureAlt1; - [FieldOffset(0x78)] - public ulong LoaderSignatureAlt2; - - // Location of the "tail" signature, which is simply the TailSignatureValue. - [FieldOffset(0x80)] - public long FpTail; - [FieldOffset(0x88)] - public long FpLim; - - // Location of the fully qualified assembly name string (in UTF-16). - // Note that it is legal for both to be zero. - [FieldOffset(0x90)] - public long FpAssemblyName; - [FieldOffset(0x98)] - public uint CbAssemblyName; - - public const int Size = 0x0100; - - // Utilities for writing. - - /// - /// Initialize the header and writer for writing. The value of fpMin and header - /// should be passed to the other utility methods here. - /// - public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHeader header) - { - Contracts.Assert(Marshal.SizeOf(typeof(ModelHeader)) == Size); - Contracts.CheckValue(writer, nameof(writer)); - - fpMin = writer.FpCur(); - header = default(ModelHeader); - header.Signature = SignatureValue; - header.VerWritten = VerWrittenCur; - header.VerReadable = VerReadableCur; - header.FpModel = ModelHeader.Size; - - // Write a blank header - the correct information is written by WriteHeaderAndTail. - byte[] headerBytes = new byte[ModelHeader.Size]; - writer.Write(headerBytes); - Contracts.CheckIO(writer.FpCur() == fpMin + ModelHeader.Size); - } + Contracts.Assert(Marshal.SizeOf(typeof(ModelHeader)) == Size); + Contracts.CheckValue(writer, nameof(writer)); + + fpMin = writer.FpCur(); + header = default(ModelHeader); + header.Signature = SignatureValue; + header.VerWritten = VerWrittenCur; + header.VerReadable = VerReadableCur; + header.FpModel = ModelHeader.Size; + + // Write a blank header - the correct information is written by WriteHeaderAndTail. + byte[] headerBytes = new byte[ModelHeader.Size]; + writer.Write(headerBytes); + Contracts.CheckIO(writer.FpCur() == fpMin + ModelHeader.Size); + } - /// - /// The current writer position should be the end of the model blob. Records the model size, writes the string table, - /// completes and writes the header, and writes the tail. - /// - public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null) - { - Contracts.CheckValue(writer, nameof(writer)); - Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); - Contracts.CheckValueOrNull(pool); + /// + /// The current writer position should be the end of the model blob. Records the model size, writes the string table, + /// completes and writes the header, and writes the tail. + /// + public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null) + { + Contracts.CheckValue(writer, nameof(writer)); + Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); + Contracts.CheckValueOrNull(pool); - // Record the model size. - EndModelCore(writer, fpMin, ref header); + // Record the model size. + EndModelCore(writer, fpMin, ref header); - Contracts.Check(header.FpStringTable == 0); - Contracts.Check(header.CbStringTable == 0); - Contracts.Check(header.FpStringChars == 0); - Contracts.Check(header.CbStringChars == 0); + Contracts.Check(header.FpStringTable == 0); + Contracts.Check(header.CbStringTable == 0); + Contracts.Check(header.FpStringChars == 0); + Contracts.Check(header.CbStringChars == 0); - // Write the strings. - if (pool != null && pool.Count > 0) + // Write the strings. + if (pool != null && pool.Count > 0) + { + header.FpStringTable = writer.FpCur() - fpMin; + long offset = 0; + int cv = 0; + // REVIEW: Implement an indexer on pool! + foreach (var ns in pool) { - header.FpStringTable = writer.FpCur() - fpMin; - long offset = 0; - int cv = 0; - // REVIEW: Implement an indexer on pool! - foreach (var ns in pool) - { - Contracts.Assert(ns.Id == cv); - offset += ns.Value.Length * sizeof(char); - writer.Write(offset); - cv++; - } - Contracts.Assert(cv == pool.Count); - header.CbStringTable = pool.Count * sizeof(long); - header.FpStringChars = writer.FpCur() - fpMin; - Contracts.Assert(header.FpStringChars == header.FpStringTable + header.CbStringTable); - foreach (var ns in pool) - { - foreach (var ch in ns.Value.Span) - writer.Write((short)ch); - } - header.CbStringChars = writer.FpCur() - header.FpStringChars - fpMin; - Contracts.Assert(offset == header.CbStringChars); + Contracts.Assert(ns.Id == cv); + offset += ns.Value.Length * sizeof(char); + writer.Write(offset); + cv++; + } + Contracts.Assert(cv == pool.Count); + header.CbStringTable = pool.Count * sizeof(long); + header.FpStringChars = writer.FpCur() - fpMin; + Contracts.Assert(header.FpStringChars == header.FpStringTable + header.CbStringTable); + foreach (var ns in pool) + { + foreach (var ch in ns.Value.Span) + writer.Write((short)ch); } + header.CbStringChars = writer.FpCur() - header.FpStringChars - fpMin; + Contracts.Assert(offset == header.CbStringChars); + } - WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName); + WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName); - WriteHeaderAndTailCore(writer, fpMin, ref header); - } + WriteHeaderAndTailCore(writer, fpMin, ref header); + } - private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName) + private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName) + { + if (!string.IsNullOrEmpty(loaderAssemblyName)) { - if (!string.IsNullOrEmpty(loaderAssemblyName)) - { - header.FpAssemblyName = writer.FpCur() - fpMin; - header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char); + header.FpAssemblyName = writer.FpCur() - fpMin; + header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char); - foreach (var ch in loaderAssemblyName) - writer.Write((short)ch); - } - else - { - header.FpAssemblyName = 0; - header.CbAssemblyName = 0; - } + foreach (var ch in loaderAssemblyName) + writer.Write((short)ch); } - - /// - /// The current writer position should be where the tail belongs. Writes the header and tail. - /// Typically this isn't called directly unless you are doing custom string table serialization. - /// In that case you should have called EndModelCore before writing the string table information. - /// - public static void WriteHeaderAndTailCore(BinaryWriter writer, long fpMin, ref ModelHeader header) + else { - Contracts.CheckValue(writer, nameof(writer)); - Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); - - header.FpTail = writer.FpCur() - fpMin; - writer.Write(TailSignatureValue); - header.FpLim = writer.FpCur() - fpMin; - - Exception ex; - bool res = TryValidate(ref header, header.FpLim, out ex); - // If this fails, we didn't construct the header correctly. This is both a bug and - // something we want to protect against at runtime, hence both assert and check. - Contracts.Assert(res); - Contracts.Check(res); - - // Write the header, then seek back to the end. - writer.Seek(fpMin); - byte[] headerBytes = new byte[ModelHeader.Size]; - MarshalToBytes(ref header, headerBytes); - writer.Write(headerBytes); - Contracts.Assert(writer.FpCur() == fpMin + ModelHeader.Size); - writer.Seek(header.FpLim + fpMin); + header.FpAssemblyName = 0; + header.CbAssemblyName = 0; } + } - /// - /// The current writer position should be the end of the model blob. Records the size of the model blob. - /// Typically this isn't called directly unless you are doing custom string table serialization. - /// - public static void EndModelCore(BinaryWriter writer, long fpMin, ref ModelHeader header) - { - Contracts.Check(header.FpModel == ModelHeader.Size); - Contracts.Check(header.CbModel == 0); + /// + /// The current writer position should be where the tail belongs. Writes the header and tail. + /// Typically this isn't called directly unless you are doing custom string table serialization. + /// In that case you should have called EndModelCore before writing the string table information. + /// + public static void WriteHeaderAndTailCore(BinaryWriter writer, long fpMin, ref ModelHeader header) + { + Contracts.CheckValue(writer, nameof(writer)); + Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); + + header.FpTail = writer.FpCur() - fpMin; + writer.Write(TailSignatureValue); + header.FpLim = writer.FpCur() - fpMin; + + Exception ex; + bool res = TryValidate(ref header, header.FpLim, out ex); + // If this fails, we didn't construct the header correctly. This is both a bug and + // something we want to protect against at runtime, hence both assert and check. + Contracts.Assert(res); + Contracts.Check(res); + + // Write the header, then seek back to the end. + writer.Seek(fpMin); + byte[] headerBytes = new byte[ModelHeader.Size]; + MarshalToBytes(ref header, headerBytes); + writer.Write(headerBytes); + Contracts.Assert(writer.FpCur() == fpMin + ModelHeader.Size); + writer.Seek(header.FpLim + fpMin); + } - long fpCur = writer.FpCur(); - Contracts.Check(fpCur - fpMin >= header.FpModel); + /// + /// The current writer position should be the end of the model blob. Records the size of the model blob. + /// Typically this isn't called directly unless you are doing custom string table serialization. + /// + public static void EndModelCore(BinaryWriter writer, long fpMin, ref ModelHeader header) + { + Contracts.Check(header.FpModel == ModelHeader.Size); + Contracts.Check(header.CbModel == 0); - // Record the size of the model. - header.CbModel = fpCur - header.FpModel - fpMin; - } + long fpCur = writer.FpCur(); + Contracts.Check(fpCur - fpMin >= header.FpModel); - /// - /// Sets the version information the header. - /// - public static void SetVersionInfo(ref ModelHeader header, VersionInfo ver) - { - header.ModelSignature = ver.ModelSignature; - header.ModelVerWritten = ver.VerWrittenCur; - header.ModelVerReadable = ver.VerReadableCur; - SetLoaderSig(ref header, ver.LoaderSignature); - SetLoaderSigAlt(ref header, ver.LoaderSignatureAlt); - } + // Record the size of the model. + header.CbModel = fpCur - header.FpModel - fpMin; + } - /// - /// Record the given loader sig in the header. If sig is null, clears the loader sig. - /// - public static void SetLoaderSig(ref ModelHeader header, string sig) - { - header.LoaderSignature0 = 0; - header.LoaderSignature1 = 0; - header.LoaderSignature2 = 0; + /// + /// Sets the version information the header. + /// + public static void SetVersionInfo(ref ModelHeader header, VersionInfo ver) + { + header.ModelSignature = ver.ModelSignature; + header.ModelVerWritten = ver.VerWrittenCur; + header.ModelVerReadable = ver.VerReadableCur; + SetLoaderSig(ref header, ver.LoaderSignature); + SetLoaderSigAlt(ref header, ver.LoaderSignatureAlt); + } + + /// + /// Record the given loader sig in the header. If sig is null, clears the loader sig. + /// + public static void SetLoaderSig(ref ModelHeader header, string sig) + { + header.LoaderSignature0 = 0; + header.LoaderSignature1 = 0; + header.LoaderSignature2 = 0; - if (sig == null) - return; + if (sig == null) + return; - Contracts.Check(sig.Length <= 24); - for (int ich = 0; ich < sig.Length; ich++) - { - char ch = sig[ich]; - Contracts.Check(ch <= 0xFF); - if (ich < 8) - header.LoaderSignature0 |= (ulong)ch << (ich * 8); - else if (ich < 16) - header.LoaderSignature1 |= (ulong)ch << ((ich - 8) * 8); - else if (ich < 24) - header.LoaderSignature2 |= (ulong)ch << ((ich - 16) * 8); - } + Contracts.Check(sig.Length <= 24); + for (int ich = 0; ich < sig.Length; ich++) + { + char ch = sig[ich]; + Contracts.Check(ch <= 0xFF); + if (ich < 8) + header.LoaderSignature0 |= (ulong)ch << (ich * 8); + else if (ich < 16) + header.LoaderSignature1 |= (ulong)ch << ((ich - 8) * 8); + else if (ich < 24) + header.LoaderSignature2 |= (ulong)ch << ((ich - 16) * 8); } + } - /// - /// Record the given alternate loader sig in the header. If sig is null, clears the alternate loader sig. - /// - public static void SetLoaderSigAlt(ref ModelHeader header, string sig) - { - header.LoaderSignatureAlt0 = 0; - header.LoaderSignatureAlt1 = 0; - header.LoaderSignatureAlt2 = 0; + /// + /// Record the given alternate loader sig in the header. If sig is null, clears the alternate loader sig. + /// + public static void SetLoaderSigAlt(ref ModelHeader header, string sig) + { + header.LoaderSignatureAlt0 = 0; + header.LoaderSignatureAlt1 = 0; + header.LoaderSignatureAlt2 = 0; - if (sig == null) - return; + if (sig == null) + return; - Contracts.Check(sig.Length <= 24); - for (int ich = 0; ich < sig.Length; ich++) - { - char ch = sig[ich]; - Contracts.Check(ch <= 0xFF); - if (ich < 8) - header.LoaderSignatureAlt0 |= (ulong)ch << (ich * 8); - else if (ich < 16) - header.LoaderSignatureAlt1 |= (ulong)ch << ((ich - 8) * 8); - else if (ich < 24) - header.LoaderSignatureAlt2 |= (ulong)ch << ((ich - 16) * 8); - } + Contracts.Check(sig.Length <= 24); + for (int ich = 0; ich < sig.Length; ich++) + { + char ch = sig[ich]; + Contracts.Check(ch <= 0xFF); + if (ich < 8) + header.LoaderSignatureAlt0 |= (ulong)ch << (ich * 8); + else if (ich < 16) + header.LoaderSignatureAlt1 |= (ulong)ch << ((ich - 8) * 8); + else if (ich < 24) + header.LoaderSignatureAlt2 |= (ulong)ch << ((ich - 16) * 8); } + } - /// - /// Low level method for copying bytes from a header structure into a byte array. - /// - public static void MarshalToBytes(ref ModelHeader header, byte[] bytes) + /// + /// Low level method for copying bytes from a header structure into a byte array. + /// + public static void MarshalToBytes(ref ModelHeader header, byte[] bytes) + { + Contracts.Check(Utils.Size(bytes) >= Size); + unsafe { - Contracts.Check(Utils.Size(bytes) >= Size); - unsafe - { - fixed (ModelHeader* pheader = &header) - Marshal.Copy((IntPtr)pheader, bytes, 0, Size); - } + fixed (ModelHeader* pheader = &header) + Marshal.Copy((IntPtr)pheader, bytes, 0, Size); } + } - // Utilities for reading. + // Utilities for reading. - /// - /// Read the model header, strings, etc from reader. Also validates the header (throws if bad). - /// Leaves the reader position at the beginning of the model blob. - /// - public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader) - { - fpMin = reader.FpCur(); + /// + /// Read the model header, strings, etc from reader. Also validates the header (throws if bad). + /// Leaves the reader position at the beginning of the model blob. + /// + public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader) + { + fpMin = reader.FpCur(); - byte[] headerBytes = reader.ReadBytes(ModelHeader.Size); - Contracts.CheckDecode(headerBytes.Length == ModelHeader.Size); - ModelHeader.MarshalFromBytes(out header, headerBytes); + byte[] headerBytes = reader.ReadBytes(ModelHeader.Size); + Contracts.CheckDecode(headerBytes.Length == ModelHeader.Size); + ModelHeader.MarshalFromBytes(out header, headerBytes); - Exception ex; - if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex)) - throw ex; + Exception ex; + if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex)) + throw ex; - reader.Seek(header.FpModel + fpMin); - } + reader.Seek(header.FpModel + fpMin); + } - /// - /// Finish reading. Checks that the current reader position is the end of the model blob. - /// Seeks to the end of the entire model file (after the tail). - /// - public static void EndRead(long fpMin, ref ModelHeader header, BinaryReader reader) - { - Contracts.CheckDecode(header.FpModel + header.CbModel == reader.FpCur() - fpMin); - reader.Seek(header.FpLim + fpMin); - } + /// + /// Finish reading. Checks that the current reader position is the end of the model blob. + /// Seeks to the end of the entire model file (after the tail). + /// + public static void EndRead(long fpMin, ref ModelHeader header, BinaryReader reader) + { + Contracts.CheckDecode(header.FpModel + header.CbModel == reader.FpCur() - fpMin); + reader.Seek(header.FpLim + fpMin); + } - /// - /// Performs standard version validation. - /// - public static void CheckVersionInfo(ref ModelHeader header, VersionInfo ver) + /// + /// Performs standard version validation. + /// + public static void CheckVersionInfo(ref ModelHeader header, VersionInfo ver) + { + Contracts.CheckDecode(header.ModelSignature == ver.ModelSignature, "Unknown file type"); + Contracts.CheckDecode(header.ModelVerReadable <= header.ModelVerWritten, "Corrupt file header"); + if (header.ModelVerReadable > ver.VerWrittenCur) + throw Contracts.ExceptDecode("Cause: ML.NET {0} cannont read component '{1}' of the model, because the model is too new.\n" + + "Suggestion: Make sure the model is trained with ML.NET {0} or older.\n" + + "Debug details: Maximum expected version {2}, got {3}.", + typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur); + if (header.ModelVerWritten < ver.VerWeCanReadBack) { - Contracts.CheckDecode(header.ModelSignature == ver.ModelSignature, "Unknown file type"); - Contracts.CheckDecode(header.ModelVerReadable <= header.ModelVerWritten, "Corrupt file header"); - if (header.ModelVerReadable > ver.VerWrittenCur) - throw Contracts.ExceptDecode("Cause: ML.NET {0} cannont read component '{1}' of the model, because the model is too new.\n" + - "Suggestion: Make sure the model is trained with ML.NET {0} or older.\n" + - "Debug details: Maximum expected version {2}, got {3}.", - typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur); - if (header.ModelVerWritten < ver.VerWeCanReadBack) - { - // Breaking backwards compatibility is something we should avoid if at all possible. If - // this message is observed, it may be a bug. - throw Contracts.ExceptDecode("Cause: ML.NET {0} cannot read component '{1}' of the model, because the model is too old.\n" + - "Suggestion: Make sure the model is trained with ML.NET {0}.\n" + - "Debug details: Minimum expected version {2}, got {3}.", - typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur); - } + // Breaking backwards compatibility is something we should avoid if at all possible. If + // this message is observed, it may be a bug. + throw Contracts.ExceptDecode("Cause: ML.NET {0} cannot read component '{1}' of the model, because the model is too old.\n" + + "Suggestion: Make sure the model is trained with ML.NET {0}.\n" + + "Debug details: Minimum expected version {2}, got {3}.", + typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur); } + } - /// - /// Low level method for copying bytes from a byte array to a header structure. - /// - public static void MarshalFromBytes(out ModelHeader header, byte[] bytes) + /// + /// Low level method for copying bytes from a byte array to a header structure. + /// + public static void MarshalFromBytes(out ModelHeader header, byte[] bytes) + { + Contracts.Check(Utils.Size(bytes) >= Size); + unsafe { - Contracts.Check(Utils.Size(bytes) >= Size); - unsafe - { - fixed (ModelHeader* pheader = &header) - Marshal.Copy(bytes, 0, (IntPtr)pheader, Size); - } + fixed (ModelHeader* pheader = &header) + Marshal.Copy(bytes, 0, (IntPtr)pheader, Size); } + } - /// - /// Checks the basic validity of the header, assuming the stream is at least the given size. - /// Returns false (and the out exception) on failure. - /// - public static bool TryValidate(ref ModelHeader header, long size, out Exception ex) - { - Contracts.Check(size >= 0); + /// + /// Checks the basic validity of the header, assuming the stream is at least the given size. + /// Returns false (and the out exception) on failure. + /// + public static bool TryValidate(ref ModelHeader header, long size, out Exception ex) + { + Contracts.Check(size >= 0); - try - { - Contracts.CheckDecode(header.Signature == SignatureValue, "Wrong file type"); - Contracts.CheckDecode(header.VerReadable <= header.VerWritten, "Corrupt file header"); - Contracts.CheckDecode(header.VerReadable <= VerWrittenCur, "File is too new"); - Contracts.CheckDecode(header.VerWritten >= VerWeCanReadBack, "File is too old"); + try + { + Contracts.CheckDecode(header.Signature == SignatureValue, "Wrong file type"); + Contracts.CheckDecode(header.VerReadable <= header.VerWritten, "Corrupt file header"); + Contracts.CheckDecode(header.VerReadable <= VerWrittenCur, "File is too new"); + Contracts.CheckDecode(header.VerWritten >= VerWeCanReadBack, "File is too old"); - // Currently the model always comes immediately after the header. - Contracts.CheckDecode(header.FpModel == Size); - Contracts.CheckDecode(header.FpModel + header.CbModel >= header.FpModel); + // Currently the model always comes immediately after the header. + Contracts.CheckDecode(header.FpModel == Size); + Contracts.CheckDecode(header.FpModel + header.CbModel >= header.FpModel); - if (header.FpStringTable == 0) + if (header.FpStringTable == 0) + { + // No strings. + Contracts.CheckDecode(header.CbStringTable == 0); + Contracts.CheckDecode(header.FpStringChars == 0); + Contracts.CheckDecode(header.CbStringChars == 0); + if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) { - // No strings. - Contracts.CheckDecode(header.CbStringTable == 0); - Contracts.CheckDecode(header.FpStringChars == 0); - Contracts.CheckDecode(header.CbStringChars == 0); - if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) - { - Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel); - } + Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel); } - else + } + else + { + // Currently the string table always comes immediately after the model block. + Contracts.CheckDecode(header.FpStringTable == header.FpModel + header.CbModel); + Contracts.CheckDecode(header.CbStringTable % sizeof(long) == 0); + Contracts.CheckDecode(header.CbStringTable / sizeof(long) < int.MaxValue); + Contracts.CheckDecode(header.FpStringTable + header.CbStringTable > header.FpStringTable); + Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable); + Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0); + Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars); + if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) { - // Currently the string table always comes immediately after the model block. - Contracts.CheckDecode(header.FpStringTable == header.FpModel + header.CbModel); - Contracts.CheckDecode(header.CbStringTable % sizeof(long) == 0); - Contracts.CheckDecode(header.CbStringTable / sizeof(long) < int.MaxValue); - Contracts.CheckDecode(header.FpStringTable + header.CbStringTable > header.FpStringTable); - Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable); - Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0); - Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars); - if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) - { - Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars); - } + Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars); } + } - if (header.VerWritten >= VerAssemblyNameSupported) + if (header.VerWritten >= VerAssemblyNameSupported) + { + if (header.FpAssemblyName == 0) { - if (header.FpAssemblyName == 0) + Contracts.CheckDecode(header.CbAssemblyName == 0); + } + else + { + // the assembly name always immediately after the string table, if there is one + if (header.FpStringTable == 0) { - Contracts.CheckDecode(header.CbAssemblyName == 0); + Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel); } else { - // the assembly name always immediately after the string table, if there is one - if (header.FpStringTable == 0) - { - Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel); - } - else - { - Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars); - } - Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0); - Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName); + Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars); } + Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0); + Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName); } + } - Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong)); - Contracts.CheckDecode(size == 0 || size >= header.FpLim); + Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong)); + Contracts.CheckDecode(size == 0 || size >= header.FpLim); - ex = null; - return true; - } - catch (Exception e) - { - ex = e; - return false; - } + ex = null; + return true; + } + catch (Exception e) + { + ex = e; + return false; } + } + + /// + /// Checks the validity of the header, reads the string table, etc. + /// + public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex) + { + Contracts.CheckValue(reader, nameof(reader)); + Contracts.Check(fpMin >= 0); - /// - /// Checks the validity of the header, reads the string table, etc. - /// - public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex) + if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex)) { - Contracts.CheckValue(reader, nameof(reader)); - Contracts.Check(fpMin >= 0); + strings = null; + loaderAssemblyName = null; + return false; + } - if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex)) - { - strings = null; - loaderAssemblyName = null; - return false; - } + try + { + long fpOrig = reader.FpCur(); - try + StringBuilder sb = null; + if (header.FpStringTable == 0) { - long fpOrig = reader.FpCur(); - - StringBuilder sb = null; - if (header.FpStringTable == 0) - { - // No strings. - strings = null; + // No strings. + strings = null; - if (header.VerWritten < VerAssemblyNameSupported) - { - // Before VerAssemblyNameSupported, if there were no strings in the model, - // validation ended here. Specifically the FpTail checks below were skipped. - // There are earlier versions of models that don't have strings, and 'reader' is - // not at FpTail at this point. - // Preserve the previous behavior by returning early here. - loaderAssemblyName = null; - ex = null; - return true; - } - } - else + if (header.VerWritten < VerAssemblyNameSupported) { - reader.Seek(header.FpStringTable + fpMin); - Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); - - long cstr = header.CbStringTable / sizeof(long); - Contracts.Assert(cstr < int.MaxValue); - long[] offsets = reader.ReadLongArray((int)cstr); - Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); - Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); - - strings = new string[cstr]; - long offset = 0; - sb = new StringBuilder(); - for (int i = 0; i < offsets.Length; i++) - { - Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); - - long offsetPrev = offset; - offset = offsets[i]; - Contracts.CheckDecode(offsetPrev <= offset && offset <= header.CbStringChars); - Contracts.CheckDecode(offset % sizeof(char) == 0); - long cch = (offset - offsetPrev) / sizeof(char); - Contracts.CheckDecode(cch < int.MaxValue); - - sb.Clear(); - for (long ich = 0; ich < cch; ich++) - sb.Append((char)reader.ReadUInt16()); - strings[i] = sb.ToString(); - } - Contracts.CheckDecode(offset == header.CbStringChars); - Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); + // Before VerAssemblyNameSupported, if there were no strings in the model, + // validation ended here. Specifically the FpTail checks below were skipped. + // There are earlier versions of models that don't have strings, and 'reader' is + // not at FpTail at this point. + // Preserve the previous behavior by returning early here. + loaderAssemblyName = null; + ex = null; + return true; } + } + else + { + reader.Seek(header.FpStringTable + fpMin); + Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); - if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0) + long cstr = header.CbStringTable / sizeof(long); + Contracts.Assert(cstr < int.MaxValue); + long[] offsets = reader.ReadLongArray((int)cstr); + Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); + Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); + + strings = new string[cstr]; + long offset = 0; + sb = new StringBuilder(); + for (int i = 0; i < offsets.Length; i++) { - reader.Seek(header.FpAssemblyName + fpMin); - int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char); + Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); - sb = sb != null ? sb.Clear() : new StringBuilder(assemblyNameLength); + long offsetPrev = offset; + offset = offsets[i]; + Contracts.CheckDecode(offsetPrev <= offset && offset <= header.CbStringChars); + Contracts.CheckDecode(offset % sizeof(char) == 0); + long cch = (offset - offsetPrev) / sizeof(char); + Contracts.CheckDecode(cch < int.MaxValue); - for (long ich = 0; ich < assemblyNameLength; ich++) + sb.Clear(); + for (long ich = 0; ich < cch; ich++) sb.Append((char)reader.ReadUInt16()); - - loaderAssemblyName = sb.ToString(); - } - else - { - loaderAssemblyName = null; + strings[i] = sb.ToString(); } + Contracts.CheckDecode(offset == header.CbStringChars); + Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); + } - Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin); + if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0) + { + reader.Seek(header.FpAssemblyName + fpMin); + int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char); - ulong tail = reader.ReadUInt64(); - Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail"); + sb = sb != null ? sb.Clear() : new StringBuilder(assemblyNameLength); - ex = null; + for (long ich = 0; ich < assemblyNameLength; ich++) + sb.Append((char)reader.ReadUInt16()); - reader.Seek(fpOrig); - return true; + loaderAssemblyName = sb.ToString(); } - catch (Exception e) + else { - strings = null; loaderAssemblyName = null; - ex = e; - return false; } - } - /// - /// Extract and return the loader sig from the header, trimming trailing zeros. - /// - public static string GetLoaderSig(ref ModelHeader header) - { - char[] chars = new char[3 * sizeof(ulong)]; + Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin); - for (int ich = 0; ich < chars.Length; ich++) - { - char ch; - if (ich < 8) - ch = (char)((header.LoaderSignature0 >> (ich * 8)) & 0xFF); - else if (ich < 16) - ch = (char)((header.LoaderSignature1 >> ((ich - 8) * 8)) & 0xFF); - else - ch = (char)((header.LoaderSignature2 >> ((ich - 16) * 8)) & 0xFF); - chars[ich] = ch; - } + ulong tail = reader.ReadUInt64(); + Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail"); - int cch = 24; - while (cch > 0 && chars[cch - 1] == 0) - cch--; - return new string(chars, 0, cch); - } + ex = null; - /// - /// Extract and return the alternate loader sig from the header, trimming trailing zeros. - /// - public static string GetLoaderSigAlt(ref ModelHeader header) + reader.Seek(fpOrig); + return true; + } + catch (Exception e) { - char[] chars = new char[3 * sizeof(ulong)]; + strings = null; + loaderAssemblyName = null; + ex = e; + return false; + } + } - for (int ich = 0; ich < chars.Length; ich++) - { - char ch; - if (ich < 8) - ch = (char)((header.LoaderSignatureAlt0 >> (ich * 8)) & 0xFF); - else if (ich < 16) - ch = (char)((header.LoaderSignatureAlt1 >> ((ich - 8) * 8)) & 0xFF); - else - ch = (char)((header.LoaderSignatureAlt2 >> ((ich - 16) * 8)) & 0xFF); - chars[ich] = ch; - } + /// + /// Extract and return the loader sig from the header, trimming trailing zeros. + /// + public static string GetLoaderSig(ref ModelHeader header) + { + char[] chars = new char[3 * sizeof(ulong)]; - int cch = 24; - while (cch > 0 && chars[cch - 1] == 0) - cch--; - return new string(chars, 0, cch); + for (int ich = 0; ich < chars.Length; ich++) + { + char ch; + if (ich < 8) + ch = (char)((header.LoaderSignature0 >> (ich * 8)) & 0xFF); + else if (ich < 16) + ch = (char)((header.LoaderSignature1 >> ((ich - 8) * 8)) & 0xFF); + else + ch = (char)((header.LoaderSignature2 >> ((ich - 16) * 8)) & 0xFF); + chars[ich] = ch; } + + int cch = 24; + while (cch > 0 && chars[cch - 1] == 0) + cch--; + return new string(chars, 0, cch); } /// - /// This is used to simplify version checking boiler-plate code. It is an optional - /// utility type. + /// Extract and return the alternate loader sig from the header, trimming trailing zeros. /// - [BestFriend] - internal readonly struct VersionInfo + public static string GetLoaderSigAlt(ref ModelHeader header) { - public readonly ulong ModelSignature; - public readonly uint VerWrittenCur; - public readonly uint VerReadableCur; - public readonly uint VerWeCanReadBack; - public readonly string LoaderAssemblyName; - public readonly string LoaderSignature; - public readonly string LoaderSignatureAlt; - - /// - /// Construct version info with a string value for modelSignature. The string must be 8 characters - /// all less than 0x100. Spaces are mapped to zero. This assumes little-endian. - /// - public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack, - string loaderAssemblyName, string loaderSignature = null, string loaderSignatureAlt = null) + char[] chars = new char[3 * sizeof(ulong)]; + + for (int ich = 0; ich < chars.Length; ich++) { - Contracts.Check(Utils.Size(modelSignature) == 8, "Model signature must be eight characters"); - ModelSignature = 0; - for (int ich = 0; ich < modelSignature.Length; ich++) - { - char ch = modelSignature[ich]; - Contracts.Check(ch <= 0xFF); - // Map space to zero. - if (ch != ' ') - ModelSignature |= (ulong)ch << (ich * 8); - } + char ch; + if (ich < 8) + ch = (char)((header.LoaderSignatureAlt0 >> (ich * 8)) & 0xFF); + else if (ich < 16) + ch = (char)((header.LoaderSignatureAlt1 >> ((ich - 8) * 8)) & 0xFF); + else + ch = (char)((header.LoaderSignatureAlt2 >> ((ich - 16) * 8)) & 0xFF); + chars[ich] = ch; + } + + int cch = 24; + while (cch > 0 && chars[cch - 1] == 0) + cch--; + return new string(chars, 0, cch); + } +} + +/// +/// This is used to simplify version checking boiler-plate code. It is an optional +/// utility type. +/// +[BestFriend] +internal readonly struct VersionInfo +{ + public readonly ulong ModelSignature; + public readonly uint VerWrittenCur; + public readonly uint VerReadableCur; + public readonly uint VerWeCanReadBack; + public readonly string LoaderAssemblyName; + public readonly string LoaderSignature; + public readonly string LoaderSignatureAlt; - VerWrittenCur = verWrittenCur; - VerReadableCur = verReadableCur; - VerWeCanReadBack = verWeCanReadBack; - LoaderAssemblyName = loaderAssemblyName; - LoaderSignature = loaderSignature; - LoaderSignatureAlt = loaderSignatureAlt; + /// + /// Construct version info with a string value for modelSignature. The string must be 8 characters + /// all less than 0x100. Spaces are mapped to zero. This assumes little-endian. + /// + public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack, + string loaderAssemblyName, string loaderSignature = null, string loaderSignatureAlt = null) + { + Contracts.Check(Utils.Size(modelSignature) == 8, "Model signature must be eight characters"); + ModelSignature = 0; + for (int ich = 0; ich < modelSignature.Length; ich++) + { + char ch = modelSignature[ich]; + Contracts.Check(ch <= 0xFF); + // Map space to zero. + if (ch != ' ') + ModelSignature |= (ulong)ch << (ich * 8); } + + VerWrittenCur = verWrittenCur; + VerReadableCur = verReadableCur; + VerWeCanReadBack = verWeCanReadBack; + LoaderAssemblyName = loaderAssemblyName; + LoaderSignature = loaderSignature; + LoaderSignatureAlt = loaderSignatureAlt; } } diff --git a/src/Microsoft.ML.Core/Data/ModelLoadContext.cs b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs index 50897fcce8..cb7aff08ca 100644 --- a/src/Microsoft.ML.Core/Data/ModelLoadContext.cs +++ b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs @@ -8,183 +8,182 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// This is a convenience context object for loading models from a repository, for +/// implementors of ICanSaveModel. It is not mandated but designed to reduce the +/// amount of boiler plate code. It can also be used when loading from a single stream, +/// for implementors of ICanSaveInBinaryFormat. +/// +[BestFriend] +internal sealed partial class ModelLoadContext : IDisposable { /// - /// This is a convenience context object for loading models from a repository, for - /// implementors of ICanSaveModel. It is not mandated but designed to reduce the - /// amount of boiler plate code. It can also be used when loading from a single stream, - /// for implementors of ICanSaveInBinaryFormat. + /// When in repository mode, this is the repository we're reading from. It is null when + /// in single-stream mode. + /// + public readonly RepositoryReader Repository; + + /// + /// When in repository mode, this is the directory we're reading from. Null means the root + /// of the repository. It is always null in single-stream mode. + /// + public readonly string Directory; + + /// + /// The main stream reader. + /// + public readonly BinaryReader Reader; + + /// + /// The strings loaded from the main stream's string table. + /// + public readonly string[] Strings; + + /// + /// The name of the assembly that the loader lives in. + /// + /// + /// This may be null or empty if one was never written to the model, or is an older model version. + /// + public readonly string LoaderAssemblyName; + + /// + /// The main stream's model header. /// [BestFriend] - internal sealed partial class ModelLoadContext : IDisposable + internal ModelHeader Header; + + /// + /// The min file position of the main stream. + /// + public readonly long FpMin; + + /// + /// Exception context provided by Repository (can be null). + /// + private readonly IExceptionContext _ectx; + + /// + /// Returns whether this context is in repository mode (true) or single-stream mode (false). + /// + public bool InRepository { get { return Repository != null; } } + + /// + /// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel. + /// + internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) { - /// - /// When in repository mode, this is the repository we're reading from. It is null when - /// in single-stream mode. - /// - public readonly RepositoryReader Repository; - - /// - /// When in repository mode, this is the directory we're reading from. Null means the root - /// of the repository. It is always null in single-stream mode. - /// - public readonly string Directory; - - /// - /// The main stream reader. - /// - public readonly BinaryReader Reader; - - /// - /// The strings loaded from the main stream's string table. - /// - public readonly string[] Strings; - - /// - /// The name of the assembly that the loader lives in. - /// - /// - /// This may be null or empty if one was never written to the model, or is an older model version. - /// - public readonly string LoaderAssemblyName; - - /// - /// The main stream's model header. - /// - [BestFriend] - internal ModelHeader Header; - - /// - /// The min file position of the main stream. - /// - public readonly long FpMin; - - /// - /// Exception context provided by Repository (can be null). - /// - private readonly IExceptionContext _ectx; - - /// - /// Returns whether this context is in repository mode (true) or single-stream mode (false). - /// - public bool InRepository { get { return Repository != null; } } - - /// - /// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel. - /// - internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) - { - Contracts.CheckValue(rep, nameof(rep)); - Repository = rep; - _ectx = rep.ExceptionContext; - - _ectx.CheckValue(ent, nameof(ent)); - _ectx.CheckValueOrNull(dir); - - Directory = dir; - - Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true); - try - { - ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); - } - catch - { - Reader.Dispose(); - throw; - } - } + Contracts.CheckValue(rep, nameof(rep)); + Repository = rep; + _ectx = rep.ExceptionContext; - /// - /// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat. - /// - internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null) - { - Contracts.AssertValueOrNull(ectx); - _ectx = ectx; - _ectx.CheckValue(reader, nameof(reader)); + _ectx.CheckValue(ent, nameof(ent)); + _ectx.CheckValueOrNull(dir); - Repository = null; - Directory = null; - Reader = reader; - ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); - } + Directory = dir; - public void CheckAtModel() + Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true); + try { - _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel); + ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); } - - public void CheckAtModel(VersionInfo ver) + catch { - _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel); - ModelHeader.CheckVersionInfo(ref Header, ver); + Reader.Dispose(); + throw; } + } - /// - /// Performs version checks. - /// - public void CheckVersionInfo(VersionInfo ver) - { - ModelHeader.CheckVersionInfo(ref Header, ver); - } + /// + /// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat. + /// + internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null) + { + Contracts.AssertValueOrNull(ectx); + _ectx = ectx; + _ectx.CheckValue(reader, nameof(reader)); + + Repository = null; + Directory = null; + Reader = reader; + ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); + } - /// - /// Reads an integer from the load context's reader, and returns the associated string, - /// or null (encoded as -1). - /// - public string LoadStringOrNull() - { - int id = Reader.ReadInt32(); - // Note that -1 means null. Empty strings are in the string table. - _ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings)); - if (id >= 0) - return Strings[id]; - return null; - } + public void CheckAtModel() + { + _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel); + } - /// - /// Reads an integer from the load context's reader, and returns the associated string. - /// - public string LoadString() - { - int id = Reader.ReadInt32(); - Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings)); + public void CheckAtModel(VersionInfo ver) + { + _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel); + ModelHeader.CheckVersionInfo(ref Header, ver); + } + + /// + /// Performs version checks. + /// + public void CheckVersionInfo(VersionInfo ver) + { + ModelHeader.CheckVersionInfo(ref Header, ver); + } + + /// + /// Reads an integer from the load context's reader, and returns the associated string, + /// or null (encoded as -1). + /// + public string LoadStringOrNull() + { + int id = Reader.ReadInt32(); + // Note that -1 means null. Empty strings are in the string table. + _ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings)); + if (id >= 0) return Strings[id]; - } + return null; + } - /// - /// Reads an integer from the load context's reader, and returns the associated string. - /// Throws if the string is empty or null. - /// - public string LoadNonEmptyString() - { - int id = Reader.ReadInt32(); - _ectx.CheckDecode(0 <= id && id < Utils.Size(Strings)); - var str = Strings[id]; - _ectx.CheckDecode(str.Length > 0); - return str; - } + /// + /// Reads an integer from the load context's reader, and returns the associated string. + /// + public string LoadString() + { + int id = Reader.ReadInt32(); + Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings)); + return Strings[id]; + } - /// - /// Commit the load operation. This completes reading of the main stream. When in repository - /// mode, it disposes the Reader (but not the repository). - /// - public void Done() - { - ModelHeader.EndRead(FpMin, ref Header, Reader); - Dispose(); - } + /// + /// Reads an integer from the load context's reader, and returns the associated string. + /// Throws if the string is empty or null. + /// + public string LoadNonEmptyString() + { + int id = Reader.ReadInt32(); + _ectx.CheckDecode(0 <= id && id < Utils.Size(Strings)); + var str = Strings[id]; + _ectx.CheckDecode(str.Length > 0); + return str; + } - /// - /// When in repository mode, this disposes the Reader (but no the repository). - /// - public void Dispose() - { - // When in single-stream mode, we don't own the Reader. - if (InRepository) - Reader.Dispose(); - } + /// + /// Commit the load operation. This completes reading of the main stream. When in repository + /// mode, it disposes the Reader (but not the repository). + /// + public void Done() + { + ModelHeader.EndRead(FpMin, ref Header, Reader); + Dispose(); + } + + /// + /// When in repository mode, this disposes the Reader (but no the repository). + /// + public void Dispose() + { + // When in single-stream mode, we don't own the Reader. + if (InRepository) + Reader.Dispose(); } } diff --git a/src/Microsoft.ML.Core/Data/ModelLoading.cs b/src/Microsoft.ML.Core/Data/ModelLoading.cs index c186e91bfe..e620649bb7 100644 --- a/src/Microsoft.ML.Core/Data/ModelLoading.cs +++ b/src/Microsoft.ML.Core/Data/ModelLoading.cs @@ -9,357 +9,356 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// Signature for a repository based model loader. This is the dual of . +/// +[BestFriend] +internal delegate void SignatureLoadModel(ModelLoadContext ctx); + +internal sealed partial class ModelLoadContext : IDisposable { + public const string ModelStreamName = "Model.key"; + internal const string NameBinary = "Model.bin"; + /// - /// Signature for a repository based model loader. This is the dual of . + /// Returns the new assembly name to maintain backward compatibility. /// - [BestFriend] - internal delegate void SignatureLoadModel(ModelLoadContext ctx); - - internal sealed partial class ModelLoadContext : IDisposable + private string ForwardedLoaderAssemblyName { - public const string ModelStreamName = "Model.key"; - internal const string NameBinary = "Model.bin"; - - /// - /// Returns the new assembly name to maintain backward compatibility. - /// - private string ForwardedLoaderAssemblyName + get { - get + string[] nameDetails = LoaderAssemblyName.Split(','); + switch (nameDetails[0]) { - string[] nameDetails = LoaderAssemblyName.Split(','); - switch (nameDetails[0]) - { - case "Microsoft.ML.HalLearners": - nameDetails[0] = "Microsoft.ML.Mkl.Components"; - break; - case "Microsoft.ML.StandardLearners": - nameDetails[0] = "Microsoft.ML.StandardTrainers"; - break; - default: - return LoaderAssemblyName; - } - - return string.Join(",", nameDetails); + case "Microsoft.ML.HalLearners": + nameDetails[0] = "Microsoft.ML.Mkl.Components"; + break; + case "Microsoft.ML.StandardLearners": + nameDetails[0] = "Microsoft.ML.StandardTrainers"; + break; + default: + return LoaderAssemblyName; } - } - /// - /// Return whether this context contains a directory and stream for a sub-model with - /// the indicated name. This does not attempt to load the sub-model. - /// - public bool ContainsModel(string name) - { - if (!InRepository) - return false; - if (string.IsNullOrEmpty(name)) - return false; - - var dir = Path.Combine(Directory ?? "", name); - var ent = Repository.OpenEntryOrNull(dir, ModelStreamName); - if (ent != null) - { - ent.Dispose(); - return true; - } - - if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null) - { - ent.Dispose(); - return true; - } - - return false; + return string.Join(",", nameDetails); } + } - /// - /// Load an optional object from the repository directory. - /// Returns false iff no stream was found for the object, iff result is set to null. - /// Throws if loading fails for any other reason. - /// - public static bool LoadModelOrNull(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) - where TRes : class - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(rep, nameof(rep)); - var ent = rep.OpenEntryOrNull(dir, ModelStreamName); - if (ent != null) - { - using (ent) - { - // Provide the repository, entry, and directory name to the loadable class ctor. - env.Assert(ent.Stream.Position == 0); - LoadModel(env, out result, rep, ent, dir, extra); - return true; - } - } - - if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null) - { - using (ent) - { - env.Assert(ent.Stream.Position == 0); - LoadModel(env, out result, ent.Stream, extra); - return true; - } - } - - result = null; + /// + /// Return whether this context contains a directory and stream for a sub-model with + /// the indicated name. This does not attempt to load the sub-model. + /// + public bool ContainsModel(string name) + { + if (!InRepository) + return false; + if (string.IsNullOrEmpty(name)) return false; - } - /// - /// Load an object from the repository directory. - /// - public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) - where TRes : class + var dir = Path.Combine(Directory ?? "", name); + var ent = Repository.OpenEntryOrNull(dir, ModelStreamName); + if (ent != null) { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(rep, nameof(rep)); - if (!LoadModelOrNull(env, out result, rep, dir, extra)) - throw env.ExceptDecode("Corrupt model file"); - env.AssertValue(result); + ent.Dispose(); + return true; } - /// - /// Load a sub model from the given sub directory if it exists. This requires InRepository to be true. - /// Returns false iff no stream was found for the object, iff result is set to null. - /// Throws if loading fails for any other reason. - /// - public bool LoadModelOrNull(IHostEnvironment env, out TRes result, string name, params object[] extra) - where TRes : class + if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null) { - _ectx.CheckValue(env, nameof(env)); - _ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream"); - return LoadModelOrNull(env, out result, Repository, Path.Combine(Directory ?? "", name), extra); + ent.Dispose(); + return true; } - /// - /// Load a sub model from the given sub directory. This requires InRepository to be true. - /// - public void LoadModel(IHostEnvironment env, out TRes result, string name, params object[] extra) - where TRes : class + return false; + } + + /// + /// Load an optional object from the repository directory. + /// Returns false iff no stream was found for the object, iff result is set to null. + /// Throws if loading fails for any other reason. + /// + public static bool LoadModelOrNull(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(rep, nameof(rep)); + var ent = rep.OpenEntryOrNull(dir, ModelStreamName); + if (ent != null) { - _ectx.CheckValue(env, nameof(env)); - if (!LoadModelOrNull(env, out result, name, extra)) - throw _ectx.ExceptDecode("Corrupt model file"); - _ectx.AssertValue(result); + using (ent) + { + // Provide the repository, entry, and directory name to the loadable class ctor. + env.Assert(ent.Stream.Position == 0); + LoadModel(env, out result, rep, ent, dir, extra); + return true; + } } - /// - /// Try to load from the given repository entry using the default loader(s) specified in the header. - /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. - /// - private static bool TryLoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) - where TRes : class + if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null) { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(rep, nameof(rep)); - long fp = ent.Stream.Position; - using (var ctx = new ModelLoadContext(rep, ent, dir)) + using (ent) { - env.Assert(fp == ctx.FpMin); - if (ctx.TryLoadModelCore(env, out result, extra)) - return true; + env.Assert(ent.Stream.Position == 0); + LoadModel(env, out result, ent.Stream, extra); + return true; } + } - // TryLoadModelCore should rewind on failure. - Contracts.Assert(fp == ent.Stream.Position); + result = null; + return false; + } - return false; - } + /// + /// Load an object from the repository directory. + /// + public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(rep, nameof(rep)); + if (!LoadModelOrNull(env, out result, rep, dir, extra)) + throw env.ExceptDecode("Corrupt model file"); + env.AssertValue(result); + } - /// - /// Load from the given repository entry using the default loader(s) specified in the header. - /// - public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) - where TRes : class - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(rep, nameof(rep)); - if (!TryLoadModel(env, out result, rep, ent, dir, extra)) - throw env.ExceptDecode("Couldn't load model: '{0}'", dir); - } + /// + /// Load a sub model from the given sub directory if it exists. This requires InRepository to be true. + /// Returns false iff no stream was found for the object, iff result is set to null. + /// Throws if loading fails for any other reason. + /// + public bool LoadModelOrNull(IHostEnvironment env, out TRes result, string name, params object[] extra) + where TRes : class + { + _ectx.CheckValue(env, nameof(env)); + _ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream"); + return LoadModelOrNull(env, out result, Repository, Path.Combine(Directory ?? "", name), extra); + } - /// - /// Try to load from the given stream (non-Repository). - /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. - /// - public static bool TryLoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra) - where TRes : class - { - Contracts.CheckValue(env, nameof(env)); - using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true)) - return TryLoadModel(env, out result, reader, extra); - } + /// + /// Load a sub model from the given sub directory. This requires InRepository to be true. + /// + public void LoadModel(IHostEnvironment env, out TRes result, string name, params object[] extra) + where TRes : class + { + _ectx.CheckValue(env, nameof(env)); + if (!LoadModelOrNull(env, out result, name, extra)) + throw _ectx.ExceptDecode("Corrupt model file"); + _ectx.AssertValue(result); + } - /// - /// Load from the given stream (non-Repository) using the default loader(s) specified in the header. - /// - public static void LoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra) - where TRes : class + /// + /// Try to load from the given repository entry using the default loader(s) specified in the header. + /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. + /// + private static bool TryLoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(rep, nameof(rep)); + long fp = ent.Stream.Position; + using (var ctx = new ModelLoadContext(rep, ent, dir)) { - Contracts.CheckValue(env, nameof(env)); - if (!TryLoadModel(env, out result, stream, extra)) - throw Contracts.ExceptDecode("Couldn't load model"); + env.Assert(fp == ctx.FpMin); + if (ctx.TryLoadModelCore(env, out result, extra)) + return true; } - /// - /// Try to load from the given reader (non-Repository). - /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. - /// - public static bool TryLoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra) - where TRes : class - { - Contracts.CheckValue(env, nameof(env)); - long fp = reader.BaseStream.Position; - using (var ctx = new ModelLoadContext(reader)) - { - Contracts.Assert(fp == ctx.FpMin); - return ctx.TryLoadModelCore(env, out result, extra); - } - } + // TryLoadModelCore should rewind on failure. + Contracts.Assert(fp == ent.Stream.Position); + + return false; + } - /// - /// Load from the given reader (non-Repository) using the default loader(s) specified in the header. - /// - public static void LoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra) - where TRes : class + /// + /// Load from the given repository entry using the default loader(s) specified in the header. + /// + public static void LoadModel(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(rep, nameof(rep)); + if (!TryLoadModel(env, out result, rep, ent, dir, extra)) + throw env.ExceptDecode("Couldn't load model: '{0}'", dir); + } + + /// + /// Try to load from the given stream (non-Repository). + /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. + /// + public static bool TryLoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true)) + return TryLoadModel(env, out result, reader, extra); + } + + /// + /// Load from the given stream (non-Repository) using the default loader(s) specified in the header. + /// + public static void LoadModel(IHostEnvironment env, out TRes result, Stream stream, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + if (!TryLoadModel(env, out result, stream, extra)) + throw Contracts.ExceptDecode("Couldn't load model"); + } + + /// + /// Try to load from the given reader (non-Repository). + /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. + /// + public static bool TryLoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + long fp = reader.BaseStream.Position; + using (var ctx = new ModelLoadContext(reader)) { - Contracts.CheckValue(env, nameof(env)); - if (!TryLoadModel(env, out result, reader, extra)) - throw Contracts.ExceptDecode("Couldn't load model"); + Contracts.Assert(fp == ctx.FpMin); + return ctx.TryLoadModelCore(env, out result, extra); } + } - /// - /// Tries to load. - /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. - /// - private bool TryLoadModelCore(IHostEnvironment env, out TRes result, params object[] extra) - where TRes : class - { - _ectx.AssertValue(env, "env"); - _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); + /// + /// Load from the given reader (non-Repository) using the default loader(s) specified in the header. + /// + public static void LoadModel(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + if (!TryLoadModel(env, out result, reader, extra)) + throw Contracts.ExceptDecode("Couldn't load model"); + } - var args = ConcatArgsRev(extra, this); + /// + /// Tries to load. + /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. + /// + private bool TryLoadModelCore(IHostEnvironment env, out TRes result, params object[] extra) + where TRes : class + { + _ectx.AssertValue(env, "env"); + _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); - EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog); + var args = ConcatArgsRev(extra, this); - object tmp; - string sig = ModelHeader.GetLoaderSig(ref Header); - if (!string.IsNullOrWhiteSpace(sig) && - ComponentCatalog.TryCreateInstance(env, out tmp, sig, "", args)) - { - result = tmp as TRes; - if (result != null) - { - Done(); - return true; - } - // REVIEW: Should this fall through? - } - _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); + EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog); - string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header); - if (!string.IsNullOrWhiteSpace(sigAlt) && - ComponentCatalog.TryCreateInstance(env, out tmp, sigAlt, "", args)) + object tmp; + string sig = ModelHeader.GetLoaderSig(ref Header); + if (!string.IsNullOrWhiteSpace(sig) && + ComponentCatalog.TryCreateInstance(env, out tmp, sig, "", args)) + { + result = tmp as TRes; + if (result != null) { - result = tmp as TRes; - if (result != null) - { - Done(); - return true; - } - // REVIEW: Should this fall through? + Done(); + return true; } - _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); - - Reader.BaseStream.Position = FpMin; - result = null; - return false; + // REVIEW: Should this fall through? } + _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); - private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog) + string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header); + if (!string.IsNullOrWhiteSpace(sigAlt) && + ComponentCatalog.TryCreateInstance(env, out tmp, sigAlt, "", args)) { - if (!string.IsNullOrEmpty(LoaderAssemblyName)) + result = tmp as TRes; + if (result != null) { - var assembly = Assembly.Load(ForwardedLoaderAssemblyName); - catalog.RegisterAssembly(assembly); + Done(); + return true; } + // REVIEW: Should this fall through? } + _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); - private static object[] ConcatArgsRev(object[] args2, params object[] args1) + Reader.BaseStream.Position = FpMin; + result = null; + return false; + } + + private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog) + { + if (!string.IsNullOrEmpty(LoaderAssemblyName)) { - Contracts.AssertNonEmpty(args1); - return Utils.Concat(args1, args2); + var assembly = Assembly.Load(ForwardedLoaderAssemblyName); + catalog.RegisterAssembly(assembly); } + } - /// - /// Try to load a sub model from the given sub directory. This requires InRepository to be true. - /// - public bool TryProcessSubModel(string dir, Action action) - { - _ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream"); - _ectx.CheckNonEmpty(dir, nameof(dir)); - _ectx.CheckValue(action, nameof(action)); + private static object[] ConcatArgsRev(object[] args2, params object[] args1) + { + Contracts.AssertNonEmpty(args1); + return Utils.Concat(args1, args2); + } - string path = Path.Combine(Directory, dir); - var ent = Repository.OpenEntryOrNull(path, ModelStreamName); - if (ent == null) - return false; + /// + /// Try to load a sub model from the given sub directory. This requires InRepository to be true. + /// + public bool TryProcessSubModel(string dir, Action action) + { + _ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream"); + _ectx.CheckNonEmpty(dir, nameof(dir)); + _ectx.CheckValue(action, nameof(action)); - using (ent) - { - // Provide the repository, entry, and directory name to the loadable class ctor. - _ectx.Assert(ent.Stream.Position == 0); - using (var ctx = new ModelLoadContext(Repository, ent, path)) - action(ctx); - } - return true; - } + string path = Path.Combine(Directory, dir); + var ent = Repository.OpenEntryOrNull(path, ModelStreamName); + if (ent == null) + return false; - /// - /// Try to load a binary stream from the current directory. This requires InRepository to be true. - /// - public bool TryLoadBinaryStream(string name, Action action) + using (ent) { - _ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream"); - _ectx.CheckNonEmpty(name, nameof(name)); - _ectx.CheckValue(action, nameof(action)); + // Provide the repository, entry, and directory name to the loadable class ctor. + _ectx.Assert(ent.Stream.Position == 0); + using (var ctx = new ModelLoadContext(Repository, ent, path)) + action(ctx); + } + return true; + } + + /// + /// Try to load a binary stream from the current directory. This requires InRepository to be true. + /// + public bool TryLoadBinaryStream(string name, Action action) + { + _ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream"); + _ectx.CheckNonEmpty(name, nameof(name)); + _ectx.CheckValue(action, nameof(action)); - var ent = Repository.OpenEntryOrNull(Directory, name); - if (ent == null) - return false; + var ent = Repository.OpenEntryOrNull(Directory, name); + if (ent == null) + return false; - using (ent) - using (var reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true)) - { - action(reader); - } - return true; + using (ent) + using (var reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true)) + { + action(reader); } + return true; + } - /// - /// Try to load a text stream from the current directory. This requires InRepository to be true. - /// - public bool TryLoadTextStream(string name, Action action) - { - _ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream"); - _ectx.CheckNonEmpty(name, nameof(name)); - _ectx.CheckValue(action, nameof(action)); + /// + /// Try to load a text stream from the current directory. This requires InRepository to be true. + /// + public bool TryLoadTextStream(string name, Action action) + { + _ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream"); + _ectx.CheckNonEmpty(name, nameof(name)); + _ectx.CheckValue(action, nameof(action)); - var ent = Repository.OpenEntryOrNull(Directory, name); - if (ent == null) - return false; + var ent = Repository.OpenEntryOrNull(Directory, name); + if (ent == null) + return false; - using (ent) - using (var reader = new StreamReader(ent.Stream)) - { - action(reader); - } - return true; + using (ent) + using (var reader = new StreamReader(ent.Stream)) + { + action(reader); } + return true; } } diff --git a/src/Microsoft.ML.Core/Data/ModelSaveContext.cs b/src/Microsoft.ML.Core/Data/ModelSaveContext.cs index 245e788702..21b3926885 100644 --- a/src/Microsoft.ML.Core/Data/ModelSaveContext.cs +++ b/src/Microsoft.ML.Core/Data/ModelSaveContext.cs @@ -8,253 +8,252 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// Convenience context object for saving models to a repository, for +/// implementors of . +/// +/// +/// This class reduces the amount of boiler plate code needed to implement . +/// It can also be used when saving to a single stream, for implementors of . +/// +public sealed partial class ModelSaveContext : IDisposable { /// - /// Convenience context object for saving models to a repository, for - /// implementors of . + /// When in repository mode, this is the repository we're writing to. It is null when + /// in single-stream mode. /// - /// - /// This class reduces the amount of boiler plate code needed to implement . - /// It can also be used when saving to a single stream, for implementors of . - /// - public sealed partial class ModelSaveContext : IDisposable + [BestFriend] + internal readonly RepositoryWriter Repository; + + /// + /// When in repository mode, this is the directory we're reading from. Null means the root + /// of the repository. It is always null in single-stream mode. + /// + [BestFriend] + internal readonly string Directory; + + /// + /// The main stream writer. + /// + [BestFriend] + internal readonly BinaryWriter Writer; + + /// + /// The strings that will be saved in the main stream's string table. + /// + [BestFriend] + internal readonly NormStr.Pool Strings; + + /// + /// The main stream's model header. + /// + [BestFriend] + internal ModelHeader Header; + + /// + /// The min file position of the main stream. + /// + [BestFriend] + internal readonly long FpMin; + + /// + /// The wrapped entry. + /// + private readonly Repository.Entry _ent; + + /// + /// Exception context provided by Repository (can be null). + /// + private readonly IExceptionContext _ectx; + + /// + /// The assembly name where the loader resides. + /// + private string _loaderAssemblyName; + + /// + /// Returns whether this context is in repository mode (true) or single-stream mode (false). + /// + [BestFriend] + internal bool InRepository => Repository != null; + + /// + /// Create a supporting saving to a repository, for implementors of . + /// + internal ModelSaveContext(RepositoryWriter rep, string dir, string name) { - /// - /// When in repository mode, this is the repository we're writing to. It is null when - /// in single-stream mode. - /// - [BestFriend] - internal readonly RepositoryWriter Repository; - - /// - /// When in repository mode, this is the directory we're reading from. Null means the root - /// of the repository. It is always null in single-stream mode. - /// - [BestFriend] - internal readonly string Directory; - - /// - /// The main stream writer. - /// - [BestFriend] - internal readonly BinaryWriter Writer; - - /// - /// The strings that will be saved in the main stream's string table. - /// - [BestFriend] - internal readonly NormStr.Pool Strings; - - /// - /// The main stream's model header. - /// - [BestFriend] - internal ModelHeader Header; - - /// - /// The min file position of the main stream. - /// - [BestFriend] - internal readonly long FpMin; - - /// - /// The wrapped entry. - /// - private readonly Repository.Entry _ent; - - /// - /// Exception context provided by Repository (can be null). - /// - private readonly IExceptionContext _ectx; - - /// - /// The assembly name where the loader resides. - /// - private string _loaderAssemblyName; - - /// - /// Returns whether this context is in repository mode (true) or single-stream mode (false). - /// - [BestFriend] - internal bool InRepository => Repository != null; - - /// - /// Create a supporting saving to a repository, for implementors of . - /// - internal ModelSaveContext(RepositoryWriter rep, string dir, string name) - { - Contracts.CheckValue(rep, nameof(rep)); - Repository = rep; - _ectx = rep.ExceptionContext; + Contracts.CheckValue(rep, nameof(rep)); + Repository = rep; + _ectx = rep.ExceptionContext; - _ectx.CheckValueOrNull(dir); - _ectx.CheckNonEmpty(name, nameof(name)); + _ectx.CheckValueOrNull(dir); + _ectx.CheckNonEmpty(name, nameof(name)); - Directory = dir; - Strings = new NormStr.Pool(); + Directory = dir; + Strings = new NormStr.Pool(); - _ent = rep.CreateEntry(dir, name); + _ent = rep.CreateEntry(dir, name); + try + { + Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true); try { - Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true); - try - { - ModelHeader.BeginWrite(Writer, out FpMin, out Header); - } - catch - { - Writer.Dispose(); - throw; - } + ModelHeader.BeginWrite(Writer, out FpMin, out Header); } catch { - _ent.Dispose(); + Writer.Dispose(); throw; } } - - /// - /// Create a supporting saving to a single-stream, for implementors of . - /// - internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null) + catch { - Contracts.AssertValueOrNull(ectx); - _ectx = ectx; - _ectx.CheckValue(writer, nameof(writer)); + _ent.Dispose(); + throw; + } + } - Repository = null; - Directory = null; - _ent = null; + /// + /// Create a supporting saving to a single-stream, for implementors of . + /// + internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null) + { + Contracts.AssertValueOrNull(ectx); + _ectx = ectx; + _ectx.CheckValue(writer, nameof(writer)); - Strings = new NormStr.Pool(); - Writer = writer; - ModelHeader.BeginWrite(Writer, out FpMin, out Header); - } + Repository = null; + Directory = null; + _ent = null; - [BestFriend] - internal void CheckAtModel() - { - _ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel); - } + Strings = new NormStr.Pool(); + Writer = writer; + ModelHeader.BeginWrite(Writer, out FpMin, out Header); + } - /// - /// Set the version information in the main stream's header. This should be called before - /// is called. - /// - /// - [BestFriend] - internal void SetVersionInfo(VersionInfo ver) - { - ModelHeader.SetVersionInfo(ref Header, ver); - _loaderAssemblyName = ver.LoaderAssemblyName; - } + [BestFriend] + internal void CheckAtModel() + { + _ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel); + } - [BestFriend] - internal void SaveTextStream(string name, Action action) - { - _ectx.Check(InRepository, "Can't save a text stream when writing to a single stream"); - _ectx.CheckNonEmpty(name, nameof(name)); - _ectx.CheckValue(action, nameof(action)); - - // I verified in the CLR source that the default buffer size is 1024. It's unfortunate - // that to set leaveOpen to true, we have to specify the buffer size.... - using (var ent = Repository.CreateEntry(Directory, name)) - using (var writer = Utils.OpenWriter(ent.Stream)) - { - action(writer); - } - } + /// + /// Set the version information in the main stream's header. This should be called before + /// is called. + /// + /// + [BestFriend] + internal void SetVersionInfo(VersionInfo ver) + { + ModelHeader.SetVersionInfo(ref Header, ver); + _loaderAssemblyName = ver.LoaderAssemblyName; + } - [BestFriend] - internal void SaveBinaryStream(string name, Action action) + [BestFriend] + internal void SaveTextStream(string name, Action action) + { + _ectx.Check(InRepository, "Can't save a text stream when writing to a single stream"); + _ectx.CheckNonEmpty(name, nameof(name)); + _ectx.CheckValue(action, nameof(action)); + + // I verified in the CLR source that the default buffer size is 1024. It's unfortunate + // that to set leaveOpen to true, we have to specify the buffer size.... + using (var ent = Repository.CreateEntry(Directory, name)) + using (var writer = Utils.OpenWriter(ent.Stream)) { - _ectx.Check(InRepository, "Can't save a text stream when writing to a single stream"); - _ectx.CheckNonEmpty(name, nameof(name)); - _ectx.CheckValue(action, nameof(action)); - - // I verified in the CLR source that the default buffer size is 1024. It's unfortunate - // that to set leaveOpen to true, we have to specify the buffer size.... - using (var ent = Repository.CreateEntry(Directory, name)) - using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true)) - { - action(writer); - } + action(writer); } + } - /// - /// Puts a string into the context pool, and writes the integer code of the string ID - /// to the write stream. If str is null, this writes -1 and doesn't add it to the pool. - /// - [BestFriend] - internal void SaveStringOrNull(string str) + [BestFriend] + internal void SaveBinaryStream(string name, Action action) + { + _ectx.Check(InRepository, "Can't save a text stream when writing to a single stream"); + _ectx.CheckNonEmpty(name, nameof(name)); + _ectx.CheckValue(action, nameof(action)); + + // I verified in the CLR source that the default buffer size is 1024. It's unfortunate + // that to set leaveOpen to true, we have to specify the buffer size.... + using (var ent = Repository.CreateEntry(Directory, name)) + using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true)) { - if (str == null) - Writer.Write(-1); - else - Writer.Write(Strings.Add(str).Id); + action(writer); } + } - /// - /// Puts a string into the context pool, and writes the integer code of the string ID - /// to the write stream. Checks that str is not null. - /// - [BestFriend] - internal void SaveString(string str) - { - _ectx.CheckValue(str, nameof(str)); + /// + /// Puts a string into the context pool, and writes the integer code of the string ID + /// to the write stream. If str is null, this writes -1 and doesn't add it to the pool. + /// + [BestFriend] + internal void SaveStringOrNull(string str) + { + if (str == null) + Writer.Write(-1); + else Writer.Write(Strings.Add(str).Id); - } + } - [BestFriend] - internal void SaveString(ReadOnlyMemory str) - { - Writer.Write(Strings.Add(str).Id); - } + /// + /// Puts a string into the context pool, and writes the integer code of the string ID + /// to the write stream. Checks that str is not null. + /// + [BestFriend] + internal void SaveString(string str) + { + _ectx.CheckValue(str, nameof(str)); + Writer.Write(Strings.Add(str).Id); + } - /// - /// Puts a string into the context pool, and writes the integer code of the string ID - /// to the write stream. - /// - [BestFriend] - internal void SaveNonEmptyString(string str) - { - _ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str)); - Writer.Write(Strings.Add(str).Id); - } + [BestFriend] + internal void SaveString(ReadOnlyMemory str) + { + Writer.Write(Strings.Add(str).Id); + } - [BestFriend] - internal void SaveNonEmptyString(ReadOnlyMemory str) - { - Writer.Write(Strings.Add(str).Id); - } + /// + /// Puts a string into the context pool, and writes the integer code of the string ID + /// to the write stream. + /// + [BestFriend] + internal void SaveNonEmptyString(string str) + { + _ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str)); + Writer.Write(Strings.Add(str).Id); + } - /// - /// Commit the save operation. This completes writing of the main stream. When in repository - /// mode, it disposes (but not ). - /// - [BestFriend] - internal void Done() - { - _ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!"); - ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName); - Dispose(); - } + [BestFriend] + internal void SaveNonEmptyString(ReadOnlyMemory str) + { + Writer.Write(Strings.Add(str).Id); + } - /// - /// When in repository mode, this disposes the Writer (but not the repository). - /// - public void Dispose() - { - _ectx.Assert((_ent == null) == !InRepository); + /// + /// Commit the save operation. This completes writing of the main stream. When in repository + /// mode, it disposes (but not ). + /// + [BestFriend] + internal void Done() + { + _ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!"); + ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName); + Dispose(); + } - // When in single stream mode, we don't own the Writer. - if (InRepository) - { - Writer.Dispose(); - _ent.Dispose(); - } + /// + /// When in repository mode, this disposes the Writer (but not the repository). + /// + public void Dispose() + { + _ectx.Assert((_ent == null) == !InRepository); + + // When in single stream mode, we don't own the Writer. + if (InRepository) + { + Writer.Dispose(); + _ent.Dispose(); } } } diff --git a/src/Microsoft.ML.Core/Data/ModelSaving.cs b/src/Microsoft.ML.Core/Data/ModelSaving.cs index e0f80965dd..b1d1ea19fa 100644 --- a/src/Microsoft.ML.Core/Data/ModelSaving.cs +++ b/src/Microsoft.ML.Core/Data/ModelSaving.cs @@ -7,85 +7,84 @@ using System.Text; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +public sealed partial class ModelSaveContext : IDisposable { - public sealed partial class ModelSaveContext : IDisposable + /// + /// Save a sub model to the given sub directory. This requires to be . + /// + [BestFriend] + internal void SaveModel(T value, string name) + where T : class { - /// - /// Save a sub model to the given sub directory. This requires to be . - /// - [BestFriend] - internal void SaveModel(T value, string name) - where T : class - { - _ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream"); - SaveModel(Repository, value, Path.Combine(Directory ?? "", name)); - } - - /// - /// Save the object by calling TrySaveModel then falling back to .net serialization. - /// - [BestFriend] - internal static void SaveModel(RepositoryWriter rep, T value, string path) - where T : class - { - if (value == null) - return; + _ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream"); + SaveModel(Repository, value, Path.Combine(Directory ?? "", name)); + } - var sm = value as ICanSaveModel; - if (sm != null) - { - using (var ctx = new ModelSaveContext(rep, path, ModelLoadContext.ModelStreamName)) - { - sm.Save(ctx); - ctx.Done(); - } - return; - } + /// + /// Save the object by calling TrySaveModel then falling back to .net serialization. + /// + [BestFriend] + internal static void SaveModel(RepositoryWriter rep, T value, string path) + where T : class + { + if (value == null) + return; - var sb = value as ICanSaveInBinaryFormat; - if (sb != null) + var sm = value as ICanSaveModel; + if (sm != null) + { + using (var ctx = new ModelSaveContext(rep, path, ModelLoadContext.ModelStreamName)) { - using (var ent = rep.CreateEntry(path, ModelLoadContext.NameBinary)) - using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true)) - { - sb.SaveAsBinary(writer); - } + sm.Save(ctx); + ctx.Done(); } + return; } - /// - /// Save to a single-stream by invoking the given action. - /// - [BestFriend] - internal static void Save(BinaryWriter writer, Action fn) + var sb = value as ICanSaveInBinaryFormat; + if (sb != null) { - Contracts.CheckValue(writer, nameof(writer)); - Contracts.CheckValue(fn, nameof(fn)); - - using (var ctx = new ModelSaveContext(writer)) + using (var ent = rep.CreateEntry(path, ModelLoadContext.NameBinary)) + using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true)) { - fn(ctx); - ctx.Done(); + sb.SaveAsBinary(writer); } } + } + + /// + /// Save to a single-stream by invoking the given action. + /// + [BestFriend] + internal static void Save(BinaryWriter writer, Action fn) + { + Contracts.CheckValue(writer, nameof(writer)); + Contracts.CheckValue(fn, nameof(fn)); - /// - /// Save to the given sub directory by invoking the given action. This requires - /// to be . - /// - [BestFriend] - internal void SaveSubModel(string dir, Action fn) + using (var ctx = new ModelSaveContext(writer)) { - _ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream"); - _ectx.CheckNonEmpty(dir, nameof(dir)); - _ectx.CheckValue(fn, nameof(fn)); + fn(ctx); + ctx.Done(); + } + } - using (var ctx = new ModelSaveContext(Repository, Path.Combine(Directory ?? "", dir), ModelLoadContext.ModelStreamName)) - { - fn(ctx); - ctx.Done(); - } + /// + /// Save to the given sub directory by invoking the given action. This requires + /// to be . + /// + [BestFriend] + internal void SaveSubModel(string dir, Action fn) + { + _ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream"); + _ectx.CheckNonEmpty(dir, nameof(dir)); + _ectx.CheckValue(fn, nameof(fn)); + + using (var ctx = new ModelSaveContext(Repository, Path.Combine(Directory ?? "", dir), ModelLoadContext.ModelStreamName)) + { + fn(ctx); + ctx.Done(); } } } diff --git a/src/Microsoft.ML.Core/Data/ProgressReporter.cs b/src/Microsoft.ML.Core/Data/ProgressReporter.cs index e00dd2183d..a3fd950837 100644 --- a/src/Microsoft.ML.Core/Data/ProgressReporter.cs +++ b/src/Microsoft.ML.Core/Data/ProgressReporter.cs @@ -9,115 +9,215 @@ using System.Threading; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +/// +/// The progress reporting classes used by descendants. +/// +[BestFriend] +internal static class ProgressReporting { /// - /// The progress reporting classes used by descendants. + /// The progress channel for . + /// This is coupled with a that aggregates all events and returns them on demand. /// - [BestFriend] - internal static class ProgressReporting + public sealed class ProgressChannel : IProgressChannel { + private readonly IExceptionContext _ectx; + /// - /// The progress channel for . - /// This is coupled with a that aggregates all events and returns them on demand. + /// The pair of (header, fill action) is updated atomically. /// - public sealed class ProgressChannel : IProgressChannel - { - private readonly IExceptionContext _ectx; + private Tuple> _headerAndAction; - /// - /// The pair of (header, fill action) is updated atomically. - /// - private Tuple> _headerAndAction; + /// + /// Normally this should be readonly field, but we want to null it in Dispose to prevent memory leaking. + /// + private ProgressTracker _tracker; - /// - /// Normally this should be readonly field, but we want to null it in Dispose to prevent memory leaking. - /// - private ProgressTracker _tracker; + private readonly ConcurrentDictionary _subChannels; + private volatile int _maxSubId; + private bool _isDisposed; - private readonly ConcurrentDictionary _subChannels; - private volatile int _maxSubId; - private bool _isDisposed; + public string Name { get; } - public string Name { get; } + /// + /// Initialize a for the process identified by . + /// + /// The exception context. + /// The tracker to couple with. + /// The computation name. + public ProgressChannel(IExceptionContext ectx, ProgressTracker tracker, string computationName) + { + Contracts.CheckValueOrNull(ectx); + _ectx = ectx; + _ectx.CheckValue(tracker, nameof(tracker)); + _ectx.CheckNonEmpty(computationName, nameof(computationName)); + + Name = computationName; + _tracker = tracker; + _subChannels = new ConcurrentDictionary(); + _maxSubId = 0; + + _headerAndAction = Tuple.Create>(new ProgressHeader(null), null); + Start(); + } - /// - /// Initialize a for the process identified by . - /// - /// The exception context. - /// The tracker to couple with. - /// The computation name. - public ProgressChannel(IExceptionContext ectx, ProgressTracker tracker, string computationName) - { - Contracts.CheckValueOrNull(ectx); - _ectx = ectx; - _ectx.CheckValue(tracker, nameof(tracker)); - _ectx.CheckNonEmpty(computationName, nameof(computationName)); + public void SetHeader(ProgressHeader header, Action fillAction) + { + _headerAndAction = Tuple.Create(header, fillAction); + } - Name = computationName; - _tracker = tracker; - _subChannels = new ConcurrentDictionary(); - _maxSubId = 0; + public void Checkpoint(params double?[] values) + { + _ectx.AssertValueOrNull(values); + _ectx.Check(!_isDisposed, "Can't report checkpoints after disposing"); + var entry = new ProgressEntry(true, _headerAndAction.Item1); - _headerAndAction = Tuple.Create>(new ProgressHeader(null), null); - Start(); - } + int n = Utils.Size(values); + int iSrc = 0; - public void SetHeader(ProgressHeader header, Action fillAction) + for (int iDst = 0; iDst < entry.Metrics.Length && iSrc < n;) + entry.Metrics[iDst++] = values[iSrc++]; + + for (int iDst = 0; iDst < entry.Progress.Length && iSrc < n;) + entry.Progress[iDst++] = values[iSrc++]; + + for (int iDst = 0; iDst < entry.ProgressLim.Length && iSrc < n;) { - _headerAndAction = Tuple.Create(header, fillAction); + var lim = values[iSrc++]; + if (Double.IsNaN(lim.GetValueOrDefault())) + lim = null; + entry.ProgressLim[iDst++] = lim; } - public void Checkpoint(params double?[] values) - { - _ectx.AssertValueOrNull(values); - _ectx.Check(!_isDisposed, "Can't report checkpoints after disposing"); - var entry = new ProgressEntry(true, _headerAndAction.Item1); + _ectx.Check(iSrc == n, "Too many values provided in Checkpoint"); + _tracker.Log(this, ProgressEvent.EventKind.Progress, entry); + } - int n = Utils.Size(values); - int iSrc = 0; + private void Start() + { + _tracker.Log(this, ProgressEvent.EventKind.Start, null); + } - for (int iDst = 0; iDst < entry.Metrics.Length && iSrc < n;) - entry.Metrics[iDst++] = values[iSrc++]; + private void Stop() + { + _tracker.Log(this, ProgressEvent.EventKind.Stop, null); + } - for (int iDst = 0; iDst < entry.Progress.Length && iSrc < n;) - entry.Progress[iDst++] = values[iSrc++]; + public void Dispose() + { + if (_isDisposed) + return; + _isDisposed = true; + Stop(); + + Contracts.Assert(_subChannels.Count == 0); + // The 'get progress' action could potentially reference additional objects via closures. + // This constitutes a memory leak potential, if the progress tracker object is retained for longer than the operation was running. + _headerAndAction = null; + _tracker = null; + } - for (int iDst = 0; iDst < entry.ProgressLim.Length && iSrc < n;) - { - var lim = values[iSrc++]; - if (Double.IsNaN(lim.GetValueOrDefault())) - lim = null; - entry.ProgressLim[iDst++] = lim; - } + /// + /// Pull the current progress by invoking the fill delegate, if any. + /// + public ProgressEntry GetProgress() + { + // Make sure we get header and action from the same pair, even if outdated. + var cache = _headerAndAction; + var fillAction = cache.Item2; + var entry = new ProgressEntry(false, cache.Item1); - _ectx.Check(iSrc == n, "Too many values provided in Checkpoint"); - _tracker.Log(this, ProgressEvent.EventKind.Progress, entry); - } + if (fillAction == null) + Contracts.Assert(entry.Header.MetricNames.Count == 0 && entry.Header.UnitNames.Count == 0); + else + fillAction(entry); - private void Start() - { - _tracker.Log(this, ProgressEvent.EventKind.Start, null); - } + return BuildJointEntry(entry); + } - private void Stop() + public IProgressChannel StartProgressChannel(string name) + { + return StartProgressChannel(1); + } + + private IProgressChannel StartProgressChannel(int level) + { + var newId = Interlocked.Increment(ref _maxSubId); + return new SubChannel(this, level, newId); + } + + private void SubChannelStopped(int id) + { + SubChannel channel; + _subChannels.TryRemove(id, out channel); + // Duplicate removal is OK, so we don't inspect return value. + } + + private void SubChannelStarted(int id, SubChannel channel) + { + var res = _subChannels.GetOrAdd(id, channel); + Contracts.Assert(res == channel); + } + + private ProgressEntry BuildJointEntry(ProgressEntry rootEntry) + { + if (_maxSubId == 0 || _subChannels.Count == 0) + return rootEntry; + + // REVIEW: consider caching the headers, in case the sub-reporters haven't changed. + // This is not anticipated to be a perf-critical path though. + var hProgress = new List(); + var hMetrics = new List(); + var progress = new List(); + var progressLim = new List(); + var metrics = new List(); + + hProgress.AddRange(rootEntry.Header.UnitNames); + hMetrics.AddRange(rootEntry.Header.MetricNames); + progress.AddRange(rootEntry.Progress); + progressLim.AddRange(rootEntry.ProgressLim); + metrics.AddRange(rootEntry.Metrics); + + foreach (var subChannel in _subChannels.Values.ToArray().OrderBy(x => x.Level)) { - _tracker.Log(this, ProgressEvent.EventKind.Stop, null); + var entry = subChannel.GetProgress(); + hProgress.AddRange(entry.Header.UnitNames); + hMetrics.AddRange(entry.Header.MetricNames); + progress.AddRange(entry.Progress); + progressLim.AddRange(entry.ProgressLim); + metrics.AddRange(entry.Metrics); } - public void Dispose() - { - if (_isDisposed) - return; - _isDisposed = true; - Stop(); + var jointEntry = new ProgressEntry(false, new ProgressHeader(hMetrics.ToArray(), hProgress.ToArray())); + progress.CopyTo(jointEntry.Progress); + progressLim.CopyTo(jointEntry.ProgressLim); + metrics.CopyTo(jointEntry.Metrics); + return jointEntry; + } - Contracts.Assert(_subChannels.Count == 0); - // The 'get progress' action could potentially reference additional objects via closures. - // This constitutes a memory leak potential, if the progress tracker object is retained for longer than the operation was running. - _headerAndAction = null; - _tracker = null; - } + /// + /// This is a 'derived' or 'subordinate' progress channel. + /// + /// The subordinates' Start/Stop events and checkpoints will not be propagated. + /// When the status is requested, all of the subordinate channels are also invoked, + /// and the resulting metrics are then returned in the order of their 'subordinate level'. + /// If there's more than one channel with the same level, the order is not defined. + /// + private sealed class SubChannel : IProgressChannel + { + private readonly ProgressChannel _root; + private readonly int _id; + // The 'depth' of subordinate. + private readonly int _level; + + /// + /// The pair of (header, fill action) is updated atomically. + /// + private Tuple> _headerAndAction; + + public int Level { get { return _level; } } /// /// Pull the current progress by invoking the fill delegate, if any. @@ -133,482 +233,381 @@ public ProgressEntry GetProgress() Contracts.Assert(entry.Header.MetricNames.Count == 0 && entry.Header.UnitNames.Count == 0); else fillAction(entry); + return entry; + } - return BuildJointEntry(entry); + public SubChannel(ProgressChannel root, int id, int level) + { + Contracts.AssertValue(root); + Contracts.Assert(level >= 0); + _root = root; + _id = id; + _level = level; + _headerAndAction = Tuple.Create>(new ProgressHeader(null), null); + Start(); } public IProgressChannel StartProgressChannel(string name) { - return StartProgressChannel(1); + return _root.StartProgressChannel(_level + 1); } - private IProgressChannel StartProgressChannel(int level) + public void Dispose() { - var newId = Interlocked.Increment(ref _maxSubId); - return new SubChannel(this, level, newId); + Stop(); } - private void SubChannelStopped(int id) + public void SetHeader(ProgressHeader header, Action fillAction) { - SubChannel channel; - _subChannels.TryRemove(id, out channel); - // Duplicate removal is OK, so we don't inspect return value. + _headerAndAction = Tuple.Create(header, fillAction); } - private void SubChannelStarted(int id, SubChannel channel) + private void Start() { - var res = _subChannels.GetOrAdd(id, channel); - Contracts.Assert(res == channel); + _root.SubChannelStarted(_id, this); } - private ProgressEntry BuildJointEntry(ProgressEntry rootEntry) + private void Stop() { - if (_maxSubId == 0 || _subChannels.Count == 0) - return rootEntry; - - // REVIEW: consider caching the headers, in case the sub-reporters haven't changed. - // This is not anticipated to be a perf-critical path though. - var hProgress = new List(); - var hMetrics = new List(); - var progress = new List(); - var progressLim = new List(); - var metrics = new List(); - - hProgress.AddRange(rootEntry.Header.UnitNames); - hMetrics.AddRange(rootEntry.Header.MetricNames); - progress.AddRange(rootEntry.Progress); - progressLim.AddRange(rootEntry.ProgressLim); - metrics.AddRange(rootEntry.Metrics); - - foreach (var subChannel in _subChannels.Values.ToArray().OrderBy(x => x.Level)) - { - var entry = subChannel.GetProgress(); - hProgress.AddRange(entry.Header.UnitNames); - hMetrics.AddRange(entry.Header.MetricNames); - progress.AddRange(entry.Progress); - progressLim.AddRange(entry.ProgressLim); - metrics.AddRange(entry.Metrics); - } - - var jointEntry = new ProgressEntry(false, new ProgressHeader(hMetrics.ToArray(), hProgress.ToArray())); - progress.CopyTo(jointEntry.Progress); - progressLim.CopyTo(jointEntry.ProgressLim); - metrics.CopyTo(jointEntry.Metrics); - return jointEntry; + _root.SubChannelStopped(_id); } - /// - /// This is a 'derived' or 'subordinate' progress channel. - /// - /// The subordinates' Start/Stop events and checkpoints will not be propagated. - /// When the status is requested, all of the subordinate channels are also invoked, - /// and the resulting metrics are then returned in the order of their 'subordinate level'. - /// If there's more than one channel with the same level, the order is not defined. - /// - private sealed class SubChannel : IProgressChannel + public void Checkpoint(params Double?[] values) { - private readonly ProgressChannel _root; - private readonly int _id; - // The 'depth' of subordinate. - private readonly int _level; - - /// - /// The pair of (header, fill action) is updated atomically. - /// - private Tuple> _headerAndAction; - - public int Level { get { return _level; } } - - /// - /// Pull the current progress by invoking the fill delegate, if any. - /// - public ProgressEntry GetProgress() - { - // Make sure we get header and action from the same pair, even if outdated. - var cache = _headerAndAction; - var fillAction = cache.Item2; - var entry = new ProgressEntry(false, cache.Item1); - - if (fillAction == null) - Contracts.Assert(entry.Header.MetricNames.Count == 0 && entry.Header.UnitNames.Count == 0); - else - fillAction(entry); - return entry; - } - - public SubChannel(ProgressChannel root, int id, int level) - { - Contracts.AssertValue(root); - Contracts.Assert(level >= 0); - _root = root; - _id = id; - _level = level; - _headerAndAction = Tuple.Create>(new ProgressHeader(null), null); - Start(); - } - - public IProgressChannel StartProgressChannel(string name) - { - return _root.StartProgressChannel(_level + 1); - } + // We are ignoring all checkpoints from subordinates. + // REVIEW: maybe this could be changed in the future. Right now it seems that + // this limitation is reasonable. + } + } + } - public void Dispose() - { - Stop(); - } + /// + /// This class listens to the progress reporting channels, caches all checkpoints and + /// start/stop events and, on demand, requests current progress on all active calculations. + /// + /// The public methods of this class should only be called from one thread. + /// + public sealed class ProgressTracker + { + private readonly IExceptionContext _ectx; + private readonly object _lock; - public void SetHeader(ProgressHeader header, Action fillAction) - { - _headerAndAction = Tuple.Create(header, fillAction); - } + /// + /// Log of pending events. + /// + private readonly ConcurrentQueue _pendingEvents; - private void Start() - { - _root.SubChannelStarted(_id, this); - } + /// + /// For each calculation, its properties. + /// This list is protected by , and it's updated every time a new calculation starts. + /// The entries are cleaned up when the start and stop events are reported (that is, after the first + /// pull request after the calculation's 'Stop' event). + /// + private readonly List _infos; - private void Stop() - { - _root.SubChannelStopped(_id); - } + /// + /// This is a 'process index' that gets incremented whenever a new calculation is started. + /// + private int _index; - public void Checkpoint(params Double?[] values) - { - // We are ignoring all checkpoints from subordinates. - // REVIEW: maybe this could be changed in the future. Right now it seems that - // this limitation is reasonable. - } - } - } + /// + /// The set of used process names. + /// + private readonly HashSet _namesUsed; /// - /// This class listens to the progress reporting channels, caches all checkpoints and - /// start/stop events and, on demand, requests current progress on all active calculations. + /// This class is an 'event log' for one calculation. /// - /// The public methods of this class should only be called from one thread. + /// Every time a calculation is 'started', it gets its own log, so if there are multiple 'start' calls, + /// there will be multiple logs. /// - public sealed class ProgressTracker + private sealed class CalculationInfo { - private readonly IExceptionContext _ectx; - private readonly object _lock; - /// - /// Log of pending events. + /// Auto-assigned index to serve as a unique ID. /// - private readonly ConcurrentQueue _pendingEvents; + public readonly int Index; /// - /// For each calculation, its properties. - /// This list is protected by , and it's updated every time a new calculation starts. - /// The entries are cleaned up when the start and stop events are reported (that is, after the first - /// pull request after the calculation's 'Stop' event). + /// Name is auto-modified from the calculation name provided by the pipe. /// - private readonly List _infos; + public readonly string Name; - /// - /// This is a 'process index' that gets incremented whenever a new calculation is started. - /// - private int _index; + public readonly DateTime StartTime; + + public readonly ProgressChannel Channel; /// - /// The set of used process names. + /// A log of pending checkpoint entries. /// - private readonly HashSet _namesUsed; + public readonly ConcurrentQueue> PendingCheckpoints; /// - /// This class is an 'event log' for one calculation. - /// - /// Every time a calculation is 'started', it gets its own log, so if there are multiple 'start' calls, - /// there will be multiple logs. + /// Whether the calculation has finished. /// - private sealed class CalculationInfo - { - /// - /// Auto-assigned index to serve as a unique ID. - /// - public readonly int Index; - - /// - /// Name is auto-modified from the calculation name provided by the pipe. - /// - public readonly string Name; - - public readonly DateTime StartTime; + public bool IsFinished; - public readonly ProgressChannel Channel; - - /// - /// A log of pending checkpoint entries. - /// - public readonly ConcurrentQueue> PendingCheckpoints; - - /// - /// Whether the calculation has finished. - /// - public bool IsFinished; + public CalculationInfo(int index, string name, ProgressChannel channel) + { + Contracts.Assert(index > 0); + Contracts.AssertNonEmpty(name); + Contracts.AssertValue(channel); - public CalculationInfo(int index, string name, ProgressChannel channel) - { - Contracts.Assert(index > 0); - Contracts.AssertNonEmpty(name); - Contracts.AssertValue(channel); - - Index = index; - Name = name; - PendingCheckpoints = new ConcurrentQueue>(); - StartTime = DateTime.UtcNow; - Channel = channel; - } + Index = index; + Name = name; + PendingCheckpoints = new ConcurrentQueue>(); + StartTime = DateTime.UtcNow; + Channel = channel; } + } - public ProgressTracker(IExceptionContext ectx) - { - Contracts.CheckValue(ectx, nameof(ectx)); - _ectx = ectx; - _lock = new object(); - _pendingEvents = new ConcurrentQueue(); - _infos = new List(); - _namesUsed = new HashSet(); - } + public ProgressTracker(IExceptionContext ectx) + { + Contracts.CheckValue(ectx, nameof(ectx)); + _ectx = ectx; + _lock = new object(); + _pendingEvents = new ConcurrentQueue(); + _infos = new List(); + _namesUsed = new HashSet(); + } - public void Log(ProgressChannel source, ProgressEvent.EventKind kind, ProgressEntry entry) - { - _ectx.AssertValue(source); - _ectx.AssertValueOrNull(entry); + public void Log(ProgressChannel source, ProgressEvent.EventKind kind, ProgressEntry entry) + { + _ectx.AssertValue(source); + _ectx.AssertValueOrNull(entry); - if (kind == ProgressEvent.EventKind.Start) + if (kind == ProgressEvent.EventKind.Start) + { + _ectx.Assert(entry == null); + lock (_lock) { - _ectx.Assert(entry == null); - lock (_lock) + // Figure out an appropriate name. + int i = 1; + var name = source.Name; + string nameCandidate = name; + while (!_namesUsed.Add(nameCandidate)) { - // Figure out an appropriate name. - int i = 1; - var name = source.Name; - string nameCandidate = name; - while (!_namesUsed.Add(nameCandidate)) - { - i++; - nameCandidate = string.Format("{0} #{1}", name, i); - } - var newInfo = new CalculationInfo(++_index, nameCandidate, source); - _infos.Add(newInfo); - _pendingEvents.Enqueue(new ProgressEvent(newInfo.Index, newInfo.Name, newInfo.StartTime, ProgressEvent.EventKind.Start)); - return; + i++; + nameCandidate = string.Format("{0} #{1}", name, i); } - } - - // Not a start event, so we won't modify the _infos. - CalculationInfo info; - lock (_lock) - { - info = _infos.FirstOrDefault(x => x.Channel == source); - if (info == null) - throw _ectx.Except("Event sent after the calculation lifetime expired."); - } - switch (kind) - { - case ProgressEvent.EventKind.Stop: - _ectx.Assert(entry == null); - info.IsFinished = true; - _pendingEvents.Enqueue(new ProgressEvent(info.Index, info.Name, info.StartTime, ProgressEvent.EventKind.Stop)); - break; - default: - _ectx.Assert(entry != null); - _ectx.Assert(kind == ProgressEvent.EventKind.Progress); - _ectx.Assert(!info.IsFinished); - _pendingEvents.Enqueue(new ProgressEvent(info.Index, info.Name, info.StartTime, entry)); - break; + var newInfo = new CalculationInfo(++_index, nameCandidate, source); + _infos.Add(newInfo); + _pendingEvents.Enqueue(new ProgressEvent(newInfo.Index, newInfo.Name, newInfo.StartTime, ProgressEvent.EventKind.Start)); + return; } } - /// - /// Get progress reports from all current calculations. - /// For every calculation the following events will be returned: - /// * A start event. - /// * Each checkpoint. - /// * If the calculation is finished, the stop event. - /// - /// Each of the above events will be returned exactly once. - /// If, for one calculation, there's no events in the above categories, the tracker will - /// request ('pull') the current progress and return this as an event. - /// - public List GetAllProgress() + // Not a start event, so we won't modify the _infos. + CalculationInfo info; + lock (_lock) { - var list = new List(); - var seen = new HashSet(); - ProgressEvent cur; - while (_pendingEvents.TryDequeue(out cur)) - { - seen.Add(cur.Index); - list.Add(cur); - } - - // Get unseen calculations to pull progress from. - CalculationInfo[] unseen; - lock (_lock) - { - unseen = _infos.Where(x => !seen.Contains(x.Index)).ToArray(); - _infos.RemoveAll(x => x.IsFinished); - } - - foreach (var info in unseen) - { - // The calculation might finish while we're inside the GetAllProgress. We will report the finish - // event in the next status, but we make a half-hearted effort not to call the delegate on a finished - // calculation. - if (info.IsFinished) - continue; - - var entry = info.Channel.GetProgress(); - list.Add(new ProgressEvent(info.Index, info.Name, info.StartTime, entry)); - } - - return list; + info = _infos.FirstOrDefault(x => x.Channel == source); + if (info == null) + throw _ectx.Except("Event sent after the calculation lifetime expired."); } - - public void Reset() + switch (kind) { - lock (_lock) - { - while (!_pendingEvents.IsEmpty) - _pendingEvents.TryDequeue(out var res); - _namesUsed.Clear(); - _index = 0; - } + case ProgressEvent.EventKind.Stop: + _ectx.Assert(entry == null); + info.IsFinished = true; + _pendingEvents.Enqueue(new ProgressEvent(info.Index, info.Name, info.StartTime, ProgressEvent.EventKind.Stop)); + break; + default: + _ectx.Assert(entry != null); + _ectx.Assert(kind == ProgressEvent.EventKind.Progress); + _ectx.Assert(!info.IsFinished); + _pendingEvents.Enqueue(new ProgressEvent(info.Index, info.Name, info.StartTime, entry)); + break; } } /// - /// An array-backed implementation of . + /// Get progress reports from all current calculations. + /// For every calculation the following events will be returned: + /// * A start event. + /// * Each checkpoint. + /// * If the calculation is finished, the stop event. + /// + /// Each of the above events will be returned exactly once. + /// If, for one calculation, there's no events in the above categories, the tracker will + /// request ('pull') the current progress and return this as an event. /// - public sealed class ProgressEntry : IProgressEntry + public List GetAllProgress() { - /// - /// The header (names of metrics and units). - /// The contents of the header should be treated as read-only. The calculation itself doesn't even - /// need to access the header, since it will know it anyway. - /// - public readonly ProgressHeader Header; - - /// - /// Whether the progress entry is a 'checkpoint' (that is, it's being pushed by the component). - /// - public readonly bool IsCheckpoint; - - /// - /// The actual progress (amount of completed units), in the units that are contained in the header. - /// Parallel to the header's . Null value indicates 'not applicable now'. - /// - /// The computation should not modify these arrays directly, and instead rely on , - /// and . - /// - public readonly Double?[] Progress; - - /// - /// The lim values of each progress unit. - /// Parallel to the header's . Null value indicates unbounded or unknown. - /// - public readonly Double?[] ProgressLim; - - /// - /// The reported metrics. Parallel to the header's . - /// Null value indicates unknown. - /// - public readonly Double?[] Metrics; - - /// - /// Set the progress value for the index to , - /// and the limit value for the progress becomes 'unknown'. - /// - public void SetProgress(int index, Double value) + var list = new List(); + var seen = new HashSet(); + ProgressEvent cur; + while (_pendingEvents.TryDequeue(out cur)) { - Contracts.Check(0 <= index && index < Progress.Length); - Progress[index] = value; - ProgressLim[index] = null; + seen.Add(cur.Index); + list.Add(cur); } - /// - /// Set the progress value for the index to , - /// and the limit value to . - /// - public void SetProgress(int index, Double value, Double lim) + // Get unseen calculations to pull progress from. + CalculationInfo[] unseen; + lock (_lock) { - Contracts.Check(0 <= index && index < Progress.Length); - Contracts.Assert(0 <= index && index < Progress.Length); - Progress[index] = value; - ProgressLim[index] = Double.IsNaN(lim) ? (Double?)null : lim; + unseen = _infos.Where(x => !seen.Contains(x.Index)).ToArray(); + _infos.RemoveAll(x => x.IsFinished); } - /// - /// Sets the metric with index to . - /// - public void SetMetric(int index, Double value) + foreach (var info in unseen) { - Contracts.Check(0 <= index && index < Metrics.Length); - Metrics[index] = value; + // The calculation might finish while we're inside the GetAllProgress. We will report the finish + // event in the next status, but we make a half-hearted effort not to call the delegate on a finished + // calculation. + if (info.IsFinished) + continue; + + var entry = info.Channel.GetProgress(); + list.Add(new ProgressEvent(info.Index, info.Name, info.StartTime, entry)); } - /// - /// Creates the progress entry corresponding to a given header. - /// - public ProgressEntry(bool isCheckpoint, ProgressHeader header) + return list; + } + + public void Reset() + { + lock (_lock) { - Contracts.CheckValue(header, nameof(header)); - Header = header; - IsCheckpoint = isCheckpoint; - Progress = new Double?[header.UnitNames.Count]; - ProgressLim = new Double?[header.UnitNames.Count]; - Metrics = new Double?[header.MetricNames.Count]; + while (!_pendingEvents.IsEmpty) + _pendingEvents.TryDequeue(out var res); + _namesUsed.Clear(); + _index = 0; } } + } + + /// + /// An array-backed implementation of . + /// + public sealed class ProgressEntry : IProgressEntry + { + /// + /// The header (names of metrics and units). + /// The contents of the header should be treated as read-only. The calculation itself doesn't even + /// need to access the header, since it will know it anyway. + /// + public readonly ProgressHeader Header; /// - /// An event about calculation progress. It could be either start/stop of the calculation, or a progress entry. + /// Whether the progress entry is a 'checkpoint' (that is, it's being pushed by the component). /// - public sealed class ProgressEvent + public readonly bool IsCheckpoint; + + /// + /// The actual progress (amount of completed units), in the units that are contained in the header. + /// Parallel to the header's . Null value indicates 'not applicable now'. + /// + /// The computation should not modify these arrays directly, and instead rely on , + /// and . + /// + public readonly Double?[] Progress; + + /// + /// The lim values of each progress unit. + /// Parallel to the header's . Null value indicates unbounded or unknown. + /// + public readonly Double?[] ProgressLim; + + /// + /// The reported metrics. Parallel to the header's . + /// Null value indicates unknown. + /// + public readonly Double?[] Metrics; + + /// + /// Set the progress value for the index to , + /// and the limit value for the progress becomes 'unknown'. + /// + public void SetProgress(int index, Double value) { - // REVIEW: Separate kind for checkpoint? - public enum EventKind - { - Start, - Progress, - Stop - } + Contracts.Check(0 <= index && index < Progress.Length); + Progress[index] = value; + ProgressLim[index] = null; + } - public readonly int Index; - public readonly string Name; - // REVIEW: Maybe switch to the stopwatch-based wall clock? - public readonly DateTime StartTime; - public readonly DateTime EventTime; - public readonly EventKind Kind; - public readonly ProgressEntry ProgressEntry; + /// + /// Set the progress value for the index to , + /// and the limit value to . + /// + public void SetProgress(int index, Double value, Double lim) + { + Contracts.Check(0 <= index && index < Progress.Length); + Contracts.Assert(0 <= index && index < Progress.Length); + Progress[index] = value; + ProgressLim[index] = Double.IsNaN(lim) ? (Double?)null : lim; + } - public ProgressEvent(int index, string name, DateTime startTime, ProgressEntry entry) - { - Contracts.CheckParam(index >= 0, nameof(index)); - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValue(entry, nameof(entry)); + /// + /// Sets the metric with index to . + /// + public void SetMetric(int index, Double value) + { + Contracts.Check(0 <= index && index < Metrics.Length); + Metrics[index] = value; + } - Index = index; - Name = name; - StartTime = startTime; - EventTime = DateTime.UtcNow; - Kind = EventKind.Progress; - ProgressEntry = entry; - } + /// + /// Creates the progress entry corresponding to a given header. + /// + public ProgressEntry(bool isCheckpoint, ProgressHeader header) + { + Contracts.CheckValue(header, nameof(header)); + Header = header; + IsCheckpoint = isCheckpoint; + Progress = new Double?[header.UnitNames.Count]; + ProgressLim = new Double?[header.UnitNames.Count]; + Metrics = new Double?[header.MetricNames.Count]; + } + } - public ProgressEvent(int index, string name, DateTime startTime, EventKind kind) - { - Contracts.CheckParam(index >= 0, nameof(index)); - Contracts.CheckNonEmpty(name, nameof(name)); + /// + /// An event about calculation progress. It could be either start/stop of the calculation, or a progress entry. + /// + public sealed class ProgressEvent + { + // REVIEW: Separate kind for checkpoint? + public enum EventKind + { + Start, + Progress, + Stop + } - Index = index; - Name = name; - StartTime = startTime; - EventTime = DateTime.UtcNow; - Kind = kind; - ProgressEntry = null; - } + public readonly int Index; + public readonly string Name; + // REVIEW: Maybe switch to the stopwatch-based wall clock? + public readonly DateTime StartTime; + public readonly DateTime EventTime; + public readonly EventKind Kind; + public readonly ProgressEntry ProgressEntry; + + public ProgressEvent(int index, string name, DateTime startTime, ProgressEntry entry) + { + Contracts.CheckParam(index >= 0, nameof(index)); + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(entry, nameof(entry)); + + Index = index; + Name = name; + StartTime = startTime; + EventTime = DateTime.UtcNow; + Kind = EventKind.Progress; + ProgressEntry = entry; + } + + public ProgressEvent(int index, string name, DateTime startTime, EventKind kind) + { + Contracts.CheckParam(index >= 0, nameof(index)); + Contracts.CheckNonEmpty(name, nameof(name)); + + Index = index; + Name = name; + StartTime = startTime; + EventTime = DateTime.UtcNow; + Kind = kind; + ProgressEntry = null; } } } diff --git a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs index 5082f106bf..9f3d0190e2 100644 --- a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs +++ b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs @@ -9,264 +9,263 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +[BestFriend] +internal static class ReadOnlyMemoryUtils { - [BestFriend] - internal static class ReadOnlyMemoryUtils - { - /// - /// Compare equality with the given system string value. - /// - public static bool EqualsStr(string s, ReadOnlyMemory memory) - { - Contracts.CheckValueOrNull(s); + /// + /// Compare equality with the given system string value. + /// + public static bool EqualsStr(string s, ReadOnlyMemory memory) + { + Contracts.CheckValueOrNull(s); - if (s == null) - return memory.Length == 0; + if (s == null) + return memory.Length == 0; - if (s.Length != memory.Length) - return false; + if (s.Length != memory.Length) + return false; - return memory.Span.SequenceEqual(s.AsSpan()); - } + return memory.Span.SequenceEqual(s.AsSpan()); + } - public static IEnumerable> Split(ReadOnlyMemory memory, char[] separators) - { - Contracts.CheckValueOrNull(separators); + public static IEnumerable> Split(ReadOnlyMemory memory, char[] separators) + { + Contracts.CheckValueOrNull(separators); - if (memory.IsEmpty) - yield break; + if (memory.IsEmpty) + yield break; - if (separators == null || separators.Length == 0) - { - yield return memory; - yield break; - } + if (separators == null || separators.Length == 0) + { + yield return memory; + yield break; + } - var span = memory.Span; - if (separators.Length == 1) + var span = memory.Span; + if (separators.Length == 1) + { + char chSep = separators[0]; + for (int ichCur = 0; ;) { - char chSep = separators[0]; - for (int ichCur = 0; ;) + int nextSep = span.IndexOf(chSep); + if (nextSep == -1) { - int nextSep = span.IndexOf(chSep); - if (nextSep == -1) - { - yield return memory.Slice(ichCur); - yield break; - } - - yield return memory.Slice(ichCur, nextSep); - - // Skip the separator. - ichCur += nextSep + 1; - span = memory.Slice(ichCur).Span; + yield return memory.Slice(ichCur); + yield break; } + + yield return memory.Slice(ichCur, nextSep); + + // Skip the separator. + ichCur += nextSep + 1; + span = memory.Slice(ichCur).Span; } - else + } + else + { + for (int ichCur = 0; ;) { - for (int ichCur = 0; ;) + int nextSep = span.IndexOfAny(separators); + if (nextSep == -1) { - int nextSep = span.IndexOfAny(separators); - if (nextSep == -1) - { - yield return memory.Slice(ichCur); - yield break; - } - - yield return memory.Slice(ichCur, nextSep); - - // Skip the separator. - ichCur += nextSep + 1; - span = memory.Slice(ichCur).Span; + yield return memory.Slice(ichCur); + yield break; } - } - } - /// - /// Splits on the left-most occurrence of separator and produces the left - /// and right of values. If does not contain the separator character, - /// this returns false and sets to this instance and - /// to the default of value. - /// - public static bool SplitOne(ReadOnlyMemory memory, char separator, out ReadOnlyMemory left, out ReadOnlyMemory right) - { - if (memory.IsEmpty) - { - left = memory; - right = default; - return false; - } + yield return memory.Slice(ichCur, nextSep); - int index = memory.Span.IndexOf(separator); - if (index == -1) - { - left = memory; - right = default; - return false; + // Skip the separator. + ichCur += nextSep + 1; + span = memory.Slice(ichCur).Span; } - - left = memory.Slice(0, index); - right = memory.Slice(index + 1, memory.Length - index - 1); - return true; } + } - /// - /// Splits on the left-most occurrence of an element of separators character array and - /// produces the left and right of values. If does not contain any of the - /// characters in separators, this return false and initializes to this instance - /// and to the default of value. - /// - public static bool SplitOne(ReadOnlyMemory memory, char[] separators, out ReadOnlyMemory left, out ReadOnlyMemory right) + /// + /// Splits on the left-most occurrence of separator and produces the left + /// and right of values. If does not contain the separator character, + /// this returns false and sets to this instance and + /// to the default of value. + /// + public static bool SplitOne(ReadOnlyMemory memory, char separator, out ReadOnlyMemory left, out ReadOnlyMemory right) + { + if (memory.IsEmpty) { - Contracts.CheckValueOrNull(separators); + left = memory; + right = default; + return false; + } - if (memory.IsEmpty || separators == null || separators.Length == 0) - { - left = memory; - right = default; - return false; - } + int index = memory.Span.IndexOf(separator); + if (index == -1) + { + left = memory; + right = default; + return false; + } - int index; - if (separators.Length == 1) - index = memory.Span.IndexOf(separators[0]); - else - index = memory.Span.IndexOfAny(separators); + left = memory.Slice(0, index); + right = memory.Slice(index + 1, memory.Length - index - 1); + return true; + } - if (index == -1) - { - left = memory; - right = default; - return false; - } + /// + /// Splits on the left-most occurrence of an element of separators character array and + /// produces the left and right of values. If does not contain any of the + /// characters in separators, this return false and initializes to this instance + /// and to the default of value. + /// + public static bool SplitOne(ReadOnlyMemory memory, char[] separators, out ReadOnlyMemory left, out ReadOnlyMemory right) + { + Contracts.CheckValueOrNull(separators); - left = memory.Slice(0, index); - right = memory.Slice(index + 1, memory.Length - index - 1); - return true; + if (memory.IsEmpty || separators == null || separators.Length == 0) + { + left = memory; + right = default; + return false; } - /// - /// Returns a of with leading and trailing spaces trimmed. Note that this - /// will remove only spaces, not any form of whitespace. - /// - public static ReadOnlyMemory TrimSpaces(ReadOnlyMemory memory) + int index; + if (separators.Length == 1) + index = memory.Span.IndexOf(separators[0]); + else + index = memory.Span.IndexOfAny(separators); + + if (index == -1) { - if (memory.IsEmpty) - return memory; - - int ichLim = memory.Length; - int ichMin = 0; - var span = memory.Span; - if (span[ichMin] != ' ' && span[ichLim - 1] != ' ') - return memory; - - while (ichMin < ichLim && span[ichMin] == ' ') - ichMin++; - while (ichMin < ichLim && span[ichLim - 1] == ' ') - ichLim--; - return memory.Slice(ichMin, ichLim - ichMin); + left = memory; + right = default; + return false; } - /// - /// Returns a of with leading and trailing whitespace trimmed. - /// - public static ReadOnlyMemory TrimWhiteSpace(ReadOnlyMemory memory) - { - if (memory.IsEmpty) - return memory; + left = memory.Slice(0, index); + right = memory.Slice(index + 1, memory.Length - index - 1); + return true; + } - int ichMin = 0; - int ichLim = memory.Length; - var span = memory.Span; - if (!char.IsWhiteSpace(span[ichMin]) && !char.IsWhiteSpace(span[ichLim - 1])) - return memory; + /// + /// Returns a of with leading and trailing spaces trimmed. Note that this + /// will remove only spaces, not any form of whitespace. + /// + public static ReadOnlyMemory TrimSpaces(ReadOnlyMemory memory) + { + if (memory.IsEmpty) + return memory; + + int ichLim = memory.Length; + int ichMin = 0; + var span = memory.Span; + if (span[ichMin] != ' ' && span[ichLim - 1] != ' ') + return memory; + + while (ichMin < ichLim && span[ichMin] == ' ') + ichMin++; + while (ichMin < ichLim && span[ichLim - 1] == ' ') + ichLim--; + return memory.Slice(ichMin, ichLim - ichMin); + } - while (ichMin < ichLim && char.IsWhiteSpace(span[ichMin])) - ichMin++; - while (ichMin < ichLim && char.IsWhiteSpace(span[ichLim - 1])) - ichLim--; + /// + /// Returns a of with leading and trailing whitespace trimmed. + /// + public static ReadOnlyMemory TrimWhiteSpace(ReadOnlyMemory memory) + { + if (memory.IsEmpty) + return memory; - return memory.Slice(ichMin, ichLim - ichMin); - } + int ichMin = 0; + int ichLim = memory.Length; + var span = memory.Span; + if (!char.IsWhiteSpace(span[ichMin]) && !char.IsWhiteSpace(span[ichLim - 1])) + return memory; - /// - /// Returns a of with trailing whitespace trimmed. - /// - public static ReadOnlyMemory TrimEndWhiteSpace(ReadOnlyMemory memory) - { - if (memory.IsEmpty) - return memory; + while (ichMin < ichLim && char.IsWhiteSpace(span[ichMin])) + ichMin++; + while (ichMin < ichLim && char.IsWhiteSpace(span[ichLim - 1])) + ichLim--; + + return memory.Slice(ichMin, ichLim - ichMin); + } - int ichLim = memory.Length; - var span = memory.Span; - if (!char.IsWhiteSpace(span[ichLim - 1])) - return memory; + /// + /// Returns a of with trailing whitespace trimmed. + /// + public static ReadOnlyMemory TrimEndWhiteSpace(ReadOnlyMemory memory) + { + if (memory.IsEmpty) + return memory; - while (0 < ichLim && char.IsWhiteSpace(span[ichLim - 1])) - ichLim--; + int ichLim = memory.Length; + var span = memory.Span; + if (!char.IsWhiteSpace(span[ichLim - 1])) + return memory; - return memory.Slice(0, ichLim); - } + while (0 < ichLim && char.IsWhiteSpace(span[ichLim - 1])) + ichLim--; - public static void AddLowerCaseToStringBuilder(ReadOnlySpan span, StringBuilder sb) - { - Contracts.CheckValue(sb, nameof(sb)); + return memory.Slice(0, ichLim); + } - if (!span.IsEmpty) + public static void AddLowerCaseToStringBuilder(ReadOnlySpan span, StringBuilder sb) + { + Contracts.CheckValue(sb, nameof(sb)); + + if (!span.IsEmpty) + { + int min = 0; + int j; + for (j = min; j < span.Length; j++) { - int min = 0; - int j; - for (j = min; j < span.Length; j++) + char ch = CharUtils.ToLowerInvariant(span[j]); + if (ch != span[j]) { - char ch = CharUtils.ToLowerInvariant(span[j]); - if (ch != span[j]) - { - sb.AppendSpan(span.Slice(min, j - min)).Append(ch); - min = j + 1; - } + sb.AppendSpan(span.Slice(min, j - min)).Append(ch); + min = j + 1; } - - Contracts.Assert(j == span.Length); - if (min != j) - sb.AppendSpan(span.Slice(min, j - min)); } + + Contracts.Assert(j == span.Length); + if (min != j) + sb.AppendSpan(span.Slice(min, j - min)); } + } - public static StringBuilder AppendMemory(this StringBuilder sb, ReadOnlyMemory memory) - { - Contracts.CheckValue(sb, nameof(sb)); - if (!memory.IsEmpty) - sb.AppendSpan(memory.Span); + public static StringBuilder AppendMemory(this StringBuilder sb, ReadOnlyMemory memory) + { + Contracts.CheckValue(sb, nameof(sb)); + if (!memory.IsEmpty) + sb.AppendSpan(memory.Span); - return sb; - } + return sb; + } - public static StringBuilder AppendSpan(this StringBuilder sb, ReadOnlySpan span) + public static StringBuilder AppendSpan(this StringBuilder sb, ReadOnlySpan span) + { + unsafe { - unsafe + fixed (char* valueChars = &MemoryMarshal.GetReference(span)) { - fixed (char* valueChars = &MemoryMarshal.GetReference(span)) - { - sb.Append(valueChars, span.Length); - } + sb.Append(valueChars, span.Length); } - - return sb; } - public sealed class ReadonlyMemoryCharComparer : IEqualityComparer> + return sb; + } + + public sealed class ReadonlyMemoryCharComparer : IEqualityComparer> + { + public bool Equals(ReadOnlyMemory x, ReadOnlyMemory y) { - public bool Equals(ReadOnlyMemory x, ReadOnlyMemory y) - { - return x.Span.SequenceEqual(y.Span); - } + return x.Span.SequenceEqual(y.Span); + } - public int GetHashCode(ReadOnlyMemory obj) - { - return (int)Hashing.HashString(obj.Span); - } + public int GetHashCode(ReadOnlyMemory obj) + { + return (int)Hashing.HashString(obj.Span); } } } diff --git a/src/Microsoft.ML.Core/Data/Repository.cs b/src/Microsoft.ML.Core/Data/Repository.cs index 7a015a31b9..c3468c831e 100644 --- a/src/Microsoft.ML.Core/Data/Repository.cs +++ b/src/Microsoft.ML.Core/Data/Repository.cs @@ -13,544 +13,543 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// For saving a model into a repository. +/// Classes implementing should do an explicit implementation of . +/// Classes inheriting from a base class should overwrite the function invoked by +/// in that base class, if there is one. +/// +public interface ICanSaveModel { - /// - /// For saving a model into a repository. - /// Classes implementing should do an explicit implementation of . - /// Classes inheriting from a base class should overwrite the function invoked by - /// in that base class, if there is one. - /// - public interface ICanSaveModel - { - void Save(ModelSaveContext ctx); - } + void Save(ModelSaveContext ctx); +} - /// - /// For saving to a single stream. Note that this interface is mostly deprecated in favor of - /// saving more comprehensive and composable "model" objects, via . - /// - [BestFriend] - internal interface ICanSaveInBinaryFormat - { - void SaveAsBinary(BinaryWriter writer); - } +/// +/// For saving to a single stream. Note that this interface is mostly deprecated in favor of +/// saving more comprehensive and composable "model" objects, via . +/// +[BestFriend] +internal interface ICanSaveInBinaryFormat +{ + void SaveAsBinary(BinaryWriter writer); +} - /// - /// Abstraction around a or other hierarchical storage. - /// - [BestFriend] - internal abstract class Repository : IDisposable +/// +/// Abstraction around a or other hierarchical storage. +/// +[BestFriend] +internal abstract class Repository : IDisposable +{ + public sealed class Entry : IDisposable { - public sealed class Entry : IDisposable + // The parent repository. + private Repository _rep; + + /// + /// The relative path of this entry. + /// /// + public string Path { get; } + + /// + /// The stream for this entry. This is either a memory stream or a file stream in + /// the temporary directory. In either case, it is seekable and NOT the actual + /// archive stream. + /// + public Stream Stream { get; } + + internal Entry(Repository rep, string path, Stream stream) { - // The parent repository. - private Repository _rep; - - /// - /// The relative path of this entry. - /// /// - public string Path { get; } - - /// - /// The stream for this entry. This is either a memory stream or a file stream in - /// the temporary directory. In either case, it is seekable and NOT the actual - /// archive stream. - /// - public Stream Stream { get; } - - internal Entry(Repository rep, string path, Stream stream) - { - _rep = rep; - Path = path; - Stream = stream; - } + _rep = rep; + Path = path; + Stream = stream; + } - public void Dispose() + public void Dispose() + { + if (_rep != null) { - if (_rep != null) - { - // Tell the repository that we're disposed. Note that the repository "owns" the stream - // so is in charge of closing it. - _rep.OnDispose(this); - _rep = null; - } + // Tell the repository that we're disposed. Note that the repository "owns" the stream + // so is in charge of closing it. + _rep.OnDispose(this); + _rep = null; } } + } - // These are the open entries that may contain streams into our DirTemp. - private readonly List _open; + // These are the open entries that may contain streams into our DirTemp. + private readonly List _open; - private bool _disposed; + private bool _disposed; - private readonly IExceptionContext _ectx; + private readonly IExceptionContext _ectx; - // This is a temporary directory that we create. It is essentially treated like an un-managed resource, - // hence the need for the complete dispose pattern. Note that it is optional - if we use memory - // streams for everything, we don't need it. This ability is needed for Scope or other environments - // where access to the file system is restricted. - protected readonly string DirTemp; + // This is a temporary directory that we create. It is essentially treated like an un-managed resource, + // hence the need for the complete dispose pattern. Note that it is optional - if we use memory + // streams for everything, we don't need it. This ability is needed for Scope or other environments + // where access to the file system is restricted. + protected readonly string DirTemp; - // Maps from relative path to the corresponding absolute path in the temp directory. - // This is populated as we decompress streams in the archive, so we don't de-compress - // more than once. - // REVIEW: Should we garbage collect to some degree? Currently we don't delete any - // of these temp files until the repository is disposed. - protected readonly ConcurrentDictionary PathMap; + // Maps from relative path to the corresponding absolute path in the temp directory. + // This is populated as we decompress streams in the archive, so we don't de-compress + // more than once. + // REVIEW: Should we garbage collect to some degree? Currently we don't delete any + // of these temp files until the repository is disposed. + protected readonly ConcurrentDictionary PathMap; - /// - /// Exception context. - /// - public IExceptionContext ExceptionContext => _ectx; + /// + /// Exception context. + /// + public IExceptionContext ExceptionContext => _ectx; - protected bool Disposed => _disposed; + protected bool Disposed => _disposed; - internal Repository(bool needDir, IExceptionContext ectx) - { - Contracts.AssertValueOrNull(ectx); - _ectx = ectx; + internal Repository(bool needDir, IExceptionContext ectx) + { + Contracts.AssertValueOrNull(ectx); + _ectx = ectx; + + PathMap = new ConcurrentDictionary(); + _open = new List(); + if (needDir) + DirTemp = GetShortTempDir(ectx); + else + GC.SuppressFinalize(this); + } - PathMap = new ConcurrentDictionary(); - _open = new List(); - if (needDir) - DirTemp = GetShortTempDir(ectx); - else - GC.SuppressFinalize(this); - } + private static string GetShortTempDir(IExceptionContext ectx) + { + string tempPath = ectx is IHostEnvironmentInternal iHostInternal ? + iHostInternal.TempFilePath : + Path.GetTempPath(); - private static string GetShortTempDir(IExceptionContext ectx) - { - string tempPath = ectx is IHostEnvironmentInternal iHostInternal ? - iHostInternal.TempFilePath : - Path.GetTempPath(); + string path = Path.Combine(Path.GetFullPath(tempPath), "ml_dotnet", Path.GetRandomFileName()); + Directory.CreateDirectory(path); + return path; + } - string path = Path.Combine(Path.GetFullPath(tempPath), "ml_dotnet", Path.GetRandomFileName()); - Directory.CreateDirectory(path); - return path; - } + ~Repository() + { + if (!Disposed) + Dispose(false); + } - ~Repository() + public void Dispose() + { + if (!Disposed) { - if (!Disposed) - Dispose(false); + Dispose(true); + GC.SuppressFinalize(this); } + } - public void Dispose() + protected virtual void Dispose(bool disposing) + { + _ectx.Assert(!Disposed); + + // Close all temp files. + try { - if (!Disposed) - { - Dispose(true); - GC.SuppressFinalize(this); - } + DisposeAllEntries(); } - - protected virtual void Dispose(bool disposing) + catch { - _ectx.Assert(!Disposed); + _ectx.Assert(false, "Closing entries should not throw!"); + } - // Close all temp files. + // Delete the temp directory. + if (DirTemp != null) + { try { - DisposeAllEntries(); + Directory.Delete(DirTemp, true); } catch { - _ectx.Assert(false, "Closing entries should not throw!"); } - - // Delete the temp directory. - if (DirTemp != null) - { - try - { - Directory.Delete(DirTemp, true); - } - catch - { - } - } - - _disposed = true; } - /// - /// Force all open entries to be disposed. - /// - protected void DisposeAllEntries() + _disposed = true; + } + + /// + /// Force all open entries to be disposed. + /// + protected void DisposeAllEntries() + { + while (_open.Count > 0) { - while (_open.Count > 0) - { - var ent = _open[_open.Count - 1]; - ent.Dispose(); - } + var ent = _open[_open.Count - 1]; + ent.Dispose(); } + } - /// - /// Remove the entry from _open. Note that under normal access patterns, entries are LIFO, - /// so we search from the end of _open. - /// - protected void RemoveEntry(Entry ent) + /// + /// Remove the entry from _open. Note that under normal access patterns, entries are LIFO, + /// so we search from the end of _open. + /// + protected void RemoveEntry(Entry ent) + { + // Note that under normal access patterns, entries are LIFO, so we search from the end of _open. + for (int i = _open.Count; --i >= 0;) { - // Note that under normal access patterns, entries are LIFO, so we search from the end of _open. - for (int i = _open.Count; --i >= 0;) + if (_open[i] == ent) { - if (_open[i] == ent) - { - _open.RemoveAt(i); - return; - } + _open.RemoveAt(i); + return; } - _ectx.Assert(false, "Why wasn't the entry found?"); } + _ectx.Assert(false, "Why wasn't the entry found?"); + } - /// - /// The entry is being disposed. Note that overrides should always call RemoveEntry, in addition to whatever - /// they need to do with the corresponding stream. - /// - protected abstract void OnDispose(Entry ent); - - /// - /// When considering entries inside one of our model archives, we want to ensure that we - /// use a consistent directory separator. Zip archives are stored as flat lists of entries. - /// When we load those entries into our look-up dictionary, we normalize them to always use - /// backward slashes. - /// - protected static string NormalizeForArchiveEntry(string path) => path?.Replace('/', Path.DirectorySeparatorChar); - - /// - /// When building paths to our local file system, we want to force both forward and backward slashes - /// to the system directory separator character. We do this for cases where we either used Windows-specific - /// path building logic, or concatenated filesystem paths with zip archive entries on Linux. - /// - private static string NormalizeForFileSystem(string path) => - path?.Replace('/', Path.DirectorySeparatorChar).Replace('\\', Path.DirectorySeparatorChar); + /// + /// The entry is being disposed. Note that overrides should always call RemoveEntry, in addition to whatever + /// they need to do with the corresponding stream. + /// + protected abstract void OnDispose(Entry ent); - /// - /// Constructs both the relative path to the entry and the absolute path of a corresponding - /// temporary file. If createDir is true, makes sure the directory exists within the temp directory. - /// - protected void GetPath(out string pathEnt, out string pathTemp, string dir, string name, bool createDir) - { - _ectx.Assert(!Disposed); - _ectx.CheckValueOrNull(dir); - _ectx.CheckParam(dir == null || !dir.Contains(".."), nameof(dir)); - _ectx.CheckParam(!string.IsNullOrWhiteSpace(name), nameof(name)); - _ectx.CheckParam(!name.Contains(".."), nameof(name)); - - // The gymnastics below are meant to deal with bad invocations including absolute paths, etc. - // That's why we go through it even if DirTemp is null. - string root = Path.GetFullPath(DirTemp ?? @"x:\dummy"); - string entityPath = Path.Combine(root, dir ?? "", name); - entityPath = Path.GetFullPath(entityPath); - string tempPath = Path.Combine(root, Path.GetRandomFileName()); - tempPath = Path.GetFullPath(tempPath); - - string parent = Path.GetDirectoryName(entityPath); - _ectx.Check(parent != null); - _ectx.Check(parent.StartsWith(root)); - - int ichSplit = root.Length; - _ectx.Check(entityPath.Length > ichSplit && entityPath[ichSplit] == Path.DirectorySeparatorChar); - - if (createDir && DirTemp != null && parent.Length > ichSplit) - Directory.CreateDirectory(parent); - - // Get the relative path portion. This is the archive entry name. - pathEnt = entityPath.Substring(ichSplit + 1); - _ectx.Check(Utils.Size(pathEnt) > 0); - _ectx.Check(entityPath == Path.Combine(root, pathEnt)); - - // Set pathTemp to non-null iff _dirTemp is non-null. - pathTemp = DirTemp != null ? tempPath : null; - - pathEnt = NormalizeForArchiveEntry(pathEnt); - pathTemp = NormalizeForFileSystem(pathTemp); - } + /// + /// When considering entries inside one of our model archives, we want to ensure that we + /// use a consistent directory separator. Zip archives are stored as flat lists of entries. + /// When we load those entries into our look-up dictionary, we normalize them to always use + /// backward slashes. + /// + protected static string NormalizeForArchiveEntry(string path) => path?.Replace('/', Path.DirectorySeparatorChar); - protected Entry AddEntry(string pathEnt, Stream stream) - { - _ectx.Assert(!Disposed); - _ectx.AssertValue(stream); + /// + /// When building paths to our local file system, we want to force both forward and backward slashes + /// to the system directory separator character. We do this for cases where we either used Windows-specific + /// path building logic, or concatenated filesystem paths with zip archive entries on Linux. + /// + private static string NormalizeForFileSystem(string path) => + path?.Replace('/', Path.DirectorySeparatorChar).Replace('\\', Path.DirectorySeparatorChar); - var ent = new Entry(this, pathEnt, stream); - _open.Add(ent); - return ent; - } + /// + /// Constructs both the relative path to the entry and the absolute path of a corresponding + /// temporary file. If createDir is true, makes sure the directory exists within the temp directory. + /// + protected void GetPath(out string pathEnt, out string pathTemp, string dir, string name, bool createDir) + { + _ectx.Assert(!Disposed); + _ectx.CheckValueOrNull(dir); + _ectx.CheckParam(dir == null || !dir.Contains(".."), nameof(dir)); + _ectx.CheckParam(!string.IsNullOrWhiteSpace(name), nameof(name)); + _ectx.CheckParam(!name.Contains(".."), nameof(name)); + + // The gymnastics below are meant to deal with bad invocations including absolute paths, etc. + // That's why we go through it even if DirTemp is null. + string root = Path.GetFullPath(DirTemp ?? @"x:\dummy"); + string entityPath = Path.Combine(root, dir ?? "", name); + entityPath = Path.GetFullPath(entityPath); + string tempPath = Path.Combine(root, Path.GetRandomFileName()); + tempPath = Path.GetFullPath(tempPath); + + string parent = Path.GetDirectoryName(entityPath); + _ectx.Check(parent != null); + _ectx.Check(parent.StartsWith(root)); + + int ichSplit = root.Length; + _ectx.Check(entityPath.Length > ichSplit && entityPath[ichSplit] == Path.DirectorySeparatorChar); + + if (createDir && DirTemp != null && parent.Length > ichSplit) + Directory.CreateDirectory(parent); + + // Get the relative path portion. This is the archive entry name. + pathEnt = entityPath.Substring(ichSplit + 1); + _ectx.Check(Utils.Size(pathEnt) > 0); + _ectx.Check(entityPath == Path.Combine(root, pathEnt)); + + // Set pathTemp to non-null iff _dirTemp is non-null. + pathTemp = DirTemp != null ? tempPath : null; + + pathEnt = NormalizeForArchiveEntry(pathEnt); + pathTemp = NormalizeForFileSystem(pathTemp); } - [BestFriend] - internal sealed class RepositoryWriter : Repository + protected Entry AddEntry(string pathEnt, Stream stream) { - private const string DirTrainingInfo = "TrainingInfo"; + _ectx.Assert(!Disposed); + _ectx.AssertValue(stream); - private ZipArchive _archive; - private Queue> _closed; - - public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx = null, bool useFileSystem = true) - { - Contracts.CheckValueOrNull(ectx); - ectx.CheckValue(stream, nameof(stream)); - var rep = new RepositoryWriter(stream, ectx, useFileSystem); - - using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt")) - using (var writer = Utils.OpenWriter(ent.Stream)) - writer.WriteLine(GetProductVersion()); - return rep; - } + var ent = new Entry(this, pathEnt, stream); + _open.Add(ent); + return ent; + } +} - private RepositoryWriter(Stream stream, IExceptionContext ectx, bool useFileSystem = true) - : base(useFileSystem, ectx) - { - _archive = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: true); - _closed = new Queue>(); - } +[BestFriend] +internal sealed class RepositoryWriter : Repository +{ + private const string DirTrainingInfo = "TrainingInfo"; - public Entry CreateEntry(string name) - { - return CreateEntry(null, name); - } + private ZipArchive _archive; + private Queue> _closed; - public Entry CreateEntry(string dir, string name) - { - ExceptionContext.Check(!Disposed); + public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx = null, bool useFileSystem = true) + { + Contracts.CheckValueOrNull(ectx); + ectx.CheckValue(stream, nameof(stream)); + var rep = new RepositoryWriter(stream, ectx, useFileSystem); + + using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt")) + using (var writer = Utils.OpenWriter(ent.Stream)) + writer.WriteLine(GetProductVersion()); + return rep; + } - Flush(); + private RepositoryWriter(Stream stream, IExceptionContext ectx, bool useFileSystem = true) + : base(useFileSystem, ectx) + { + _archive = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: true); + _closed = new Queue>(); + } - string pathEnt; - string pathTemp; - GetPath(out pathEnt, out pathTemp, dir, name, true); - if (!PathMap.TryAdd(pathEnt, pathTemp)) - throw ExceptionContext.ExceptParam(nameof(name), "Duplicate entry: '{0}'", pathEnt); + public Entry CreateEntry(string name) + { + return CreateEntry(null, name); + } - Stream stream; - if (pathTemp != null) - stream = new FileStream(pathTemp, FileMode.CreateNew); - else - stream = new MemoryStream(); + public Entry CreateEntry(string dir, string name) + { + ExceptionContext.Check(!Disposed); - return AddEntry(pathEnt, stream); - } + Flush(); - // The entry is being disposed. Note that this isn't supposed to throw, so we simply queue - // the stream so it can be written to the archive when it IS legal to throw. - protected override void OnDispose(Entry ent) - { - ExceptionContext.AssertValue(ent); - RemoveEntry(ent); + string pathEnt; + string pathTemp; + GetPath(out pathEnt, out pathTemp, dir, name, true); + if (!PathMap.TryAdd(pathEnt, pathTemp)) + throw ExceptionContext.ExceptParam(nameof(name), "Duplicate entry: '{0}'", pathEnt); - if (_closed != null) - _closed.Enqueue(new KeyValuePair(ent.Path, ent.Stream)); - else - ent.Stream.CloseEx(); - } + Stream stream; + if (pathTemp != null) + stream = new FileStream(pathTemp, FileMode.CreateNew); + else + stream = new MemoryStream(); - protected override void Dispose(bool disposing) - { - ExceptionContext.Assert(!Disposed); + return AddEntry(pathEnt, stream); + } - if (_closed != null) - { - while (_closed.Count > 0) - { - var kvp = _closed.Dequeue(); - kvp.Value.CloseEx(); - } - _closed = null; - } + // The entry is being disposed. Note that this isn't supposed to throw, so we simply queue + // the stream so it can be written to the archive when it IS legal to throw. + protected override void OnDispose(Entry ent) + { + ExceptionContext.AssertValue(ent); + RemoveEntry(ent); - if (_archive != null) - { - try - { - _archive.Dispose(); - } - catch - { - } - _archive = null; - } + if (_closed != null) + _closed.Enqueue(new KeyValuePair(ent.Path, ent.Stream)); + else + ent.Stream.CloseEx(); + } - // Close all the streams. - base.Dispose(disposing); - } + protected override void Dispose(bool disposing) + { + ExceptionContext.Assert(!Disposed); - // Write "closed" entries to the archive. - private void Flush() + if (_closed != null) { - ExceptionContext.Assert(!Disposed); - ExceptionContext.AssertValue(_closed); - ExceptionContext.AssertValue(_archive); - while (_closed.Count > 0) { - string path = null; var kvp = _closed.Dequeue(); - using (var src = kvp.Value) - { - var fs = src as FileStream; - if (fs != null) - path = fs.Name; - - var ae = _archive.CreateEntry(kvp.Key); - using (var dst = ae.Open()) - { - src.Position = 0; - src.CopyTo(dst); - } - } - - if (!string.IsNullOrEmpty(path)) - File.Delete(path); + kvp.Value.CloseEx(); } + _closed = null; } - /// - /// Commit the writing of the repository. This signals successful completion of the write. - /// - public void Commit() + if (_archive != null) { - ExceptionContext.Check(!Disposed); - ExceptionContext.AssertValue(_closed); - - DisposeAllEntries(); - Flush(); - Dispose(true); + try + { + _archive.Dispose(); + } + catch + { + } + _archive = null; } - private static string GetProductVersion() - { - var assembly = typeof(RepositoryWriter).Assembly; + // Close all the streams. + base.Dispose(disposing); + } - var assemblyInternationalVersionAttribute = assembly.CustomAttributes.FirstOrDefault(a => - a.AttributeType == typeof(AssemblyInformationalVersionAttribute)); + // Write "closed" entries to the archive. + private void Flush() + { + ExceptionContext.Assert(!Disposed); + ExceptionContext.AssertValue(_closed); + ExceptionContext.AssertValue(_archive); - if (assemblyInternationalVersionAttribute == null) + while (_closed.Count > 0) + { + string path = null; + var kvp = _closed.Dequeue(); + using (var src = kvp.Value) { - throw new ApplicationException($"Cannot determine product version from assembly {assembly.FullName}."); + var fs = src as FileStream; + if (fs != null) + path = fs.Name; + + var ae = _archive.CreateEntry(kvp.Key); + using (var dst = ae.Open()) + { + src.Position = 0; + src.CopyTo(dst); + } } - return assemblyInternationalVersionAttribute.ConstructorArguments - .First() - .Value - .ToString(); + if (!string.IsNullOrEmpty(path)) + File.Delete(path); } } - [BestFriend] - internal sealed class RepositoryReader : Repository + /// + /// Commit the writing of the repository. This signals successful completion of the write. + /// + public void Commit() + { + ExceptionContext.Check(!Disposed); + ExceptionContext.AssertValue(_closed); + + DisposeAllEntries(); + Flush(); + Dispose(true); + } + + private static string GetProductVersion() { - private readonly ZipArchive _archive; + var assembly = typeof(RepositoryWriter).Assembly; - // Maps from a normalized path to the entry in the _archive. This is needed since - // a zip might use / or \ for directory separation. - private readonly Dictionary _entries; + var assemblyInternationalVersionAttribute = assembly.CustomAttributes.FirstOrDefault(a => + a.AttributeType == typeof(AssemblyInformationalVersionAttribute)); - public static RepositoryReader Open(Stream stream, IExceptionContext ectx = null, bool useFileSystem = true) + if (assemblyInternationalVersionAttribute == null) { - Contracts.CheckValueOrNull(ectx); - ectx.CheckValue(stream, nameof(stream)); - return new RepositoryReader(stream, ectx, useFileSystem); + throw new ApplicationException($"Cannot determine product version from assembly {assembly.FullName}."); } - private RepositoryReader(Stream stream, IExceptionContext ectx, bool useFileSystem) - : base(useFileSystem, ectx) - { - try - { - _archive = new ZipArchive(stream, ZipArchiveMode.Read, true); - } - catch (Exception ex) - { - throw ExceptionContext.ExceptDecode(ex, "Failed to open a zip archive"); - } + return assemblyInternationalVersionAttribute.ConstructorArguments + .First() + .Value + .ToString(); + } +} - _entries = new Dictionary(); - foreach (var entry in _archive.Entries) - { - var path = NormalizeForArchiveEntry(entry.FullName); - _entries[path] = entry; - } - } +[BestFriend] +internal sealed class RepositoryReader : Repository +{ + private readonly ZipArchive _archive; + + // Maps from a normalized path to the entry in the _archive. This is needed since + // a zip might use / or \ for directory separation. + private readonly Dictionary _entries; - public Entry OpenEntry(string name) + public static RepositoryReader Open(Stream stream, IExceptionContext ectx = null, bool useFileSystem = true) + { + Contracts.CheckValueOrNull(ectx); + ectx.CheckValue(stream, nameof(stream)); + return new RepositoryReader(stream, ectx, useFileSystem); + } + + private RepositoryReader(Stream stream, IExceptionContext ectx, bool useFileSystem) + : base(useFileSystem, ectx) + { + try { - return OpenEntry(null, name); + _archive = new ZipArchive(stream, ZipArchiveMode.Read, true); } - - public Entry OpenEntry(string dir, string name) + catch (Exception ex) { - var ent = OpenEntryOrNull(dir, name); - if (ent != null) - return ent; - - string pathEnt; - string pathTemp; - GetPath(out pathEnt, out pathTemp, dir, name, false); - throw ExceptionContext.Except("Repository doesn't contain entry {0}", pathEnt); + throw ExceptionContext.ExceptDecode(ex, "Failed to open a zip archive"); } - public Entry OpenEntryOrNull(string name) + _entries = new Dictionary(); + foreach (var entry in _archive.Entries) { - return OpenEntryOrNull(null, name); + var path = NormalizeForArchiveEntry(entry.FullName); + _entries[path] = entry; } + } - public Entry OpenEntryOrNull(string dir, string name) - { - ExceptionContext.Check(!Disposed); + public Entry OpenEntry(string name) + { + return OpenEntry(null, name); + } - string pathEnt; - string pathTemp; - GetPath(out pathEnt, out pathTemp, dir, name, false); + public Entry OpenEntry(string dir, string name) + { + var ent = OpenEntryOrNull(dir, name); + if (ent != null) + return ent; - ZipArchiveEntry entry; - Stream stream; - string pathAbs; - string pathLower = pathEnt.ToLowerInvariant(); - if (PathMap.TryGetValue(pathLower, out pathAbs)) - { - stream = new FileStream(pathAbs, FileMode.Open, FileAccess.Read, FileShare.Read); - } - else - { - if (!_entries.TryGetValue(pathEnt, out entry)) - { - //Read old zip file that use backslash in filename - var pathEntTmp = pathEnt.Replace("/", "\\"); - if (!_entries.TryGetValue(pathEntTmp, out entry)) - { - return null; - } - } + string pathEnt; + string pathTemp; + GetPath(out pathEnt, out pathTemp, dir, name, false); + throw ExceptionContext.Except("Repository doesn't contain entry {0}", pathEnt); + } - if (pathTemp != null) - { - // Extract to a temporary file. - Directory.CreateDirectory(Path.GetDirectoryName(pathTemp)); - entry.ExtractToFile(pathTemp); - if (!PathMap.TryAdd(pathLower, pathTemp)) - throw ExceptionContext.ExceptParam(nameof(name), "Duplicate entry: '{0}'", pathLower); + public Entry OpenEntryOrNull(string name) + { + return OpenEntryOrNull(null, name); + } - stream = new FileStream(pathTemp, FileMode.Open, FileAccess.Read, FileShare.Read); - } - else + public Entry OpenEntryOrNull(string dir, string name) + { + ExceptionContext.Check(!Disposed); + + string pathEnt; + string pathTemp; + GetPath(out pathEnt, out pathTemp, dir, name, false); + + ZipArchiveEntry entry; + Stream stream; + string pathAbs; + string pathLower = pathEnt.ToLowerInvariant(); + if (PathMap.TryGetValue(pathLower, out pathAbs)) + { + stream = new FileStream(pathAbs, FileMode.Open, FileAccess.Read, FileShare.Read); + } + else + { + if (!_entries.TryGetValue(pathEnt, out entry)) + { + //Read old zip file that use backslash in filename + var pathEntTmp = pathEnt.Replace("/", "\\"); + if (!_entries.TryGetValue(pathEntTmp, out entry)) { - // Extract to a memory stream. - ExceptionContext.CheckDecode(entry.Length < int.MaxValue, "Repository stream too large to read into memory"); - stream = new MemoryStream((int)entry.Length); - using (var src = entry.Open()) - src.CopyTo(stream); - stream.Position = 0; + return null; } } - return AddEntry(pathEnt, stream); - } + if (pathTemp != null) + { + // Extract to a temporary file. + Directory.CreateDirectory(Path.GetDirectoryName(pathTemp)); + entry.ExtractToFile(pathTemp); + if (!PathMap.TryAdd(pathLower, pathTemp)) + throw ExceptionContext.ExceptParam(nameof(name), "Duplicate entry: '{0}'", pathLower); - protected override void OnDispose(Entry ent) - { - ExceptionContext.AssertValue(ent); - RemoveEntry(ent); - ent.Stream.CloseEx(); + stream = new FileStream(pathTemp, FileMode.Open, FileAccess.Read, FileShare.Read); + } + else + { + // Extract to a memory stream. + ExceptionContext.CheckDecode(entry.Length < int.MaxValue, "Repository stream too large to read into memory"); + stream = new MemoryStream((int)entry.Length); + using (var src = entry.Open()) + src.CopyTo(stream); + stream.Position = 0; + } } + + return AddEntry(pathEnt, stream); + } + + protected override void OnDispose(Entry ent) + { + ExceptionContext.AssertValue(ent); + RemoveEntry(ent); + ent.Stream.CloseEx(); } } diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 32b5e7107a..2e64339f94 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -6,484 +6,483 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Encapsulates a plus column role mapping information. The purpose of role mappings is to +/// provide information on what the intended usage is for. That is: while a given data view may have a column named +/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role +/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features" +/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be +/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume +/// multiple features columns to consume that information. +/// +/// This class has convenience fields for several common column roles (for example, , ), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only +/// if there is a unique column with the corresponding role. When there are no such columns or more than one such +/// column, the field is null. The , , and +/// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden +/// in this schema. +/// +/// +/// Note that instances of this class are, like instances of , immutable. +/// +/// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For +/// that case, please use the class. +/// +/// Note that there is no need for components consuming a or +/// to make use of every defined mapping. Consuming components are also expected to ignore any +/// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s) +/// mapped from the role are not of the form they require. +/// +/// +/// +[BestFriend] +internal sealed class RoleMappedSchema { + private const string FeatureString = "Feature"; + private const string LabelString = "Label"; + private const string GroupString = "Group"; + private const string WeightString = "Weight"; + private const string NameString = "Name"; + private const string FeatureContributionsString = "FeatureContributions"; + /// - /// Encapsulates a plus column role mapping information. The purpose of role mappings is to - /// provide information on what the intended usage is for. That is: while a given data view may have a column named - /// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role - /// mapping for features is filled by that "Features" column. This allows things like columns not named "Features" - /// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be - /// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume - /// multiple features columns to consume that information. - /// - /// This class has convenience fields for several common column roles (for example, , ), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only - /// if there is a unique column with the corresponding role. When there are no such columns or more than one such - /// column, the field is null. The , , and - /// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden - /// in this schema. + /// Instances of this are the keys of a . This class also holds some important + /// commonly used pre-defined instances available (for example, , ) that should + /// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom + /// roles if approppriate for their task. /// - /// - /// Note that instances of this class are, like instances of , immutable. - /// - /// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For - /// that case, please use the class. - /// - /// Note that there is no need for components consuming a or - /// to make use of every defined mapping. Consuming components are also expected to ignore any - /// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s) - /// mapped from the role are not of the form they require. - /// - /// - /// - [BestFriend] - internal sealed class RoleMappedSchema + public readonly struct ColumnRole { - private const string FeatureString = "Feature"; - private const string LabelString = "Label"; - private const string GroupString = "Group"; - private const string WeightString = "Weight"; - private const string NameString = "Name"; - private const string FeatureContributionsString = "FeatureContributions"; - /// - /// Instances of this are the keys of a . This class also holds some important - /// commonly used pre-defined instances available (for example, , ) that should - /// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom - /// roles if approppriate for their task. + /// Role for features. Commonly used as the independent variables given to trainers, and scorers. /// - public readonly struct ColumnRole - { - /// - /// Role for features. Commonly used as the independent variables given to trainers, and scorers. - /// - public static ColumnRole Feature => FeatureString; - - /// - /// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators. - /// - public static ColumnRole Label => LabelString; - - /// - /// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or - /// sequence classification, for defining the boundaries of an utterance. - /// - public static ColumnRole Group => GroupString; - - /// - /// Role for sample weights. Commonly used to point to a number to make trainers give more weight - /// to a particular example. - /// - public static ColumnRole Weight => WeightString; - - /// - /// Role for sample names. Useful for informational and tracking purposes when scoring, but typically - /// without affecting results. - /// - public static ColumnRole Name => NameString; - - // REVIEW: Does this really belong here? - /// - /// Role for feature contributions. Useful for specific diagnostic functionality. - /// - public static ColumnRole FeatureContributions => FeatureContributionsString; - - /// - /// The string value for the role. Guaranteed to be non-empty. - /// - public readonly string Value; - - /// - /// Constructor for the column role. - /// - /// The value for the role. Must be non-empty. - public ColumnRole(string value) - { - Contracts.CheckNonEmpty(value, nameof(value)); - Value = value; - } + public static ColumnRole Feature => FeatureString; - public static implicit operator ColumnRole(string value) - => new ColumnRole(value); - - /// - /// Convenience method for creating a mapping pair from a role to a column name - /// for giving to constructors of and . - /// - /// The column name to map to. Can be null, in which case when used - /// to construct a role mapping structure this pair will be ignored - /// A key-value pair with this instance as the key and as the value - public KeyValuePair Bind(string name) - => new KeyValuePair(this, name); - } + /// + /// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators. + /// + public static ColumnRole Label => LabelString; - public static KeyValuePair CreatePair(ColumnRole role, string name) - => new KeyValuePair(role, name); + /// + /// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or + /// sequence classification, for defining the boundaries of an utterance. + /// + public static ColumnRole Group => GroupString; /// - /// The source . + /// Role for sample weights. Commonly used to point to a number to make trainers give more weight + /// to a particular example. /// - public DataViewSchema Schema { get; } + public static ColumnRole Weight => WeightString; /// - /// The column, when there is exactly one (null otherwise). + /// Role for sample names. Useful for informational and tracking purposes when scoring, but typically + /// without affecting results. /// - public DataViewSchema.Column? Feature { get; } + public static ColumnRole Name => NameString; + // REVIEW: Does this really belong here? /// - /// The column, when there is exactly one (null otherwise). + /// Role for feature contributions. Useful for specific diagnostic functionality. /// - public DataViewSchema.Column? Label { get; } + public static ColumnRole FeatureContributions => FeatureContributionsString; /// - /// The column, when there is exactly one (null otherwise). + /// The string value for the role. Guaranteed to be non-empty. /// - public DataViewSchema.Column? Group { get; } + public readonly string Value; /// - /// The column, when there is exactly one (null otherwise). + /// Constructor for the column role. /// - public DataViewSchema.Column? Weight { get; } + /// The value for the role. Must be non-empty. + public ColumnRole(string value) + { + Contracts.CheckNonEmpty(value, nameof(value)); + Value = value; + } + + public static implicit operator ColumnRole(string value) + => new ColumnRole(value); /// - /// The column, when there is exactly one (null otherwise). + /// Convenience method for creating a mapping pair from a role to a column name + /// for giving to constructors of and . /// - public DataViewSchema.Column? Name { get; } + /// The column name to map to. Can be null, in which case when used + /// to construct a role mapping structure this pair will be ignored + /// A key-value pair with this instance as the key and as the value + public KeyValuePair Bind(string name) + => new KeyValuePair(this, name); + } - // Maps from role to the associated column infos. - private readonly Dictionary> _map; + public static KeyValuePair CreatePair(ColumnRole role, string name) + => new KeyValuePair(role, name); - private RoleMappedSchema(DataViewSchema schema, Dictionary> map) - { - Contracts.AssertValue(schema); - Contracts.AssertValue(map); + /// + /// The source . + /// + public DataViewSchema Schema { get; } - Schema = schema; - _map = map; - foreach (var kvp in _map) - { - Contracts.Assert(Utils.Size(kvp.Value) > 0); - var cols = kvp.Value; + /// + /// The column, when there is exactly one (null otherwise). + /// + public DataViewSchema.Column? Feature { get; } + + /// + /// The column, when there is exactly one (null otherwise). + /// + public DataViewSchema.Column? Label { get; } + + /// + /// The column, when there is exactly one (null otherwise). + /// + public DataViewSchema.Column? Group { get; } + + /// + /// The column, when there is exactly one (null otherwise). + /// + public DataViewSchema.Column? Weight { get; } + + /// + /// The column, when there is exactly one (null otherwise). + /// + public DataViewSchema.Column? Name { get; } + + // Maps from role to the associated column infos. + private readonly Dictionary> _map; + + private RoleMappedSchema(DataViewSchema schema, Dictionary> map) + { + Contracts.AssertValue(schema); + Contracts.AssertValue(map); + + Schema = schema; + _map = map; + foreach (var kvp in _map) + { + Contracts.Assert(Utils.Size(kvp.Value) > 0); + var cols = kvp.Value; #if DEBUG - foreach (var info in cols) - Contracts.Assert(!schema[info.Index].IsHidden, "How did a hidden column sneak in?"); + foreach (var info in cols) + Contracts.Assert(!schema[info.Index].IsHidden, "How did a hidden column sneak in?"); #endif - if (cols.Count == 1) + if (cols.Count == 1) + { + switch (kvp.Key) { - switch (kvp.Key) - { - case FeatureString: - Feature = cols[0]; - break; - case LabelString: - Label = cols[0]; - break; - case GroupString: - Group = cols[0]; - break; - case WeightString: - Weight = cols[0]; - break; - case NameString: - Name = cols[0]; - break; - } + case FeatureString: + Feature = cols[0]; + break; + case LabelString: + Label = cols[0]; + break; + case GroupString: + Group = cols[0]; + break; + case WeightString: + Weight = cols[0]; + break; + case NameString: + Name = cols[0]; + break; } } } + } - private RoleMappedSchema(DataViewSchema schema, Dictionary> map) - : this(schema, Copy(map)) - { - } + private RoleMappedSchema(DataViewSchema schema, Dictionary> map) + : this(schema, Copy(map)) + { + } - private static void Add(Dictionary> map, ColumnRole role, DataViewSchema.Column column) - { - Contracts.AssertValue(map); - Contracts.AssertNonEmpty(role.Value); + private static void Add(Dictionary> map, ColumnRole role, DataViewSchema.Column column) + { + Contracts.AssertValue(map); + Contracts.AssertNonEmpty(role.Value); - if (!map.TryGetValue(role.Value, out var list)) - { - list = new List(); - map.Add(role.Value, list); - } - list.Add(column); + if (!map.TryGetValue(role.Value, out var list)) + { + list = new List(); + map.Add(role.Value, list); } + list.Add(column); + } - private static Dictionary> MapFromNames(DataViewSchema schema, IEnumerable> roles, bool opt = false) - { - Contracts.AssertValue(schema); - Contracts.AssertValue(roles); + private static Dictionary> MapFromNames(DataViewSchema schema, IEnumerable> roles, bool opt = false) + { + Contracts.AssertValue(schema); + Contracts.AssertValue(roles); - var map = new Dictionary>(); - foreach (var kvp in roles) - { - Contracts.AssertNonEmpty(kvp.Key.Value); - if (string.IsNullOrEmpty(kvp.Value)) - continue; - var info = schema.GetColumnOrNull(kvp.Value); - if (info.HasValue) - Add(map, kvp.Key.Value, info.Value); - else if (!opt) - throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found"); - } - return map; + var map = new Dictionary>(); + foreach (var kvp in roles) + { + Contracts.AssertNonEmpty(kvp.Key.Value); + if (string.IsNullOrEmpty(kvp.Value)) + continue; + var info = schema.GetColumnOrNull(kvp.Value); + if (info.HasValue) + Add(map, kvp.Key.Value, info.Value); + else if (!opt) + throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found"); } + return map; + } - /// - /// Returns whether there are any columns with the given column role. - /// - public bool Has(ColumnRole role) - => _map.ContainsKey(role.Value); - - /// - /// Returns whether there is exactly one column of the given role. - /// - public bool HasUnique(ColumnRole role) - => _map.TryGetValue(role.Value, out var cols) && cols.Count == 1; + /// + /// Returns whether there are any columns with the given column role. + /// + public bool Has(ColumnRole role) + => _map.ContainsKey(role.Value); - /// - /// Returns whether there are two or more columns of the given role. - /// - public bool HasMultiple(ColumnRole role) - => _map.TryGetValue(role.Value, out var cols) && cols.Count > 1; + /// + /// Returns whether there is exactly one column of the given role. + /// + public bool HasUnique(ColumnRole role) + => _map.TryGetValue(role.Value, out var cols) && cols.Count == 1; - /// - /// If there are columns of the given role, this returns the infos as a readonly list. Otherwise, - /// it returns null. - /// - public IReadOnlyList GetColumns(ColumnRole role) - => _map.TryGetValue(role.Value, out var list) ? list : null; + /// + /// Returns whether there are two or more columns of the given role. + /// + public bool HasMultiple(ColumnRole role) + => _map.TryGetValue(role.Value, out var cols) && cols.Count > 1; - /// - /// An enumerable over all role-column associations within this object. - /// - public IEnumerable> GetColumnRoles() - { - foreach (var roleAndList in _map) - { - foreach (var info in roleAndList.Value) - yield return new KeyValuePair(roleAndList.Key, info); - } - } + /// + /// If there are columns of the given role, this returns the infos as a readonly list. Otherwise, + /// it returns null. + /// + public IReadOnlyList GetColumns(ColumnRole role) + => _map.TryGetValue(role.Value, out var list) ? list : null; - /// - /// An enumerable over all role-column associations within this object. - /// - public IEnumerable> GetColumnRoleNames() + /// + /// An enumerable over all role-column associations within this object. + /// + public IEnumerable> GetColumnRoles() + { + foreach (var roleAndList in _map) { - foreach (var roleAndList in _map) - { - foreach (var info in roleAndList.Value) - yield return new KeyValuePair(roleAndList.Key, info.Name); - } + foreach (var info in roleAndList.Value) + yield return new KeyValuePair(roleAndList.Key, info); } + } - /// - /// An enumerable over all role-column associations for the given role. This is a helper function - /// for implementing the method. - /// - public IEnumerable> GetColumnRoleNames(ColumnRole role) + /// + /// An enumerable over all role-column associations within this object. + /// + public IEnumerable> GetColumnRoleNames() + { + foreach (var roleAndList in _map) { - if (_map.TryGetValue(role.Value, out var list)) - { - foreach (var info in list) - yield return new KeyValuePair(role, info.Name); - } + foreach (var info in roleAndList.Value) + yield return new KeyValuePair(roleAndList.Key, info.Name); } + } - /// - /// Returns the corresponding to if there is - /// exactly one such mapping, and otherwise throws an exception. - /// - /// The role to look up - /// The column corresponding to that role, assuming there was only one column - /// mapped to that - public DataViewSchema.Column GetUniqueColumn(ColumnRole role) + /// + /// An enumerable over all role-column associations for the given role. This is a helper function + /// for implementing the method. + /// + public IEnumerable> GetColumnRoleNames(ColumnRole role) + { + if (_map.TryGetValue(role.Value, out var list)) { - var infos = GetColumns(role); - if (Utils.Size(infos) != 1) - throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos)); - return infos[0]; + foreach (var info in list) + yield return new KeyValuePair(role, info.Name); } + } - private static Dictionary> Copy(Dictionary> map) - { - var copy = new Dictionary>(map.Count); - foreach (var kvp in map) - { - Contracts.Assert(Utils.Size(kvp.Value) > 0); - var cols = kvp.Value.ToArray(); - copy.Add(kvp.Key, cols); - } - return copy; - } + /// + /// Returns the corresponding to if there is + /// exactly one such mapping, and otherwise throws an exception. + /// + /// The role to look up + /// The column corresponding to that role, assuming there was only one column + /// mapped to that + public DataViewSchema.Column GetUniqueColumn(ColumnRole role) + { + var infos = GetColumns(role); + if (Utils.Size(infos) != 1) + throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos)); + return infos[0]; + } - /// - /// Constructor given a schema, and mapping pairs of roles to columns in the schema. - /// This skips null or empty column-names. It will also skip column-names that are not - /// found in the schema if is true. - /// - /// The schema over which roles are defined - /// Whether to consider the column names specified "optional" or not. If false then any non-empty - /// values for the column names that does not appear in will result in an exception being thrown, - /// but if true such values will be ignored - /// The column role to column name mappings - public RoleMappedSchema(DataViewSchema schema, bool opt = false, params KeyValuePair[] roles) - : this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt) + private static Dictionary> Copy(Dictionary> map) + { + var copy = new Dictionary>(map.Count); + foreach (var kvp in map) { + Contracts.Assert(Utils.Size(kvp.Value) > 0); + var cols = kvp.Value.ToArray(); + copy.Add(kvp.Key, cols); } + return copy; + } - /// - /// Constructor given a schema, and mapping pairs of roles to columns in the schema. - /// This skips null or empty column names. It will also skip column-names that are not - /// found in the schema if is true. - /// - /// The schema over which roles are defined - /// The column role to column name mappings - /// Whether to consider the column names specified "optional" or not. If false then any non-empty - /// values for the column names that does not appear in will result in an exception being thrown, - /// but if true such values will be ignored - public RoleMappedSchema(DataViewSchema schema, IEnumerable> roles, bool opt = false) - : this(Contracts.CheckRef(schema, nameof(schema)), - MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt)) - { - } + /// + /// Constructor given a schema, and mapping pairs of roles to columns in the schema. + /// This skips null or empty column-names. It will also skip column-names that are not + /// found in the schema if is true. + /// + /// The schema over which roles are defined + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in will result in an exception being thrown, + /// but if true such values will be ignored + /// The column role to column name mappings + public RoleMappedSchema(DataViewSchema schema, bool opt = false, params KeyValuePair[] roles) + : this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt) + { + } - private static IEnumerable> PredefinedRolesHelper( - string label, string feature, string group, string weight, string name, - IEnumerable> custom = null) - { - if (!string.IsNullOrWhiteSpace(label)) - yield return ColumnRole.Label.Bind(label); - if (!string.IsNullOrWhiteSpace(feature)) - yield return ColumnRole.Feature.Bind(feature); - if (!string.IsNullOrWhiteSpace(group)) - yield return ColumnRole.Group.Bind(group); - if (!string.IsNullOrWhiteSpace(weight)) - yield return ColumnRole.Weight.Bind(weight); - if (!string.IsNullOrWhiteSpace(name)) - yield return ColumnRole.Name.Bind(name); - if (custom != null) - { - foreach (var role in custom) - yield return role; - } - } + /// + /// Constructor given a schema, and mapping pairs of roles to columns in the schema. + /// This skips null or empty column names. It will also skip column-names that are not + /// found in the schema if is true. + /// + /// The schema over which roles are defined + /// The column role to column name mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedSchema(DataViewSchema schema, IEnumerable> roles, bool opt = false) + : this(Contracts.CheckRef(schema, nameof(schema)), + MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt)) + { + } - /// - /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified - /// is null or whitespace, it is ignored. - /// - /// The schema over which roles are defined - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// Any additional desired custom column role mappings - /// Whether to consider the column names specified "optional" or not. If false then any non-empty - /// values for the column names that does not appear in will result in an exception being thrown, - /// but if true such values will be ignored - public RoleMappedSchema(DataViewSchema schema, string label, string feature, - string group = null, string weight = null, string name = null, - IEnumerable> custom = null, bool opt = false) - : this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt) + private static IEnumerable> PredefinedRolesHelper( + string label, string feature, string group, string weight, string name, + IEnumerable> custom = null) + { + if (!string.IsNullOrWhiteSpace(label)) + yield return ColumnRole.Label.Bind(label); + if (!string.IsNullOrWhiteSpace(feature)) + yield return ColumnRole.Feature.Bind(feature); + if (!string.IsNullOrWhiteSpace(group)) + yield return ColumnRole.Group.Bind(group); + if (!string.IsNullOrWhiteSpace(weight)) + yield return ColumnRole.Weight.Bind(weight); + if (!string.IsNullOrWhiteSpace(name)) + yield return ColumnRole.Name.Bind(name); + if (custom != null) { - Contracts.CheckValueOrNull(label); - Contracts.CheckValueOrNull(feature); - Contracts.CheckValueOrNull(group); - Contracts.CheckValueOrNull(weight); - Contracts.CheckValueOrNull(name); - Contracts.CheckValueOrNull(custom); + foreach (var role in custom) + yield return role; } } /// - /// Encapsulates an plus a corresponding . - /// Note that the schema of of is - /// guaranteed to equal the the of . + /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified + /// is null or whitespace, it is ignored. /// - [BestFriend] - internal sealed class RoleMappedData + /// The schema over which roles are defined + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// Any additional desired custom column role mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedSchema(DataViewSchema schema, string label, string feature, + string group = null, string weight = null, string name = null, + IEnumerable> custom = null, bool opt = false) + : this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt) { - /// - /// The data. - /// - public IDataView Data { get; } + Contracts.CheckValueOrNull(label); + Contracts.CheckValueOrNull(feature); + Contracts.CheckValueOrNull(group); + Contracts.CheckValueOrNull(weight); + Contracts.CheckValueOrNull(name); + Contracts.CheckValueOrNull(custom); + } +} - /// - /// The role mapped schema. Note that 's is - /// guaranteed to be the same as 's . - /// - public RoleMappedSchema Schema { get; } +/// +/// Encapsulates an plus a corresponding . +/// Note that the schema of of is +/// guaranteed to equal the the of . +/// +[BestFriend] +internal sealed class RoleMappedData +{ + /// + /// The data. + /// + public IDataView Data { get; } - private RoleMappedData(IDataView data, RoleMappedSchema schema) - { - Contracts.AssertValue(data); - Contracts.AssertValue(schema); - Contracts.Assert(schema.Schema == data.Schema); - Data = data; - Schema = schema; - } + /// + /// The role mapped schema. Note that 's is + /// guaranteed to be the same as 's . + /// + public RoleMappedSchema Schema { get; } - /// - /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema. - /// This skips null or empty column-names. It will also skip column-names that are not - /// found in the schema if is true. - /// - /// The data over which roles are defined - /// Whether to consider the column names specified "optional" or not. If false then any non-empty - /// values for the column names that does not appear in 's schema will result in an exception being thrown, - /// but if true such values will be ignored - /// The column role to column name mappings - public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair[] roles) - : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt)) - { - } + private RoleMappedData(IDataView data, RoleMappedSchema schema) + { + Contracts.AssertValue(data); + Contracts.AssertValue(schema); + Contracts.Assert(schema.Schema == data.Schema); + Data = data; + Schema = schema; + } - /// - /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema. - /// This skips null or empty column-names. It will also skip column-names that are not - /// found in the schema if is true. - /// - /// The schema over which roles are defined - /// The column role to column name mappings - /// Whether to consider the column names specified "optional" or not. If false then any non-empty - /// values for the column names that does not appear in 's schema will result in an exception being thrown, - /// but if true such values will be ignored - public RoleMappedData(IDataView data, IEnumerable> roles, bool opt = false) - : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt)) - { - } + /// + /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema. + /// This skips null or empty column-names. It will also skip column-names that are not + /// found in the schema if is true. + /// + /// The data over which roles are defined + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in 's schema will result in an exception being thrown, + /// but if true such values will be ignored + /// The column role to column name mappings + public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair[] roles) + : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt)) + { + } - /// - /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified - /// is null or whitespace, it is ignored. - /// - /// The data over which roles are defined - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// The column name that will be mapped to the role - /// Any additional desired custom column role mappings - /// Whether to consider the column names specified "optional" or not. If false then any non-empty - /// values for the column names that does not appear in 's schema will result in an exception being thrown, - /// but if true such values will be ignored - public RoleMappedData(IDataView data, string label, string feature, - string group = null, string weight = null, string name = null, - IEnumerable> custom = null, bool opt = false) - : this(Contracts.CheckRef(data, nameof(data)), - new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt)) - { - Contracts.CheckValueOrNull(label); - Contracts.CheckValueOrNull(feature); - Contracts.CheckValueOrNull(group); - Contracts.CheckValueOrNull(weight); - Contracts.CheckValueOrNull(name); - Contracts.CheckValueOrNull(custom); - } + /// + /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema. + /// This skips null or empty column-names. It will also skip column-names that are not + /// found in the schema if is true. + /// + /// The schema over which roles are defined + /// The column role to column name mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in 's schema will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedData(IDataView data, IEnumerable> roles, bool opt = false) + : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt)) + { + } + + /// + /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified + /// is null or whitespace, it is ignored. + /// + /// The data over which roles are defined + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// The column name that will be mapped to the role + /// Any additional desired custom column role mappings + /// Whether to consider the column names specified "optional" or not. If false then any non-empty + /// values for the column names that does not appear in 's schema will result in an exception being thrown, + /// but if true such values will be ignored + public RoleMappedData(IDataView data, string label, string feature, + string group = null, string weight = null, string name = null, + IEnumerable> custom = null, bool opt = false) + : this(Contracts.CheckRef(data, nameof(data)), + new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt)) + { + Contracts.CheckValueOrNull(label); + Contracts.CheckValueOrNull(feature); + Contracts.CheckValueOrNull(group); + Contracts.CheckValueOrNull(weight); + Contracts.CheckValueOrNull(name); + Contracts.CheckValueOrNull(custom); } } diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs index 000c43fc6d..fbd02d06ff 100644 --- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs @@ -4,78 +4,77 @@ using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +// REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes +// ownership of the channel so the derived classes don't have to. + +/// +/// Base class for creating a cursor with default tracking of . All calls to +/// will be seen by subclasses of this cursor. For a cursor that has an input cursor and does not need notification on +/// , use instead. +/// +[BestFriend] +internal abstract class RootCursorBase : DataViewRowCursor { - // REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes - // ownership of the channel so the derived classes don't have to. + protected readonly IChannel Ch; + private long _position; + private bool _disposed; /// - /// Base class for creating a cursor with default tracking of . All calls to - /// will be seen by subclasses of this cursor. For a cursor that has an input cursor and does not need notification on - /// , use instead. + /// Zero-based position of the cursor. /// - [BestFriend] - internal abstract class RootCursorBase : DataViewRowCursor - { - protected readonly IChannel Ch; - private long _position; - private bool _disposed; + public sealed override long Position => _position; + + /// + /// Convenience property for checking whether the current state of the cursor is one where data can be fetched. + /// + protected bool IsGood => _position >= 0; - /// - /// Zero-based position of the cursor. - /// - public sealed override long Position => _position; + /// + /// Creates an instance of the class + /// + /// Channel provider + protected RootCursorBase(IChannelProvider provider) + { + Contracts.CheckValue(provider, nameof(provider)); + Ch = provider.Start("Cursor"); - /// - /// Convenience property for checking whether the current state of the cursor is one where data can be fetched. - /// - protected bool IsGood => _position >= 0; + _position = -1; + } - /// - /// Creates an instance of the class - /// - /// Channel provider - protected RootCursorBase(IChannelProvider provider) + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) { - Contracts.CheckValue(provider, nameof(provider)); - Ch = provider.Start("Cursor"); - + Ch.Dispose(); _position = -1; } + _disposed = true; + base.Dispose(disposing); - protected override void Dispose(bool disposing) - { - if (_disposed) - return; - if (disposing) - { - Ch.Dispose(); - _position = -1; - } - _disposed = true; - base.Dispose(disposing); + } - } + public sealed override bool MoveNext() + { + if (_disposed) + return false; - public sealed override bool MoveNext() + if (MoveNextCore()) { - if (_disposed) - return false; - - if (MoveNextCore()) - { - _position++; - return true; - } - - Dispose(); - return false; + _position++; + return true; } - /// - /// Core implementation of , called if no prior call to this method - /// has returned . - /// - protected abstract bool MoveNextCore(); + Dispose(); + return false; } + + /// + /// Core implementation of , called if no prior call to this method + /// has returned . + /// + protected abstract bool MoveNextCore(); } diff --git a/src/Microsoft.ML.Core/Data/SchemaExtensions.cs b/src/Microsoft.ML.Core/Data/SchemaExtensions.cs index c89d06984b..d750d09628 100644 --- a/src/Microsoft.ML.Core/Data/SchemaExtensions.cs +++ b/src/Microsoft.ML.Core/Data/SchemaExtensions.cs @@ -4,26 +4,25 @@ using System.Collections.Generic; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +[BestFriend] +internal static class SchemaExtensions { - [BestFriend] - internal static class SchemaExtensions + public static DataViewSchema MakeSchema(IEnumerable columns) { - public static DataViewSchema MakeSchema(IEnumerable columns) - { - var builder = new DataViewSchema.Builder(); - builder.AddColumns(columns); - return builder.ToSchema(); - } + var builder = new DataViewSchema.Builder(); + builder.AddColumns(columns); + return builder.ToSchema(); + } - /// - /// Legacy method to get the column index. - /// DO NOT USE: use instead. - /// - public static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col) - { - col = schema.GetColumnOrNull(name)?.Index ?? -1; - return col >= 0; - } + /// + /// Legacy method to get the column index. + /// DO NOT USE: use instead. + /// + public static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col) + { + col = schema.GetColumnOrNull(name)?.Index ?? -1; + return col >= 0; } } diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs index 306c19e654..f1ad584661 100644 --- a/src/Microsoft.ML.Core/Data/ServerChannel.cs +++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs @@ -6,266 +6,265 @@ using System.Collections.Generic; using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +/// +/// Instances of this class are used to set up a bundle of named delegates. These +/// delegates are registered through and its overloads. +/// Once all registrations are done, is called and a message +/// of type is sent through the input channel +/// provider. The intended use case is that any information surfaced through these +/// delegates will be published in some fashion, with the target scenario being +/// that the library will publish some sort of restful API. +/// +[BestFriend] +internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable { + // See ServerChannel.md for a more elaborate discussion of high level usage and design. + private readonly IChannelProvider _chp; + private readonly string _identifier; + + // This holds the running collection of named delegates, if any. The dictionary itself + // is lazily initialized only when a listener + private Dictionary _toPublish; + private Action _onPublish; + private Bundle _published; + private bool _disposed; + /// - /// Instances of this class are used to set up a bundle of named delegates. These - /// delegates are registered through and its overloads. - /// Once all registrations are done, is called and a message - /// of type is sent through the input channel - /// provider. The intended use case is that any information surfaced through these - /// delegates will be published in some fashion, with the target scenario being - /// that the library will publish some sort of restful API. + /// Returns either this object, or null if there are no listeners on this server + /// channel. This can be used in conjunction with the ?. operator to have more + /// performant though more robust calls to and + /// . /// - [BestFriend] - internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable - { - // See ServerChannel.md for a more elaborate discussion of high level usage and design. - private readonly IChannelProvider _chp; - private readonly string _identifier; - - // This holds the running collection of named delegates, if any. The dictionary itself - // is lazily initialized only when a listener - private Dictionary _toPublish; - private Action _onPublish; - private Bundle _published; - private bool _disposed; + private ServerChannel ThisIfActiveOrNull => _toPublish == null ? null : this; - /// - /// Returns either this object, or null if there are no listeners on this server - /// channel. This can be used in conjunction with the ?. operator to have more - /// performant though more robust calls to and - /// . - /// - private ServerChannel ThisIfActiveOrNull => _toPublish == null ? null : this; + private ServerChannel(IChannelProvider provider, string idenfier) + { + Contracts.AssertValue(provider); + _chp = provider; + _chp.AssertNonWhiteSpace(idenfier); + _identifier = idenfier; + } - private ServerChannel(IChannelProvider provider, string idenfier) + /// + /// Starts a new server channel. + /// + /// The channel provider, on which to send + /// the notification that a server is being constructed + /// A semi-unique identifier for this + /// "bundle" that is being constructed + /// The constructed server channel, or null if there + /// was no listeners for server channels registered on + public static ServerChannel Start(IChannelProvider provider, string identifier) + { + Contracts.CheckValue(provider, nameof(provider)); + provider.CheckNonWhiteSpace(identifier, nameof(identifier)); + using (var pipe = provider.StartPipe("Server")) { - Contracts.AssertValue(provider); - _chp = provider; - _chp.AssertNonWhiteSpace(idenfier); - _identifier = idenfier; + var sc = new ServerChannel(provider, identifier); + pipe.Send(sc); + return sc.ThisIfActiveOrNull; } + } - /// - /// Starts a new server channel. - /// - /// The channel provider, on which to send - /// the notification that a server is being constructed - /// A semi-unique identifier for this - /// "bundle" that is being constructed - /// The constructed server channel, or null if there - /// was no listeners for server channels registered on - public static ServerChannel Start(IChannelProvider provider, string identifier) + public void Dispose() + { + if (!_disposed) { - Contracts.CheckValue(provider, nameof(provider)); - provider.CheckNonWhiteSpace(identifier, nameof(identifier)); - using (var pipe = provider.StartPipe("Server")) - { - var sc = new ServerChannel(provider, identifier); - pipe.Send(sc); - return sc.ThisIfActiveOrNull; - } + _disposed = true; + _published?.Done(); } + } - public void Dispose() - { - if (!_disposed) - { - _disposed = true; - _published?.Done(); - } - } + private void RegisterCore(string name, Delegate func) + { + _chp.CheckNonEmpty(name, nameof(name)); + _chp.CheckValue(func, nameof(func)); + _chp.Check(_published == null, "Cannot expose more interfaces once a server channel has been published"); + _chp.AssertValue(_toPublish); - private void RegisterCore(string name, Delegate func) - { - _chp.CheckNonEmpty(name, nameof(name)); - _chp.CheckValue(func, nameof(func)); - _chp.Check(_published == null, "Cannot expose more interfaces once a server channel has been published"); - _chp.AssertValue(_toPublish); + _toPublish.Add(name, func); + } - _toPublish.Add(name, func); - } + public void Register(string name, Func func) + { + if (_toPublish != null) + RegisterCore(name, func); + } - public void Register(string name, Func func) - { - if (_toPublish != null) - RegisterCore(name, func); - } + public void Register(string name, Func func) + { + if (_toPublish != null) + RegisterCore(name, func); + } - public void Register(string name, Func func) - { - if (_toPublish != null) - RegisterCore(name, func); - } + public void Register(string name, Func func) + { + if (_toPublish != null) + RegisterCore(name, func); + } - public void Register(string name, Func func) - { - if (_toPublish != null) - RegisterCore(name, func); - } + public void Register(string name, Func func) + { + if (_toPublish != null) + RegisterCore(name, func); + } - public void Register(string name, Func func) - { - if (_toPublish != null) - RegisterCore(name, func); - } + /// + /// Finalizes all registrations of delegates, and pipes the bundle of objects + /// in a up through the pipe to be consumed by any + /// listeners. + /// + public void Publish() + { + _chp.Assert((_toPublish == null) == (_onPublish == null)); + if (_toPublish == null) + return; + _chp.Check(_published == null, "Cannot republish once a server channel has been published"); + _published = new Bundle(this); + _onPublish(_published); + } - /// - /// Finalizes all registrations of delegates, and pipes the bundle of objects - /// in a up through the pipe to be consumed by any - /// listeners. - /// - public void Publish() - { - _chp.Assert((_toPublish == null) == (_onPublish == null)); - if (_toPublish == null) - return; - _chp.Check(_published == null, "Cannot republish once a server channel has been published"); - _published = new Bundle(this); - _onPublish(_published); - } + public void Acknowledge(Action toDo) + { + _chp.CheckValue(toDo, nameof(toDo)); + _chp.Assert((_onPublish == null) == (_toPublish == null)); + if (_toPublish == null) + _toPublish = new Dictionary(); + _onPublish += toDo; + _chp.AssertValue(_onPublish); + } - public void Acknowledge(Action toDo) - { - _chp.CheckValue(toDo, nameof(toDo)); - _chp.Assert((_onPublish == null) == (_toPublish == null)); - if (_toPublish == null) - _toPublish = new Dictionary(); - _onPublish += toDo; - _chp.AssertValue(_onPublish); - } + /// + /// Entry point factory for creating instances. + /// + [TlcModule.ComponentKind("Server")] + public interface IServerFactory : IComponentFactory + { + new IServer CreateComponent(IHostEnvironment env, IChannel ch); + } + /// + /// Classes that want to publish the bundles from server channels in some fashion should implement + /// this interface. The intended simple use case is that this will be some form of in-process web + /// server, and then when disposed, they should stop themselves. + /// + /// Note that the primary communication with the server from the client code's perspective is not + /// through method calls on this interface, but rather communication through an + /// that the server will listen to throughout its + /// lifetime. + /// + public interface IServer : IDisposable + { /// - /// Entry point factory for creating instances. + /// This should return the base address where the server is. If this server is not actually + /// serving content at any URL, this property should be null. /// - [TlcModule.ComponentKind("Server")] - public interface IServerFactory : IComponentFactory - { - new IServer CreateComponent(IHostEnvironment env, IChannel ch); - } + Uri BaseAddress { get; } + } - /// - /// Classes that want to publish the bundles from server channels in some fashion should implement - /// this interface. The intended simple use case is that this will be some form of in-process web - /// server, and then when disposed, they should stop themselves. - /// - /// Note that the primary communication with the server from the client code's perspective is not - /// through method calls on this interface, but rather communication through an - /// that the server will listen to throughout its - /// lifetime. - /// - public interface IServer : IDisposable - { - /// - /// This should return the base address where the server is. If this server is not actually - /// serving content at any URL, this property should be null. - /// - Uri BaseAddress { get; } - } + /// + /// Creates what might be considered a good "default" server factory, if possible, + /// or null if no good default was possible. A null value could be returned, + /// for example, if a user opted to remove all implementations of and + /// the associated for security reasons. + /// + public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env) + { + Contracts.CheckValue(env, nameof(env)); + // REVIEW: There should be a better way. There currently isn't, + // but there should be. This is pretty horrifying, but it is preferable to + // the alternative of having core components depend on an actual server + // implementation, since we want those to be removable because of security + // concerns in certain environments (since not everyone will be wild about + // web servers popping up everywhere). + ComponentCatalog.ComponentInfo component; + if (!env.ComponentCatalog.TryFindComponent(typeof(IServerFactory), "mini", out component)) + return null; + IServerFactory factory = (IServerFactory)Activator.CreateInstance(component.ArgumentType); + var field = factory.GetType().GetField("Port"); + if (field?.FieldType != typeof(int)) + return null; + field.SetValue(factory, 12345); + return factory; + } + /// + /// When a is created, the creation method will send an implementation + /// is a notification sent through an , to indicate that + /// a may be pending soon. Listeners that want to receive the bundle to + /// expose it, for example, a web service, should register this interest by passing in an action to be called. + /// If no listener registers interest, the server channel that sent the notification will act + /// differently by, say, acting as a no-op w.r.t. client calls to it. + /// + public interface IPendingBundleNotification + { /// - /// Creates what might be considered a good "default" server factory, if possible, - /// or null if no good default was possible. A null value could be returned, - /// for example, if a user opted to remove all implementations of and - /// the associated for security reasons. + /// Any publisher of the named delegates will call this method, upon receiving an instance + /// of this object through the pipe. This method serves two purposes: firstly it detects + /// whether anyone is even interested in publishing anything at all, so that we can just + /// ignore any input delegates in the case where no one is listening (which, we must expect, + /// is the majority of scenarios). The second is that it provides an action to call, once + /// all publishing is complete, and has been called by the client code. /// - public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env) - { - Contracts.CheckValue(env, nameof(env)); - // REVIEW: There should be a better way. There currently isn't, - // but there should be. This is pretty horrifying, but it is preferable to - // the alternative of having core components depend on an actual server - // implementation, since we want those to be removable because of security - // concerns in certain environments (since not everyone will be wild about - // web servers popping up everywhere). - ComponentCatalog.ComponentInfo component; - if (!env.ComponentCatalog.TryFindComponent(typeof(IServerFactory), "mini", out component)) - return null; - IServerFactory factory = (IServerFactory)Activator.CreateInstance(component.ArgumentType); - var field = factory.GetType().GetField("Port"); - if (field?.FieldType != typeof(int)) - return null; - field.SetValue(factory, 12345); - return factory; - } + /// The callback to perform when all named delegates have been registered, + /// and is called. + void Acknowledge(Action toDo); + } + /// + /// The final bundle of published named delegates that a listener can serve. + /// + public sealed class Bundle + { /// - /// When a is created, the creation method will send an implementation - /// is a notification sent through an , to indicate that - /// a may be pending soon. Listeners that want to receive the bundle to - /// expose it, for example, a web service, should register this interest by passing in an action to be called. - /// If no listener registers interest, the server channel that sent the notification will act - /// differently by, say, acting as a no-op w.r.t. client calls to it. + /// This contains a name to delegate mappings. The delegates contained herein are gauranteed to be + /// some variety of , , + /// , etc. /// - public interface IPendingBundleNotification - { - /// - /// Any publisher of the named delegates will call this method, upon receiving an instance - /// of this object through the pipe. This method serves two purposes: firstly it detects - /// whether anyone is even interested in publishing anything at all, so that we can just - /// ignore any input delegates in the case where no one is listening (which, we must expect, - /// is the majority of scenarios). The second is that it provides an action to call, once - /// all publishing is complete, and has been called by the client code. - /// - /// The callback to perform when all named delegates have been registered, - /// and is called. - void Acknowledge(Action toDo); - } + public readonly IReadOnlyDictionary NameToFuncs; /// - /// The final bundle of published named delegates that a listener can serve. + /// This should be a more-or-less unique identifier for the type of API this bundle is producing. + /// Its intended use is that it will form part of the URL for the RESTful API, so to the extent that + /// it contains multiple tokens they must be slash delimited. /// - public sealed class Bundle - { - /// - /// This contains a name to delegate mappings. The delegates contained herein are gauranteed to be - /// some variety of , , - /// , etc. - /// - public readonly IReadOnlyDictionary NameToFuncs; - - /// - /// This should be a more-or-less unique identifier for the type of API this bundle is producing. - /// Its intended use is that it will form part of the URL for the RESTful API, so to the extent that - /// it contains multiple tokens they must be slash delimited. - /// - public readonly string Identifier; + public readonly string Identifier; - internal Action Done; + internal Action Done; - internal Bundle(ServerChannel sch) - { - Contracts.AssertValue(sch); + internal Bundle(ServerChannel sch) + { + Contracts.AssertValue(sch); - NameToFuncs = sch._toPublish; - Identifier = sch._identifier; - } + NameToFuncs = sch._toPublish; + Identifier = sch._identifier; + } - public void AddDoneAction(Action onDone) - { - Done += onDone; - } + public void AddDoneAction(Action onDone) + { + Done += onDone; } } +} - [BestFriend] - internal static class ServerChannelUtilities +[BestFriend] +internal static class ServerChannelUtilities +{ + /// + /// Convenience method for that looks more idiomatic to typical + /// channel creation methods on . + /// + /// The channel provider. + /// This is an identifier of the "type" of bundle that is being published, + /// and should form a path with forward-slash '/' delimiters. + /// The newly created server channel, or null if there was no listener for + /// server channels on . + public static ServerChannel StartServerChannel(this IChannelProvider provider, string identifier) { - /// - /// Convenience method for that looks more idiomatic to typical - /// channel creation methods on . - /// - /// The channel provider. - /// This is an identifier of the "type" of bundle that is being published, - /// and should form a path with forward-slash '/' delimiters. - /// The newly created server channel, or null if there was no listener for - /// server channels on . - public static ServerChannel StartServerChannel(this IChannelProvider provider, string identifier) - { - Contracts.CheckValue(provider, nameof(provider)); - Contracts.CheckNonWhiteSpace(identifier, nameof(identifier)); - return ServerChannel.Start(provider, identifier); - } + Contracts.CheckValue(provider, nameof(provider)); + Contracts.CheckNonWhiteSpace(identifier, nameof(identifier)); + return ServerChannel.Start(provider, identifier); } } diff --git a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs index 26799aa1b8..187d612ab7 100644 --- a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs @@ -4,69 +4,68 @@ using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Base class for creating a cursor on top of another cursor that does not add or remove rows. +/// It forces one-to-one correspondence between items in the input cursor and this cursor. +/// It delegates all functionality except Dispose() to the root cursor. +/// Dispose is virtual with the default implementation delegating to the input cursor. +/// +[BestFriend] +internal abstract class SynchronizedCursorBase : DataViewRowCursor { + protected readonly IChannel Ch; + /// - /// Base class for creating a cursor on top of another cursor that does not add or remove rows. - /// It forces one-to-one correspondence between items in the input cursor and this cursor. - /// It delegates all functionality except Dispose() to the root cursor. - /// Dispose is virtual with the default implementation delegating to the input cursor. + /// The synchronized cursor base, as it merely passes through requests for all "positional" calls (including + /// , , , and so forth), offers an opportunity + /// for optimization for "wrapping" cursors (which are themselves often + /// implementors) to get this root cursor. But, this can only be done by exposing this root cursor, as we do here. + /// Internal code should be quite careful in using this as the potential for misuse is quite high. /// - [BestFriend] - internal abstract class SynchronizedCursorBase : DataViewRowCursor - { - protected readonly IChannel Ch; - - /// - /// The synchronized cursor base, as it merely passes through requests for all "positional" calls (including - /// , , , and so forth), offers an opportunity - /// for optimization for "wrapping" cursors (which are themselves often - /// implementors) to get this root cursor. But, this can only be done by exposing this root cursor, as we do here. - /// Internal code should be quite careful in using this as the potential for misuse is quite high. - /// - internal readonly DataViewRowCursor Root; - private bool _disposed; + internal readonly DataViewRowCursor Root; + private bool _disposed; - protected DataViewRowCursor Input { get; } + protected DataViewRowCursor Input { get; } - public sealed override long Position => Root.Position; + public sealed override long Position => Root.Position; - public sealed override long Batch => Root.Batch; + public sealed override long Batch => Root.Batch; - /// - /// Convenience property for checking whether the cursor is in a good state where values - /// can be retrieved, that is, whenever is non-negative. - /// - protected bool IsGood => Position >= 0; + /// + /// Convenience property for checking whether the cursor is in a good state where values + /// can be retrieved, that is, whenever is non-negative. + /// + protected bool IsGood => Position >= 0; - protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input) - { - Contracts.AssertValue(provider); - Ch = provider.Start("Cursor"); + protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input) + { + Contracts.AssertValue(provider); + Ch = provider.Start("Cursor"); - Ch.AssertValue(input); - Input = input; - // If this thing happens to be itself an instance of this class (which, practically, it will - // be in the majority of situations), we can treat the input as likewise being a passthrough, - // thereby saving lots of "nested" calls on the stack when doing common operations like movement. - Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input; - } + Ch.AssertValue(input); + Input = input; + // If this thing happens to be itself an instance of this class (which, practically, it will + // be in the majority of situations), we can treat the input as likewise being a passthrough, + // thereby saving lots of "nested" calls on the stack when doing common operations like movement. + Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input; + } - protected override void Dispose(bool disposing) + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) { - if (_disposed) - return; - if (disposing) - { - Input.Dispose(); - Ch.Dispose(); - } - base.Dispose(disposing); - _disposed = true; + Input.Dispose(); + Ch.Dispose(); } + base.Dispose(disposing); + _disposed = true; + } - public sealed override bool MoveNext() => Root.MoveNext(); + public sealed override bool MoveNext() => Root.MoveNext(); - public sealed override ValueGetter GetIdGetter() => Input.GetIdGetter(); - } + public sealed override ValueGetter GetIdGetter() => Input.GetIdGetter(); } diff --git a/src/Microsoft.ML.Core/Data/WrappingRow.cs b/src/Microsoft.ML.Core/Data/WrappingRow.cs index e1faf66540..ad2be326a7 100644 --- a/src/Microsoft.ML.Core/Data/WrappingRow.cs +++ b/src/Microsoft.ML.Core/Data/WrappingRow.cs @@ -5,61 +5,60 @@ using System; using Microsoft.ML.Runtime; -namespace Microsoft.ML.Data +namespace Microsoft.ML.Data; + +/// +/// Convenient base class for implementors that wrap a single +/// as their input. The , , and +/// are taken from this . +/// +[BestFriend] +internal abstract class WrappingRow : DataViewRow { + private bool _disposed; + /// - /// Convenient base class for implementors that wrap a single - /// as their input. The , , and - /// are taken from this . + /// The wrapped input row. /// - [BestFriend] - internal abstract class WrappingRow : DataViewRow - { - private bool _disposed; + protected DataViewRow Input { get; } - /// - /// The wrapped input row. - /// - protected DataViewRow Input { get; } + public sealed override long Batch => Input.Batch; + public sealed override long Position => Input.Position; + public override ValueGetter GetIdGetter() => Input.GetIdGetter(); - public sealed override long Batch => Input.Batch; - public sealed override long Position => Input.Position; - public override ValueGetter GetIdGetter() => Input.GetIdGetter(); - - [BestFriend] - private protected WrappingRow(DataViewRow input) - { - Contracts.AssertValue(input); - Input = input; - } + [BestFriend] + private protected WrappingRow(DataViewRow input) + { + Contracts.AssertValue(input); + Input = input; + } - /// - /// This override of the dispose method by default only calls 's - /// method, but subclasses can enable additional functionality - /// via the functionality. - /// - /// - protected sealed override void Dispose(bool disposing) - { - if (_disposed) - return; - // Since the input was created first, and this instance may depend on it, we should - // dispose local resources first before potentially disposing the input row resources. - DisposeCore(disposing); - if (disposing) - Input.Dispose(); - _disposed = true; - } + /// + /// This override of the dispose method by default only calls 's + /// method, but subclasses can enable additional functionality + /// via the functionality. + /// + /// + protected sealed override void Dispose(bool disposing) + { + if (_disposed) + return; + // Since the input was created first, and this instance may depend on it, we should + // dispose local resources first before potentially disposing the input row resources. + DisposeCore(disposing); + if (disposing) + Input.Dispose(); + _disposed = true; + } - /// - /// Called from with in the case where - /// that method has never been called before, and right after has been - /// disposed. The default implementation does nothing. - /// - /// Whether this was called through the dispose path, as opposed - /// to the finalizer path. - protected virtual void DisposeCore(bool disposing) - { - } + /// + /// Called from with in the case where + /// that method has never been called before, and right after has been + /// disposed. The default implementation does nothing. + /// + /// Whether this was called through the dispose path, as opposed + /// to the finalizer path. + protected virtual void DisposeCore(bool disposing) + { } }