diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 5352040958..41abfffcf1 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -280,6 +280,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); var type = inputSchema.GetColumnType(colSrc); + _parent.CheckInputColumn(inputSchema, i, colSrc); infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); } return infos; diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 7f86b36bd7..ba6f0f866e 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -13983,7 +13983,7 @@ public MinMaxNormalizerPipelineStep(Output output) namespace Legacy.Transforms { - public enum NAHandleTransformReplacementKind + public enum NAHandleTransformReplacementKind : byte { DefaultValue = 0, Mean = 1, @@ -14444,7 +14444,7 @@ public MissingValuesRowDropperPipelineStep(Output output) namespace Legacy.Transforms { - public enum NAReplaceTransformReplacementKind + public enum NAReplaceTransformReplacementKind : byte { DefaultValue = 0, Mean = 1, diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index cde6ae9aa1..955940b916 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -69,6 +69,7 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Input, x.Output)).ToArray(); } + public IReadOnlyCollection Columns => _columns.AsReadOnly(); private readonly ColumnInfo[] _columns; @@ -209,7 +210,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); var type = inputSchema.GetColumnType(colSrc); - + _parent.CheckInputColumn(inputSchema, i, colSrc); infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); } return infos; diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index bb31510252..bf3ab5f41a 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -20,28 +20,28 @@ namespace Microsoft.ML.Runtime.Data /// public static class NAHandleTransform { - public enum ReplacementKind + public enum ReplacementKind : byte { /// /// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns. /// [EnumValueDisplay("Zero/empty")] - DefaultValue, + DefaultValue = 0, /// /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns. /// - Mean, + Mean = 1, /// /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns. /// - Minimum, + Minimum = 2, /// /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns. /// - Maximum, + Maximum = 3, [HideEnumValue] Def = DefaultValue, @@ -135,7 +135,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - var replaceCols = new List(); + var replaceCols = new List(); var naIndicatorCols = new List(); var naConvCols = new List(); var concatCols = new List(); @@ -149,26 +149,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var addInd = column.ConcatIndicator ?? args.Concat; if (!addInd) { - replaceCols.Add( - new NAReplaceTransform.Column() - { - Kind = (NAReplaceTransform.ReplacementKind?)column.Kind, - Name = column.Name, - Source = column.Source, - Slot = column.ImputeBySlot - }); + replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); continue; } // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type, // so that they can be concatenated. - int inputCol; - if (!input.Schema.TryGetColumnIndex(column.Source, out inputCol)) + if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol)) throw h.Except("Column '{0}' does not exist", column.Source); var replaceType = input.Schema.GetColumnType(inputCol); - Delegate conv; - bool identity; - if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out conv, out identity)) + if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity)) { throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'", BoolType.Instance, replaceType.ItemType); @@ -186,14 +176,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV naConvCols.Add(new ConvertTransform.Column() { Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind }); // Add the NAReplaceTransform column. - replaceCols.Add( - new NAReplaceTransform.Column() - { - Kind = (NAReplaceTransform.ReplacementKind?)column.Kind, - Name = tmpReplacementColName, - Source = column.Source, - Slot = column.ImputeBySlot - }); + replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType.IsVector) @@ -237,15 +220,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.AssertValue(output); output = new ConvertTransform(h, new ConvertTransform.Arguments() { Column = naConvCols.ToArray() }, output); } - // Create the NAReplace transform. - output = new NAReplaceTransform(h, - new NAReplaceTransform.Arguments() - { - Column = replaceCols.ToArray(), - ReplacementKind = (NAReplaceTransform.ReplacementKind)args.ReplaceWith, - ImputeBySlot = args.ImputeBySlot - }, output ?? input); + output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray()); // Concat the NAReplaceTransform output and the NAIndicatorTransform output. if (naIndicatorCols.Count > 0) diff --git a/src/Microsoft.ML.Transforms/NAHandling.cs b/src/Microsoft.ML.Transforms/NAHandling.cs index 0870d16461..8f8fdaabb9 100644 --- a/src/Microsoft.ML.Transforms/NAHandling.cs +++ b/src/Microsoft.ML.Transforms/NAHandling.cs @@ -88,7 +88,7 @@ public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, NAIn public static CommonOutputs.TransformOutput Replace(IHostEnvironment env, NAReplaceTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAReplace", input); - var xf = new NAReplaceTransform(h, input, input.Data); + var xf = NAReplaceTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index c9b309af89..d8cd698658 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -2,29 +2,37 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections; -using System.Collections.Generic; -using System.Reflection; -using System.Text; -using System.IO; -using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Model.Onnx; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; + +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), + NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] -[assembly: LoadableClass(typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), - NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), null, typeof(SignatureLoadDataTransform), + NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] -[assembly: LoadableClass(typeof(NAReplaceTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(NAReplaceTransform), null, typeof(SignatureLoadModel), NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(NAReplaceTransform), null, typeof(SignatureLoadRowMapper), + NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] + namespace Microsoft.ML.Runtime.Data { // This transform can transform either scalars or vectors (both fixed and variable size), @@ -33,16 +41,16 @@ namespace Microsoft.ML.Runtime.Data // Imputation modes are supported for vectors both by slot and across all slots. // REVIEW: May make sense to implement the transform template interface. /// - public sealed partial class NAReplaceTransform : OneToOneTransformBase + public sealed partial class NAReplaceTransform : OneToOneTransformerBase { - public enum ReplacementKind + public enum ReplacementKind : byte { // REVIEW: What should the full list of options for this transform be? - DefaultValue, - Mean, - Minimum, - Maximum, - SpecifiedValue, + DefaultValue = 0, + Mean = 1, + Minimum = 2, + Maximum = 3, + SpecifiedValue = 4, [HideEnumValue] Def = DefaultValue, @@ -112,14 +120,15 @@ public sealed class Arguments : TransformInputBase public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "The replacement method to utilize", ShortName = "kind")] - public ReplacementKind ReplacementKind = ReplacementKind.DefaultValue; + public ReplacementKind ReplacementKind = (ReplacementKind)NAReplaceEstimator.Defaults.ReplacementMode; // Specifying by-slot imputation for vectors of unknown size will cause a warning, and the imputation will be global. [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to impute values by slot", ShortName = "slot")] - public bool ImputeBySlot = true; + public bool ImputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot; } public const string LoadName = "NAReplaceTransform"; + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -133,12 +142,12 @@ private static VersionInfo GetVersionInfo() } internal const string Summary = "Create an output column of the same type and size of the input column, where missing values " - + "are replaced with either the default value or the mean/min/max value (for non-text columns only)."; + + "are replaced with either the default value or the mean/min/max value (for non-text columns only)."; internal const string FriendlyName = "NA Replace Transform"; internal const string ShortName = "NARep"; - private static string TestType(ColumnType type) + internal static string TestType(ColumnType type) { // Item type must have an NA value that exists and is not equal to its default value. Func func = TestType; @@ -149,8 +158,7 @@ private static string TestType(ColumnType type) private static string TestType(ColumnType type) { Contracts.Assert(type.ItemType.RawType == typeof(T)); - RefPredicate isNA; - if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out isNA)) + if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out RefPredicate isNA)) { return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", type, LoadName); @@ -165,8 +173,50 @@ private static string TestType(ColumnType type) return null; } + public class ColumnInfo + { + public enum ReplacementMode : byte + { + DefaultValue = 0, + Mean = 1, + Minimum = 2, + Maximum = 3, + } + + public readonly string Input; + public readonly string Output; + public readonly bool ImputeBySlot; + public readonly ReplacementMode Replacement; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of output column. + /// What to replace the missing value with. + /// If true, per-slot imputation of replacement is performed. + /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, + /// where imputation is always for the entire column. + public ColumnInfo(string input, string output, ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, + bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Input = input; + Output = output; + ImputeBySlot = imputeBySlot; + Replacement = replacementMode; + } + + internal string ReplacementString { get; set; } + } + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } + // The output column types, parallel to Infos. - private readonly ColumnType[] _types; + private readonly ColumnType[] _replaceTypes; // The replacementValues for the columns, parallel to Infos. // The elements of this array can be either primitive values or arrays of primitive values. When replacing a scalar valued column in Infos, @@ -178,89 +228,56 @@ private static string TestType(ColumnType type) // Marks if the replacement values in given slots of _repValues are the default value. // REVIEW: Currently these arrays are constructed on load but could be changed to being constructed lazily. - private BitArray[] _repIsDefault; - - // The isNA delegates, parallel to Infos. - private readonly Delegate[] _isNAs; + private readonly BitArray[] _repIsDefault; - public override bool CanSaveOnnx => true; - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - /// The replacement method to utilize. - public NAReplaceTransform(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replacementKind = ReplacementKind.DefaultValue) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ReplacementKind = replacementKind }, input) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { + var type = inputSchema.GetColumnType(srcCol); + string reason = TestType(type); + if (reason != null) + throw Host.ExceptParam(nameof(inputSchema), reason); } - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public NAReplaceTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column, - input, TestType) + public NAReplaceTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAReplaceTransform)), GetColumnPairs(columns)) { - Host.CheckValue(args, nameof(args)); - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - - GetInfoAndMetadata(out _types, out _isNAs); - GetReplacementValues(args, out _repValues, out _repIsDefault); + // Check that all the input columns are present and correct. + for (int i = 0; i < ColumnPairs.Length; i++) + { + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + CheckInputColumn(input.Schema, i, srcCol); + } + GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _replaceTypes); } - private NAReplaceTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) + private NAReplaceTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - Host.AssertValue(ctx); - Host.AssertNonEmpty(Infos); - - GetInfoAndMetadata(out _types, out _isNAs); - - // *** Binary format *** - // - // for each column: - // type and value - _repValues = new object[Infos.Length]; - _repIsDefault = new BitArray[Infos.Length]; + var columnsLength = ColumnPairs.Length; + _repValues = new object[columnsLength]; + _repIsDefault = new BitArray[columnsLength]; + _replaceTypes = new ColumnType[columnsLength]; var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + for (int i = 0; i < columnsLength; i++) { - object repValue; - ColumnType repType; - if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out repType, out repValue)) + if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out ColumnType savedType, out object repValue)) throw Host.ExceptDecode(); - if (!_types[iinfo].ItemType.Equals(repType.ItemType)) - throw Host.ExceptParam(nameof(input), "Decoded serialization of type '{0}' does not match expected ColumnType of '{1}'", repType.ItemType, _types[iinfo].ItemType); - // If type is a vector and the value is not either a scalar or a vector of the same size, throw an error. - if (repType.IsVector) + _replaceTypes[i] = savedType; + if (savedType.IsVector) { - if (!_types[iinfo].IsVector) - throw Host.ExceptParam(nameof(input), "Decoded serialization of type '{0}' cannot be a vector when Columntype is a scalar of type '{1}'", repType, _types[iinfo]); - if (!_types[iinfo].IsKnownSizeVector) - throw Host.ExceptParam(nameof(input), "Decoded serialization for unknown size vector '{0}' must be a scalar instead of type '{1}'", _types[iinfo], repType); - if (_types[iinfo].VectorSize != repType.VectorSize) - { - throw Host.ExceptParam(nameof(input), "Decoded serialization of type '{0}' must be a scalar or a vector of the same size as Columntype '{1}'", - repType, _types[iinfo]); - } - // REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array. // It might be of value to consider storing the VBUffer in order to possibly benefit from sparsity. However, this would // necessitate a reimplementation of the FillValues code to accomodate sparse VBuffers. - object[] args = new object[] { repValue, _types[iinfo], iinfo }; + object[] args = new object[] { repValue, _replaceTypes[i], i }; Func, ColumnType, int, int[]> func = GetValuesArray; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repType.ItemType.RawType); - _repValues[iinfo] = meth.Invoke(this, args); + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(savedType.ItemType.RawType); + _repValues[i] = meth.Invoke(this, args); } else - _repValues[iinfo] = repValue; + _repValues[i] = repValue; - Host.Assert(repValue.GetType() == _types[iinfo].RawType || repValue.GetType() == _types[iinfo].ItemType.RawType); + Host.Assert(repValue.GetType() == _replaceTypes[i].RawType || repValue.GetType() == _replaceTypes[i].ItemType.RawType); } } @@ -282,105 +299,53 @@ private T[] GetValuesArray(VBuffer src, ColumnType srcType, int iinfo) return valReturn; } - public static NAReplaceTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(LoadName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new NAReplaceTransform(h, ctx, input)); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // - // for each column: - // type and value - SaveBase(ctx); - var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int iinfo = 0; iinfo < _types.Length; iinfo++) - { - var repValue = _repValues[iinfo]; - var repType = _types[iinfo].ItemType; - if (_repIsDefault[iinfo] != null) - { - Host.Assert(repValue is Array); - Func> function = CreateVBuffer; - var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_types[iinfo].ItemType.RawType); - repValue = method.Invoke(this, new object[] { _repValues[iinfo] }); - repType = _types[iinfo]; - } - Host.Assert(!(repValue is Array)); - object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, repValue }; - Action func = WriteTypeAndValue; - Host.Assert(repValue.GetType() == _types[iinfo].RawType || repValue.GetType() == _types[iinfo].ItemType.RawType); - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); - meth.Invoke(this, args); - } - } - - private VBuffer CreateVBuffer(T[] array) - { - Host.AssertValue(array); - return new VBuffer(array.Length, array); - } - - private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType type, T rep) - { - Host.AssertValue(stream); - Host.AssertValue(saver); - Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); - - int bytesWritten; - if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out bytesWritten)) - throw Host.Except("We do not know how to serialize terms of type '{0}'", type); - } - /// /// Fill the repValues array with the correct replacement values based on the user-given replacement kinds. /// Vectors default to by-slot imputation unless otherwise specified, except for unknown sized vectors /// which force across-slot imputation. /// - private void GetReplacementValues(Arguments args, out object[] repValues, out BitArray[] slotIsDefault) + private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out object[] repValues, out BitArray[] slotIsDefault, out ColumnType[] types) { - repValues = new object[Infos.Length]; - slotIsDefault = new BitArray[Infos.Length]; - - ReplacementKind?[] imputationModes = new ReplacementKind?[Infos.Length]; + repValues = new object[columns.Length]; + slotIsDefault = new BitArray[columns.Length]; + types = new ColumnType[columns.Length]; + var sources = new int[columns.Length]; + ReplacementKind[] imputationModes = new ReplacementKind[columns.Length]; List columnsToImpute = null; // REVIEW: Would like to get rid of the sourceColumns list but seems to be the best way to provide // the cursor with what columns to cursor through. HashSet sourceColumns = null; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + for (int iinfo = 0; iinfo < columns.Length; iinfo++) { - ReplacementKind kind = args.Column[iinfo].Kind ?? args.ReplacementKind; + input.Schema.TryGetColumnIndex(columns[iinfo].Input, out int colSrc); + sources[iinfo] = colSrc; + var type = input.Schema.GetColumnType(colSrc); + if (type.IsVector) + type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + Delegate isNa = GetIsNADelegate(type); + types[iinfo] = type; + var kind = (ReplacementKind)columns[iinfo].Replacement; switch (kind) { - case ReplacementKind.SpecifiedValue: - repValues[iinfo] = GetSpecifiedValue(args.Column[iinfo].ReplacementString, _types[iinfo], _isNAs[iinfo]); - break; - case ReplacementKind.DefaultValue: - repValues[iinfo] = GetDefault(_types[iinfo]); - break; - case ReplacementKind.Mean: - case ReplacementKind.Min: - case ReplacementKind.Max: - if (!_types[iinfo].ItemType.IsNumber && !_types[iinfo].ItemType.IsTimeSpan && !_types[iinfo].ItemType.IsDateTime) - throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", _types[iinfo].ItemType); - imputationModes[iinfo] = kind; - Utils.Add(ref columnsToImpute, iinfo); - Utils.Add(ref sourceColumns, Infos[iinfo].Source); - break; - default: - Host.Assert(false); - throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", kind); + case ReplacementKind.SpecifiedValue: + repValues[iinfo] = GetSpecifiedValue(columns[iinfo].ReplacementString, _replaceTypes[iinfo], isNa); + break; + case ReplacementKind.DefaultValue: + repValues[iinfo] = GetDefault(type); + break; + case ReplacementKind.Mean: + case ReplacementKind.Minimum: + case ReplacementKind.Maximum: + if (!type.ItemType.IsNumber && !type.ItemType.IsTimeSpan && !type.ItemType.IsDateTime) + throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.ItemType); + imputationModes[iinfo] = kind; + Utils.Add(ref columnsToImpute, iinfo); + Utils.Add(ref sourceColumns, colSrc); + break; + default: + Host.Assert(false); + throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", columns[iinfo].Replacement); } } @@ -390,20 +355,21 @@ private void GetReplacementValues(Arguments args, out object[] repValues, out Bi // Impute values. using (var ch = Host.Start("Computing Statistics")) - using (var cursor = Source.GetRowCursor(sourceColumns.Contains)) + using (var cursor = input.GetRowCursor(sourceColumns.Contains)) { StatAggregator[] statAggregators = new StatAggregator[columnsToImpute.Count]; for (int ii = 0; ii < columnsToImpute.Count; ii++) { int iinfo = columnsToImpute[ii]; - bool bySlot = args.Column[ii].Slot ?? args.ImputeBySlot; - if (_types[iinfo].IsVector && !_types[iinfo].IsKnownSizeVector && bySlot) + bool bySlot = columns[ii].ImputeBySlot; + if (types[iinfo].IsVector && !types[iinfo].IsKnownSizeVector && bySlot) { ch.Warning("By-slot imputation can not be done on variable-length column"); bySlot = false; } - statAggregators[ii] = CreateStatAggregator(ch, _types[iinfo], imputationModes[iinfo], bySlot, - cursor, Infos[iinfo].Source); + + statAggregators[ii] = CreateStatAggregator(ch, types[iinfo], imputationModes[iinfo], bySlot, + cursor, sources[iinfo]); } while (cursor.MoveNext()) @@ -425,8 +391,8 @@ private void GetReplacementValues(Arguments args, out object[] repValues, out Bi if (repValues[slot] is Array) { Func func = ComputeDefaultSlots; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_types[slot].ItemType.RawType); - slotIsDefault[slot] = (BitArray)meth.Invoke(this, new object[] { _types[slot], repValues[slot] }); + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(types[slot].ItemType.RawType); + slotIsDefault[slot] = (BitArray)meth.Invoke(this, new object[] { types[slot], repValues[slot] }); } } } @@ -444,31 +410,6 @@ private BitArray ComputeDefaultSlots(ColumnType type, T[] values) return defaultSlots; } - private void GetInfoAndMetadata(out ColumnType[] types, out Delegate[] isNAs) - { - var md = Metadata; - types = new ColumnType[Infos.Length]; - isNAs = new Delegate[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - var type = Infos[iinfo].TypeSrc; - - if (!type.IsVector) - types[iinfo] = type; - else - types[iinfo] = new VectorType(type.ItemType.AsPrimitive, type.AsVector); - - isNAs[iinfo] = GetIsNADelegate(type); - - // Pass through slot name metadata and normalization data. - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source, - MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized)) - { - } - } - md.Seal(); - } - private object GetDefault(ColumnType type) { Func func = GetDefault; @@ -495,12 +436,6 @@ private Delegate GetIsNADelegate(ColumnType type) return Conversions.Instance.GetIsNAPredicate(type.ItemType); } - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; - } - /// /// Converts a string to its respective value in the corresponding type. /// @@ -518,8 +453,7 @@ private object GetSpecifiedValue(string srcStr, ColumnType dstType, RefPredic { // Handles converting input strings to correct types. DvText srcTxt = new DvText(srcStr); - bool identity; - var strToT = Conversions.Instance.GetStandardConversion(TextType.Instance, dstType.ItemType, out identity); + var strToT = Conversions.Instance.GetStandardConversion(TextType.Instance, dstType.ItemType, out bool identity); strToT(ref srcTxt, ref val); // Make sure that the srcTxt can legitimately be converted to dstType, throw error otherwise. if (isNA(ref val)) @@ -529,367 +463,678 @@ private object GetSpecifiedValue(string srcStr, ColumnType dstType, RefPredic return val; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - if (!Infos[iinfo].TypeSrc.IsVector) - return ComposeGetterOne(input, iinfo); - return ComposeGetterVec(input, iinfo); + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + var kind = item.Kind ?? args.ReplacementKind; + if (!Enum.IsDefined(typeof(ReplacementKind), kind)) + throw env.ExceptUserArg(nameof(args.ReplacementKind), "Undefined sorting criteria '{0}' detected for column '{1}'", kind, item.Name); + + cols[i] = new ColumnInfo(item.Source, + item.Name, + (ColumnInfo.ReplacementMode)(item.Kind ?? args.ReplacementKind), + item.Slot ?? args.ImputeBySlot); + cols[i].ReplacementString = item.ReplacementString; + }; + return new NAReplaceTransform(env, input, cols).MakeDataTransform(input); } - /// - /// Getter generator for single valued inputs. - /// - private Delegate ComposeGetterOne(IRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterOne, Infos[iinfo].TypeSrc.RawType, input, iinfo); + public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + { + return new NAReplaceTransform(env, input, columns).MakeDataTransform(input); + } - /// - /// Replaces NA values for scalars. - /// - private Delegate ComposeGetterOne(IRow input, int iinfo) + // Factory method for SignatureLoadModel. + public static NAReplaceTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - var getSrc = GetSrcGetter(input, iinfo); - var src = default(T); - var isNA = (RefPredicate)_isNAs[iinfo]; - Host.Assert(_repValues[iinfo] is T); - T rep = (T)_repValues[iinfo]; - ValueGetter getter; + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(LoadName); - return getter = - (ref T dst) => - { - getSrc(ref src); - dst = isNA(ref src) ? rep : src; - }; + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new NAReplaceTransform(host, ctx); } - /// - /// Getter generator for vector valued inputs. - /// - private Delegate ComposeGetterVec(IRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterVec, Infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); - /// - /// Replaces NA values for vectors. - /// - private Delegate ComposeGetterVec(IRow input, int iinfo) + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + private VBuffer CreateVBuffer(T[] array) { - var getSrc = GetSrcGetter>(input, iinfo); - var isNA = (RefPredicate)_isNAs[iinfo]; - var isDefault = Conversions.Instance.GetIsDefaultPredicate(input.Schema.GetColumnType(Infos[iinfo].Source).ItemType); + Host.AssertValue(array); + return new VBuffer(array.Length, array); + } - var src = default(VBuffer); - ValueGetter> getter; + private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType type, T rep) + { + Host.AssertValue(stream); + Host.AssertValue(saver); + Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); - if (_repIsDefault[iinfo] == null) - { - // One replacement value for all slots. - Host.Assert(_repValues[iinfo] is T); - T rep = (T)_repValues[iinfo]; - bool repIsDefault = isDefault(ref rep); - return getter = - (ref VBuffer dst) => - { - getSrc(ref src); - FillValues(ref src, ref dst, isNA, rep, repIsDefault); - }; - } + if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out int bytesWritten)) + throw Host.Except("We do not know how to serialize terms of type '{0}'", type); + } - // Replacement values by slot. - Host.Assert(_repValues[iinfo] is T[]); - // The replacement array. - T[] repArray = (T[])_repValues[iinfo]; + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); - return getter = - (ref VBuffer dst) => + SaveColumns(ctx); + var saver = new BinarySaver(Host, new BinarySaver.Arguments()); + for (int iinfo = 0; iinfo < _replaceTypes.Length; iinfo++) + { + var repValue = _repValues[iinfo]; + var repType = _replaceTypes[iinfo].ItemType; + if (_repIsDefault[iinfo] != null) { - getSrc(ref src); - Host.Check(src.Length == repArray.Length); - FillValues(ref src, ref dst, isNA, repArray, _repIsDefault[iinfo]); - }; + Host.Assert(repValue is Array); + Func> function = CreateVBuffer; + var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repType.RawType); + repValue = method.Invoke(this, new object[] { _repValues[iinfo] }); + repType = _replaceTypes[iinfo]; + } + Host.Assert(!(repValue is Array)); + object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, repValue }; + Action func = WriteTypeAndValue; + Host.Assert(repValue.GetType() == _replaceTypes[iinfo].RawType || repValue.GetType() == _replaceTypes[iinfo].ItemType.RawType); + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); + meth.Invoke(this, args); + } } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + private sealed class Mapper : MapperBase, ISaveAsOnnx { - DataKind rawKind; - var type = Infos[iinfo].TypeSrc; - if (type.IsVector) - rawKind = type.AsVector.ItemType.RawKind; - else if (type.IsKey) - rawKind = type.AsKey.RawKind; - else - rawKind = type.RawKind; - if (rawKind != DataKind.R4) - return false; + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } - string opType = "Imputer"; - var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - node.AddAttribute("replaced_value_float", Single.NaN); + private readonly NAReplaceTransform _parent; + private readonly ColInfo[] _infos; + private readonly ColumnType[] _types; + // The isNA delegates, parallel to Infos. + private readonly Delegate[] _isNAs; + public bool CanSaveOnnx => true; - if (!Infos[iinfo].TypeSrc.IsVector) - node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1)); - else + public Mapper(NAReplaceTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - if (_repIsDefault[iinfo] != null) - node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]); - else - node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1)); + _parent = parent; + _infos = CreateInfos(inputSchema); + _types = new ColumnType[_parent.ColumnPairs.Length]; + _isNAs = new Delegate[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + var type = _infos[i].TypeSrc; + if (type.IsVector) + type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + var repType = _parent._repIsDefault[i] != null ? _parent._replaceTypes[i] : _parent._replaceTypes[i].ItemType; + if (!type.ItemType.Equals(repType.ItemType)) + throw Host.ExceptParam(nameof(InputSchema), "Column '{0}' item type '{1}' does not match expected ColumnType of '{2}'", + _infos[i].Source, _parent._replaceTypes[i].ItemType.ToString(), _infos[i].TypeSrc); + // If type is a vector and the value is not either a scalar or a vector of the same size, throw an error. + if (repType.IsVector) + { + if (!type.IsVector) + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' cannot be a vector when Columntype is a scalar of type '{2}'", + _infos[i].Source, repType, type); + if (!type.IsKnownSizeVector) + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' is unknown size vector '{1}' must be a scalar instead of type '{2}'", _infos[i].Source, type, parent._replaceTypes[i]); + if (type.VectorSize != repType.VectorSize) + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' must be a scalar or a vector of the same size as Columntype '{2}'", + _infos[i].Source, repType, type); + } + _types[i] = type; + _isNAs[i] = _parent.GetIsNADelegate(type); + } } - return true; - } - - protected override VectorType GetSlotTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - return Infos[iinfo].SlotTypeSrc; - } + private ColInfo[] CreateInfos(ISchema inputSchema) + { + Host.AssertValue(inputSchema); + var infos = new ColInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + _parent.CheckInputColumn(inputSchema, i, colSrc); + var type = inputSchema.GetColumnType(colSrc); + infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + } + return infos; + } - protected override ISlotCursor GetSlotCursorCore(int iinfo) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.AssertValue(Infos[iinfo].SlotTypeSrc); + public override RowMapperColumnInfo[] GetOutputColumns() + { + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + Host.Assert(colIndex >= 0); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + var meta = RowColumnUtils.GetMetadataAsRow(InputSchema, colIndex, x => x == MetadataUtils.Kinds.SlotNames || x == MetadataUtils.Kinds.IsNormalized); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], meta); + } + return result; + } - ISlotCursor cursor = InputTranspose.GetSlotCursor(Infos[iinfo].Source); - var type = GetSlotTypeCore(iinfo); - Host.AssertValue(type); - return Utils.MarshalInvoke(GetSlotCursorCore, type.ItemType.RawType, this, iinfo, cursor, type); - } + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + disposer = null; - private ISlotCursor GetSlotCursorCore(NAReplaceTransform parent, int iinfo, ISlotCursor cursor, VectorType type) - => new SlotCursor(parent, iinfo, cursor, type); + if (!_infos[iinfo].TypeSrc.IsVector) + return ComposeGetterOne(input, iinfo); + return ComposeGetterVec(input, iinfo); + } - private sealed class SlotCursor : SynchronizedCursorBase, ISlotCursor - { - private readonly ValueGetter> _getter; - private readonly VectorType _type; + /// + /// Getter generator for single valued inputs. + /// + private Delegate ComposeGetterOne(IRow input, int iinfo) + => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].TypeSrc.RawType, input, iinfo); - public SlotCursor(NAReplaceTransform parent, int iinfo, ISlotCursor cursor, VectorType type) - : base(parent.Host, cursor) + /// + /// Replaces NA values for scalars. + /// + private Delegate ComposeGetterOne(IRow input, int iinfo) { - Ch.Assert(0 <= iinfo && iinfo < parent.Infos.Length); - Ch.AssertValue(cursor); - Ch.AssertValue(type); - var srcGetter = cursor.GetGetter(); - _type = type; - _getter = CreateGetter(parent, iinfo, cursor, type); + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + var src = default(T); + var isNA = (RefPredicate)_isNAs[iinfo]; + Host.Assert(_parent._repValues[iinfo] is T); + T rep = (T)_parent._repValues[iinfo]; + ValueGetter getter; + + return getter = + (ref T dst) => + { + getSrc(ref src); + dst = isNA(ref src) ? rep : src; + }; } - private ValueGetter> CreateGetter(NAReplaceTransform parent, int iinfo, ISlotCursor cursor, VectorType type) + /// + /// Getter generator for vector valued inputs. + /// + private Delegate ComposeGetterVec(IRow input, int iinfo) + => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); + + /// + /// Replaces NA values for vectors. + /// + private Delegate ComposeGetterVec(IRow input, int iinfo) { + var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); + var isNA = (RefPredicate)_isNAs[iinfo]; + var isDefault = Conversions.Instance.GetIsDefaultPredicate(_infos[iinfo].TypeSrc.ItemType); + var src = default(VBuffer); ValueGetter> getter; - var getSrc = cursor.GetGetter(); - var isNA = (RefPredicate)parent._isNAs[iinfo]; - var isDefault = Conversions.Instance.GetIsDefaultPredicate(type.ItemType); - - if (parent._repIsDefault[iinfo] == null) + if (_parent._repIsDefault[iinfo] == null) { // One replacement value for all slots. - Ch.Assert(parent._repValues[iinfo] is T); - T rep = (T)parent._repValues[iinfo]; + Host.Assert(_parent._repValues[iinfo] is T); + T rep = (T)_parent._repValues[iinfo]; bool repIsDefault = isDefault(ref rep); - - return (ref VBuffer dst) => - { - getSrc(ref src); - parent.FillValues(ref src, ref dst, isNA, rep, repIsDefault); - }; + return getter = + (ref VBuffer dst) => + { + getSrc(ref src); + FillValues(ref src, ref dst, isNA, rep, repIsDefault); + }; } // Replacement values by slot. - Ch.Assert(parent._repValues[iinfo] is T[]); + Host.Assert(_parent._repValues[iinfo] is T[]); // The replacement array. - T[] repArray = (T[])parent._repValues[iinfo]; + T[] repArray = (T[])_parent._repValues[iinfo]; return getter = (ref VBuffer dst) => { getSrc(ref src); - Ch.Check(0 <= Position && Position < repArray.Length); - T rep = repArray[(int)Position]; - parent.FillValues(ref src, ref dst, isNA, rep, isDefault(ref rep)); + Host.Check(src.Length == repArray.Length); + FillValues(ref src, ref dst, isNA, repArray, _parent._repIsDefault[iinfo]); }; } - public VectorType GetSlotType() => _type; - - public ValueGetter> GetGetter() + /// + /// Fills values for vectors where there is one replacement value. + /// + private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T rep, bool repIsDefault) { - ValueGetter> getter = _getter as ValueGetter>; - if (getter == null) - throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); - return getter; - } - } + Host.AssertValue(isNA); - /// - /// Fills values for vectors where there is one replacement value. - /// - private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T rep, bool repIsDefault) - { - Host.AssertValue(isNA); + int srcSize = src.Length; + int srcCount = src.Count; + var srcValues = src.Values; + Host.Assert(Utils.Size(srcValues) >= srcCount); + var srcIndices = src.Indices; - int srcSize = src.Length; - int srcCount = src.Count; - var srcValues = src.Values; - Host.Assert(Utils.Size(srcValues) >= srcCount); - var srcIndices = src.Indices; + var dstValues = dst.Values; + var dstIndices = dst.Indices; - var dstValues = dst.Values; - var dstIndices = dst.Indices; + // If the values array is not large enough, allocate sufficient space. + // Note: We can't set the max to srcSize as vectors can be of variable lengths. + Utils.EnsureSize(ref dstValues, srcCount, keepOld: false); - // If the values array is not large enough, allocate sufficient space. - // Note: We can't set the max to srcSize as vectors can be of variable lengths. - Utils.EnsureSize(ref dstValues, srcCount, keepOld: false); + int iivDst = 0; + if (src.IsDense) + { + // The source vector is dense. + Host.Assert(srcSize == srcCount); + + for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) + { + var srcVal = srcValues[ivSrc]; + + // The output for dense inputs is always dense. + // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with + // the default value, resulting in more than half of the indices being the default value. + // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision + // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. + dstValues[ivSrc] = isNA(ref srcVal) ? rep : srcVal; + } + iivDst = srcCount; + } + else + { + // The source vector is sparse. + Host.Assert(Utils.Size(srcIndices) >= srcCount); + Host.Assert(srcCount < srcSize); + + // Allocate more space if necessary. + // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays + // does is over-allocate space if the replacement value is the default value in a dataset with a + // signficiant amount of NA values -- is it worth handling allocation of memory for this case? + Utils.EnsureSize(ref dstIndices, srcCount, keepOld: false); + + // Note: ivPrev is only used for asserts. + int ivPrev = -1; + for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) + { + Host.Assert(iivDst <= iivSrc); + var srcVal = srcValues[iivSrc]; + int iv = srcIndices[iivSrc]; + Host.Assert(ivPrev < iv & iv < srcSize); + ivPrev = iv; + + if (!isNA(ref srcVal)) + { + dstValues[iivDst] = srcVal; + dstIndices[iivDst++] = iv; + } + else if (!repIsDefault) + { + // Allow for further sparsification. + dstValues[iivDst] = rep; + dstIndices[iivDst++] = iv; + } + } + Host.Assert(iivDst <= srcCount); + } + Host.Assert(0 <= iivDst); + Host.Assert(repIsDefault || iivDst == srcCount); + dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); + } - int iivDst = 0; - if (src.IsDense) + /// + /// Fills values for vectors where there is slot-wise replacement values. + /// + private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T[] rep, BitArray repIsDefault) { - // The source vector is dense. - Host.Assert(srcSize == srcCount); + Host.AssertValue(rep); + Host.Assert(rep.Length == src.Length); + Host.AssertValue(repIsDefault); + Host.Assert(repIsDefault.Length == src.Length); + Host.AssertValue(isNA); + + int srcSize = src.Length; + int srcCount = src.Count; + var srcValues = src.Values; + Host.Assert(Utils.Size(srcValues) >= srcCount); + var srcIndices = src.Indices; + + var dstValues = dst.Values; + var dstIndices = dst.Indices; + + // If the values array is not large enough, allocate sufficient space. + Utils.EnsureSize(ref dstValues, srcCount, srcSize, keepOld: false); + + int iivDst = 0; + Host.Assert(Utils.Size(srcValues) >= srcCount); + if (src.IsDense) + { + // The source vector is dense. + Host.Assert(srcSize == srcCount); - for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) + for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) + { + var srcVal = srcValues[ivSrc]; + + // The output for dense inputs is always dense. + // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with + // the default value, resulting in more than half of the indices being the default value. + // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision + // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. + dstValues[ivSrc] = isNA(ref srcVal) ? rep[ivSrc] : srcVal; + } + iivDst = srcCount; + } + else { - var srcVal = srcValues[ivSrc]; - - // The output for dense inputs is always dense. - // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with - // the default value, resulting in more than half of the indices being the default value. - // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision - // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. - dstValues[ivSrc] = isNA(ref srcVal) ? rep : srcVal; + // The source vector is sparse. + Host.Assert(Utils.Size(srcIndices) >= srcCount); + Host.Assert(srcCount < srcSize); + + // Allocate more space if necessary. + // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays + // does is over-allocate space if the replacement value is the default value in a dataset with a + // signficiant amount of NA values -- is it worth handling allocation of memory for this case? + Utils.EnsureSize(ref dstIndices, srcCount, srcSize, keepOld: false); + + // Note: ivPrev is only used for asserts. + int ivPrev = -1; + for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) + { + Host.Assert(iivDst <= iivSrc); + var srcVal = srcValues[iivSrc]; + int iv = srcIndices[iivSrc]; + Host.Assert(ivPrev < iv & iv < srcSize); + ivPrev = iv; + + if (!isNA(ref srcVal)) + { + dstValues[iivDst] = srcVal; + dstIndices[iivDst++] = iv; + } + else if (!repIsDefault[iv]) + { + // Allow for further sparsification. + dstValues[iivDst] = rep[iv]; + dstIndices[iivDst++] = iv; + } + } + Host.Assert(iivDst <= srcCount); } - iivDst = srcCount; + Host.Assert(0 <= iivDst); + dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); } - else - { - // The source vector is sparse. - Host.Assert(Utils.Size(srcIndices) >= srcCount); - Host.Assert(srcCount < srcSize); - // Allocate more space if necessary. - // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays - // does is over-allocate space if the replacement value is the default value in a dataset with a - // signficiant amount of NA values -- is it worth handling allocation of memory for this case? - Utils.EnsureSize(ref dstIndices, srcCount, keepOld: false); + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); - // Note: ivPrev is only used for asserts. - int ivPrev = -1; - for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { - Host.Assert(iivDst <= iivSrc); - var srcVal = srcValues[iivSrc]; - int iv = srcIndices[iivSrc]; - Host.Assert(ivPrev < iv & iv < srcSize); - ivPrev = iv; - - if (!isNA(ref srcVal)) + ColInfo info = _infos[iinfo]; + string sourceColumnName = info.Source; + if (!ctx.ContainsColumn(sourceColumnName)) { - dstValues[iivDst] = srcVal; - dstIndices[iivDst++] = iv; + ctx.RemoveColumn(info.Name, false); + continue; } - else if (!repIsDefault) + + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + ctx.AddIntermediateVariable(_parent._replaceTypes[iinfo], info.Name))) { - // Allow for further sparsification. - dstValues[iivDst] = rep; - dstIndices[iivDst++] = iv; + ctx.RemoveColumn(info.Name, true); } } - Host.Assert(iivDst <= srcCount); } - Host.Assert(0 <= iivDst); - Host.Assert(repIsDefault || iivDst == srcCount); - dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + { + DataKind rawKind; + var type = _infos[iinfo].TypeSrc; + if (type.IsVector) + rawKind = type.AsVector.ItemType.RawKind; + else if (type.IsKey) + rawKind = type.AsKey.RawKind; + else + rawKind = type.RawKind; + + if (rawKind != DataKind.R4) + return false; + + string opType = "Imputer"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("replaced_value_float", Single.NaN); + + if (!_infos[iinfo].TypeSrc.IsVector) + node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1)); + else + { + if (_parent._repIsDefault[iinfo] != null) + node.AddAttribute("imputed_value_floats", (float[])_parent._repValues[iinfo]); + else + node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1)); + } + return true; + } } + } - /// - /// Fills values for vectors where there is slot-wise replacement values. - /// - private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T[] rep, BitArray repIsDefault) + public sealed class NAReplaceEstimator : IEstimator + { + public static class Defaults { - Host.AssertValue(rep); - Host.Assert(rep.Length == src.Length); - Host.AssertValue(repIsDefault); - Host.Assert(repIsDefault.Length == src.Length); - Host.AssertValue(isNA); + public const NAReplaceTransform.ColumnInfo.ReplacementMode ReplacementMode = NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue; + public const bool ImputeBySlot = true; + } - int srcSize = src.Length; - int srcCount = src.Count; - var srcValues = src.Values; - Host.Assert(Utils.Size(srcValues) >= srcCount); - var srcIndices = src.Indices; + private readonly IHost _host; + private readonly NAReplaceTransform.ColumnInfo[] _columns; - var dstValues = dst.Values; - var dstIndices = dst.Indices; + public NAReplaceEstimator(IHostEnvironment env, string name, string source = null, NAReplaceTransform.ColumnInfo.ReplacementMode replacementKind = Defaults.ReplacementMode) + : this(env, new NAReplaceTransform.ColumnInfo(source ?? name, name, replacementKind)) + { + + } - // If the values array is not large enough, allocate sufficient space. - Utils.EnsureSize(ref dstValues, srcCount, srcSize, keepOld: false); + public NAReplaceEstimator(IHostEnvironment env, params NAReplaceTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(NAReplaceEstimator)); + _columns = columns; + } - int iivDst = 0; - Host.Assert(Utils.Size(srcValues) >= srcCount); - if (src.IsDense) + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) { - // The source vector is dense. - Host.Assert(srcSize == srcCount); + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + string reason = NAReplaceTransform.TestType(col.ItemType); + if (reason != null) + throw _host.ExceptParam(nameof(inputSchema), reason); + var metadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) + metadata.Add(slotMeta); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normalized)) + metadata.Add(normalized); + var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); + } + return new SchemaShape(result.Values); + } - for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) - { - var srcVal = srcValues[ivSrc]; - - // The output for dense inputs is always dense. - // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with - // the default value, resulting in more than half of the indices being the default value. - // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision - // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. - dstValues[ivSrc] = isNA(ref srcVal) ? rep[ivSrc] : srcVal; - } - iivDst = srcCount; + public NAReplaceTransform Fit(IDataView input) => new NAReplaceTransform(_host, input, _columns); + } + + /// + /// Extension methods for the static-pipeline over objects. + /// + public static class NAReplaceExtensions + { + private struct Config + { + public readonly bool ImputeBySlot; + public readonly NAReplaceTransform.ColumnInfo.ReplacementMode ReplacementMode; + + public Config(NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, + bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + ImputeBySlot = imputeBySlot; + ReplacementMode = replacementMode; } - else + } + + private interface IColInput + { + PipelineColumn Input { get; } + Config Config { get; } + } + + private sealed class OutScalar : Scalar, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutScalar(Scalar input, Config config) + : base(Reconciler.Inst, input) { - // The source vector is sparse. - Host.Assert(Utils.Size(srcIndices) >= srcCount); - Host.Assert(srcCount < srcSize); + Input = input; + Config = config; + } + } - // Allocate more space if necessary. - // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays - // does is over-allocate space if the replacement value is the default value in a dataset with a - // signficiant amount of NA values -- is it worth handling allocation of memory for this case? - Utils.EnsureSize(ref dstIndices, srcCount, srcSize, keepOld: false); + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } - // Note: ivPrev is only used for asserts. - int ivPrev = -1; - for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) - { - Host.Assert(iivDst <= iivSrc); - var srcVal = srcValues[iivSrc]; - int iv = srcIndices[iivSrc]; - Host.Assert(ivPrev < iv & iv < srcSize); - ivPrev = iv; + public OutVectorColumn(Vector input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } - if (!isNA(ref srcVal)) - { - dstValues[iivDst] = srcVal; - dstIndices[iivDst++] = iv; - } - else if (!repIsDefault[iv]) - { - // Allow for further sparsification. - dstValues[iivDst] = rep[iv]; - dstIndices[iivDst++] = iv; - } + } + + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutVarVectorColumn(VarVector input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var infos = new NAReplaceTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var col = (IColInput)toOutput[i]; + infos[i] = new NAReplaceTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Config.ReplacementMode, col.Config.ImputeBySlot); } - Host.Assert(iivDst <= srcCount); + return new NAReplaceEstimator(env, infos); } - Host.Assert(0 <= iivDst); - dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); + } + + public static Scalar ReplaceWithMissingValues(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input, new Config(replacementMode, imputeBySlot)); + } + + public static Scalar ReplaceWithMissingValues(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input, new Config(replacementMode, imputeBySlot)); + } + + public static Scalar ReplaceWithMissingValues(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input, new Config(replacementMode, imputeBySlot)); + } + + public static Vector ReplaceWithMissingValues(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static Vector ReplaceWithMissingValues(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static Vector ReplaceWithMissingValues(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static VarVector ReplaceWithMissingValues(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static VarVector NAReplace(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static VarVector ReplaceWithMissingValues(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); } } } diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index 2340f9b413..8100a5f84b 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -1597,7 +1597,7 @@ public override object GetStat() // If sparsity occurred, fold in a zero. if (ValueCount > (ulong)ValuesProcessed) { - TItem def = default(TItem); + TItem def = default; ProcValueDelegate(ref def); } return _converter.FromLong(Stat); diff --git a/test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv b/test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv new file mode 100644 index 0000000000..f9ba7bf9f4 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv @@ -0,0 +1,14 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=A:TX:0 +#@ col=B:R4:1 +#@ col=C:R8:2 +#@ col=D:TX:3-6 +#@ col=E:R4:7-10 +#@ } +A B C 8 0:"" +5 5 5 5 1 1 1 5 1 1 1 +5 5 5 5 4 4 5 5 4 4 5 +3 3 3 3 1 1 1 3 1 1 1 +6 6 6 6 8 8 1 6 8 8 1 diff --git a/test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv b/test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv new file mode 100644 index 0000000000..f9ba7bf9f4 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv @@ -0,0 +1,14 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=A:TX:0 +#@ col=B:R4:1 +#@ col=C:R8:2 +#@ col=D:TX:3-6 +#@ col=E:R4:7-10 +#@ } +A B C 8 0:"" +5 5 5 5 1 1 1 5 1 1 1 +5 5 5 5 4 4 5 5 4 4 5 +3 3 3 3 1 1 1 3 1 1 1 +6 6 6 6 8 8 1 6 8 8 1 diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index fdd240c514..5a39ef268b 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -150,10 +150,7 @@ private void ValidateMetadata(IDataView result) [Fact] public void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0); - } + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 45c6f938c5..1570840f84 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -215,10 +215,7 @@ private void ValidateMetadata(IDataView result) [Fact] public void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0); - } + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs new file mode 100644 index 0000000000..969c5198b8 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -0,0 +1,133 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class NAReplaceTests : TestDataPipeBase + { + private class TestClass + { + public float A; + public string B; + public double C; + [VectorType(2)] + public float[] D; + [VectorType(2)] + public double[] E; + } + + public NAReplaceTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void NAReplaceWorkout() + { + var data = new[] { + new TestClass() { A = 1, B = "A", C = 3, D= new float[2]{ 1, 2 } , E = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = null, C = double.NaN, D= new float[2]{ float.NaN, float.NaN } , E = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = null, C = double.NegativeInfinity,D= new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , E = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = null, C = double.PositiveInfinity,D= new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , E = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = "B", C = 1 ,D= new float[2]{ 3, 4 } , E = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new NAReplaceEstimator(Env, + new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue), + new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean)); + TestEstimatorCore(pipe, dataView); + Done(); + } + + [Fact] + public void NAReplaceStatic() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarString: ctx.LoadText(1), + ScalarFloat: ctx.LoadFloat(1), + ScalarDouble: ctx.LoadDouble(1), + VectorString: ctx.LoadText(1, 4), + VectorFloat: ctx.LoadFloat(1, 4), + VectorDoulbe: ctx.LoadDouble(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)); + var wrongCollection = new[] { new TestClass() { A = 1, B = "A", C = 3, D = new float[2] { 1, 2 }, E = new double[2] { 3, 4 } } }; + var invalidData = ComponentCreation.CreateDataView(Env, wrongCollection); + + var est = data.MakeNewEstimator(). + Append(row => ( + A: row.ScalarString.ReplaceWithMissingValues(), + B: row.ScalarFloat.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Maximum), + C: row.ScalarDouble.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + D: row.VectorString.ReplaceWithMissingValues(), + E: row.VectorFloat.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + F: row.VectorDoulbe.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Minimum) + )); + + TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); + var outputPath = GetOutputPath("NAReplace", "featurized.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); + savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D", "E"); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("NAReplace", "featurized.tsv"); + Done(); + } + + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=NAReplace{col=C:A} in=f:\2.txt" }), (int)0); + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { + new TestClass() { A = 1, B = "A", C = 3, D= new float[2]{ 1, 2 } , E = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = null, C = double.NaN, D= new float[2]{ float.NaN, float.NaN } , E = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = null, C = double.NegativeInfinity,D= new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , E = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = null, C = double.PositiveInfinity,D= new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , E = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = "B", C = 1 ,D= new float[2]{ 3, 4 } , E = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new NAReplaceEstimator(Env, + new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue), + new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean)); + + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + } +}