From 16f44ad911d997366763a36c7bb5597d177ec3b3 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 14 Sep 2018 11:35:21 -0700 Subject: [PATCH 1/8] need metadata tests //Ivan resolve and pigsty --- .../Transforms/KeyToVectorTransform.cs | 1 + .../Transforms/TermEstimator.cs | 1 + .../KeyToBinaryVectorTransform.cs | 3 +- .../NAHandleTransform.cs | 29 +- src/Microsoft.ML.Transforms/NAHandling.cs | 2 +- .../NAReplaceTransform.cs | 1170 +++++++++-------- .../Transformers/NAReplaceTests.cs | 93 ++ 7 files changed, 705 insertions(+), 594 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index efa519e05e..b5bbc63466 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -280,6 +280,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); var type = inputSchema.GetColumnType(colSrc); + _parent.CheckInputColumn(inputSchema, i, colSrc); infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); } return infos; diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs index 44e873e9e7..02a47fd90a 100644 --- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -15,6 +15,7 @@ public sealed class TermEstimator : IEstimator { private readonly IHost _host; private readonly TermTransform.ColumnInfo[] _columns; + public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) : this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) { diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 52a49b1563..b97612e380 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -69,6 +69,7 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Input, x.Output)).ToArray(); } + public IReadOnlyCollection Columns => _columns.AsReadOnly(); private readonly ColumnInfo[] _columns; @@ -209,7 +210,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); var type = inputSchema.GetColumnType(colSrc); - + _parent.CheckInputColumn(inputSchema, i, colSrc); infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); } return infos; diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index 9e7390948b..d723b52dee 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -135,7 +135,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - var replaceCols = new List(); + var replaceCols = new List(); var naIndicatorCols = new List(); var naConvCols = new List(); var concatCols = new List(); @@ -149,14 +149,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var addInd = column.ConcatIndicator ?? args.Concat; if (!addInd) { - replaceCols.Add( - new NAReplaceTransform.Column() - { - Kind = (NAReplaceTransform.ReplacementKind?)column.Kind, - Name = column.Name, - Source = column.Source, - Slot = column.ImputeBySlot - }); + replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ReplacementKind)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); continue; } @@ -186,14 +179,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV naConvCols.Add(new ConvertTransform.Column() { Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind }); // Add the NAReplaceTransform column. - replaceCols.Add( - new NAReplaceTransform.Column() - { - Kind = (NAReplaceTransform.ReplacementKind?)column.Kind, - Name = tmpReplacementColName, - Source = column.Source, - Slot = column.ImputeBySlot - }); + replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ReplacementKind)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType.IsVector) @@ -237,15 +223,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.AssertValue(output); output = new ConvertTransform(h, new ConvertTransform.Arguments() { Column = naConvCols.ToArray() }, output); } - // Create the NAReplace transform. - output = new NAReplaceTransform(h, - new NAReplaceTransform.Arguments() - { - Column = replaceCols.ToArray(), - ReplacementKind = (NAReplaceTransform.ReplacementKind)args.ReplaceWith, - ImputeBySlot = args.ImputeBySlot - }, output ?? input); + output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray()); // Concat the NAReplaceTransform output and the NAIndicatorTransform output. if (naIndicatorCols.Count > 0) diff --git a/src/Microsoft.ML.Transforms/NAHandling.cs b/src/Microsoft.ML.Transforms/NAHandling.cs index 0870d16461..8f8fdaabb9 100644 --- a/src/Microsoft.ML.Transforms/NAHandling.cs +++ b/src/Microsoft.ML.Transforms/NAHandling.cs @@ -88,7 +88,7 @@ public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, NAIn public static CommonOutputs.TransformOutput Replace(IHostEnvironment env, NAReplaceTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAReplace", input); - var xf = new NAReplaceTransform(h, input, input.Data); + var xf = NAReplaceTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index c9b309af89..9dc00fe8d0 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -18,23 +18,25 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.Core.Data; -[assembly: LoadableClass(typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), - NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), + NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] -[assembly: LoadableClass(typeof(NAReplaceTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataView), typeof(NAReplaceTransform), null, typeof(SignatureLoadDataTransform), NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(NAReplaceTransform), null, typeof(SignatureLoadModel), + NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(NAReplaceTransform), null, typeof(SignatureLoadRowMapper), + NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] + namespace Microsoft.ML.Runtime.Data { - // This transform can transform either scalars or vectors (both fixed and variable size), - // creating output columns that are identical to the input columns except for replacing NA values - // with either the default value, user input, or imputed values (min/max/mean are currently supported). - // Imputation modes are supported for vectors both by slot and across all slots. - // REVIEW: May make sense to implement the transform template interface. - /// - public sealed partial class NAReplaceTransform : OneToOneTransformBase + public sealed partial class NAReplaceTransform : OneToOneTransformerBase { + //IVAN: CLEAN IT public enum ReplacementKind { // REVIEW: What should the full list of options for this transform be? @@ -120,54 +122,50 @@ public sealed class Arguments : TransformInputBase } public const string LoadName = "NAReplaceTransform"; - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - // REVIEW: temporary name - modelSignature: "NAREP TF", - // verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x0010002, // Added imputation methods. - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); - } - internal const string Summary = "Create an output column of the same type and size of the input column, where missing values " - + "are replaced with either the default value or the mean/min/max value (for non-text columns only)."; + + "are replaced with either the default value or the mean/min/max value (for non-text columns only)."; internal const string FriendlyName = "NA Replace Transform"; internal const string ShortName = "NARep"; - private static string TestType(ColumnType type) + public class ColumnInfo { - // Item type must have an NA value that exists and is not equal to its default value. - Func func = TestType; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); - return (string)meth.Invoke(null, new object[] { type.ItemType }); - } + public readonly string Input; + public readonly string Output; + public readonly bool ImputeBySlot; + public readonly ReplacementKind Kind; - private static string TestType(ColumnType type) - { - Contracts.Assert(type.ItemType.RawType == typeof(T)); - RefPredicate isNA; - if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out isNA)) + public ColumnInfo(string input, string output, ReplacementKind kind = ReplacementKind.DefaultValue, bool imputeBySlot = true) { - return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", - type, LoadName); + Input = input; + Output = output; + ImputeBySlot = imputeBySlot; + Kind = kind; } - var t = default(T); - if (isNA(ref t)) + internal string ReplacementString { get; set; } + } + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } + + ///IVAN: move to mapper. + internal sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) { - // REVIEW: Key values will be handled in a "new key value" transform. - return string.Format("Type '{0}' is not supported by {1} since its NA value is equivalent to its default value", - type, LoadName); + Name = name; + Source = source; + TypeSrc = type; } - return null; } - // The output column types, parallel to Infos. - private readonly ColumnType[] _types; - // The replacementValues for the columns, parallel to Infos. // The elements of this array can be either primitive values or arrays of primitive values. When replacing a scalar valued column in Infos, // this array will hold a primitive value. When replacing a vector valued column in Infos, this array will either hold a primitive @@ -178,168 +176,52 @@ private static string TestType(ColumnType type) // Marks if the replacement values in given slots of _repValues are the default value. // REVIEW: Currently these arrays are constructed on load but could be changed to being constructed lazily. - private BitArray[] _repIsDefault; - - // The isNA delegates, parallel to Infos. - private readonly Delegate[] _isNAs; - - public override bool CanSaveOnnx => true; - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - /// The replacement method to utilize. - public NAReplaceTransform(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replacementKind = ReplacementKind.DefaultValue) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ReplacementKind = replacementKind }, input) - { - } - - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public NAReplaceTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column, - input, TestType) - { - Host.CheckValue(args, nameof(args)); - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + private readonly BitArray[] _repIsDefault; - GetInfoAndMetadata(out _types, out _isNAs); - GetReplacementValues(args, out _repValues, out _repIsDefault); - } + // The output column types, parallel to Infos. + private readonly ColumnType[] _types; - private NAReplaceTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) + private string TestType(ColumnType type) { - Host.AssertValue(ctx); - Host.AssertNonEmpty(Infos); - - GetInfoAndMetadata(out _types, out _isNAs); - - // *** Binary format *** - // - // for each column: - // type and value - _repValues = new object[Infos.Length]; - _repIsDefault = new BitArray[Infos.Length]; - var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - object repValue; - ColumnType repType; - if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out repType, out repValue)) - throw Host.ExceptDecode(); - if (!_types[iinfo].ItemType.Equals(repType.ItemType)) - throw Host.ExceptParam(nameof(input), "Decoded serialization of type '{0}' does not match expected ColumnType of '{1}'", repType.ItemType, _types[iinfo].ItemType); - // If type is a vector and the value is not either a scalar or a vector of the same size, throw an error. - if (repType.IsVector) - { - if (!_types[iinfo].IsVector) - throw Host.ExceptParam(nameof(input), "Decoded serialization of type '{0}' cannot be a vector when Columntype is a scalar of type '{1}'", repType, _types[iinfo]); - if (!_types[iinfo].IsKnownSizeVector) - throw Host.ExceptParam(nameof(input), "Decoded serialization for unknown size vector '{0}' must be a scalar instead of type '{1}'", _types[iinfo], repType); - if (_types[iinfo].VectorSize != repType.VectorSize) - { - throw Host.ExceptParam(nameof(input), "Decoded serialization of type '{0}' must be a scalar or a vector of the same size as Columntype '{1}'", - repType, _types[iinfo]); - } - - // REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array. - // It might be of value to consider storing the VBUffer in order to possibly benefit from sparsity. However, this would - // necessitate a reimplementation of the FillValues code to accomodate sparse VBuffers. - object[] args = new object[] { repValue, _types[iinfo], iinfo }; - Func, ColumnType, int, int[]> func = GetValuesArray; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repType.ItemType.RawType); - _repValues[iinfo] = meth.Invoke(this, args); - } - else - _repValues[iinfo] = repValue; - - Host.Assert(repValue.GetType() == _types[iinfo].RawType || repValue.GetType() == _types[iinfo].ItemType.RawType); - } + // Item type must have an NA value that exists and is not equal to its default value. + Func func = TestType; + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); + return (string)meth.Invoke(null, new object[] { type.ItemType }); } - private T[] GetValuesArray(VBuffer src, ColumnType srcType, int iinfo) + private static string TestType(ColumnType type) { - Host.Assert(srcType.IsVector); - Host.Assert(srcType.VectorSize == src.Length); - VBufferUtils.Densify(ref src); - RefPredicate defaultPred = Conversions.Instance.GetIsDefaultPredicate(srcType.ItemType); - _repIsDefault[iinfo] = new BitArray(srcType.VectorSize); - for (int slot = 0; slot < src.Length; slot++) + Contracts.Assert(type.ItemType.RawType == typeof(T)); + RefPredicate isNA; + if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out isNA)) { - if (defaultPred(ref src.Values[slot])) - _repIsDefault[iinfo][slot] = true; + return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", + type, LoadName); } - T[] valReturn = src.Values; - Array.Resize(ref valReturn, srcType.VectorSize); - Host.Assert(valReturn.Length == src.Length); - return valReturn; - } - - public static NAReplaceTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(LoadName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new NAReplaceTransform(h, ctx, input)); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // - // for each column: - // type and value - SaveBase(ctx); - var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int iinfo = 0; iinfo < _types.Length; iinfo++) + var t = default(T); + if (isNA(ref t)) { - var repValue = _repValues[iinfo]; - var repType = _types[iinfo].ItemType; - if (_repIsDefault[iinfo] != null) - { - Host.Assert(repValue is Array); - Func> function = CreateVBuffer; - var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_types[iinfo].ItemType.RawType); - repValue = method.Invoke(this, new object[] { _repValues[iinfo] }); - repType = _types[iinfo]; - } - Host.Assert(!(repValue is Array)); - object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, repValue }; - Action func = WriteTypeAndValue; - Host.Assert(repValue.GetType() == _types[iinfo].RawType || repValue.GetType() == _types[iinfo].ItemType.RawType); - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); - meth.Invoke(this, args); + // REVIEW: Key values will be handled in a "new key value" transform. + return string.Format("Type '{0}' is not supported by {1} since its NA value is equivalent to its default value", + type, LoadName); } + return null; } - private VBuffer CreateVBuffer(T[] array) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.AssertValue(array); - return new VBuffer(array.Length, array); + var type = inputSchema.GetColumnType(srcCol); + string reason = TestType(type); + if (reason != null) + //IVAN: not sure about schema mismatch + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); } - private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType type, T rep) + public NAReplaceTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAReplaceTransform)), GetColumnPairs(columns)) { - Host.AssertValue(stream); - Host.AssertValue(saver); - Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); - - int bytesWritten; - if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out bytesWritten)) - throw Host.Except("We do not know how to serialize terms of type '{0}'", type); + // Validate input schema. + GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _types); } /// @@ -347,40 +229,47 @@ private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType t /// Vectors default to by-slot imputation unless otherwise specified, except for unknown sized vectors /// which force across-slot imputation. /// - private void GetReplacementValues(Arguments args, out object[] repValues, out BitArray[] slotIsDefault) + private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out object[] repValues, out BitArray[] slotIsDefault, out ColumnType[] types) { - repValues = new object[Infos.Length]; - slotIsDefault = new BitArray[Infos.Length]; - - ReplacementKind?[] imputationModes = new ReplacementKind?[Infos.Length]; + repValues = new object[columns.Length]; + slotIsDefault = new BitArray[columns.Length]; + types = new ColumnType[columns.Length]; + var sources = new int[columns.Length]; + ReplacementKind?[] imputationModes = new ReplacementKind?[columns.Length]; List columnsToImpute = null; // REVIEW: Would like to get rid of the sourceColumns list but seems to be the best way to provide // the cursor with what columns to cursor through. HashSet sourceColumns = null; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + for (int iinfo = 0; iinfo < columns.Length; iinfo++) { - ReplacementKind kind = args.Column[iinfo].Kind ?? args.ReplacementKind; - switch (kind) + input.Schema.TryGetColumnIndex(columns[iinfo].Input, out int colSrc); + sources[iinfo] = colSrc; + var type = input.Schema.GetColumnType(colSrc); + if (type.IsVector) + type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + types[iinfo] = type; + switch (columns[iinfo].Kind) { - case ReplacementKind.SpecifiedValue: - repValues[iinfo] = GetSpecifiedValue(args.Column[iinfo].ReplacementString, _types[iinfo], _isNAs[iinfo]); - break; - case ReplacementKind.DefaultValue: - repValues[iinfo] = GetDefault(_types[iinfo]); - break; - case ReplacementKind.Mean: - case ReplacementKind.Min: - case ReplacementKind.Max: - if (!_types[iinfo].ItemType.IsNumber && !_types[iinfo].ItemType.IsTimeSpan && !_types[iinfo].ItemType.IsDateTime) - throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", _types[iinfo].ItemType); - imputationModes[iinfo] = kind; - Utils.Add(ref columnsToImpute, iinfo); - Utils.Add(ref sourceColumns, Infos[iinfo].Source); - break; - default: - Host.Assert(false); - throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", kind); + //IVAN: what to do with you? + //case ReplacementKind.SpecifiedValue: + // repValues[iinfo] = GetSpecifiedValue(args.Column[iinfo].ReplacementString, _types[iinfo], _isNAs[iinfo]); + // break; + case ReplacementKind.DefaultValue: + repValues[iinfo] = GetDefault(type); + break; + case ReplacementKind.Mean: + case ReplacementKind.Minimum: + case ReplacementKind.Maximum: + if (!type.ItemType.IsNumber && !type.ItemType.IsTimeSpan && !type.ItemType.IsDateTime) + throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.ItemType); + imputationModes[iinfo] = columns[iinfo].Kind; + Utils.Add(ref columnsToImpute, iinfo); + Utils.Add(ref sourceColumns, colSrc); + break; + default: + Host.Assert(false); + throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", columns[iinfo].Kind); } } @@ -390,20 +279,21 @@ private void GetReplacementValues(Arguments args, out object[] repValues, out Bi // Impute values. using (var ch = Host.Start("Computing Statistics")) - using (var cursor = Source.GetRowCursor(sourceColumns.Contains)) + using (var cursor = input.GetRowCursor(sourceColumns.Contains)) { StatAggregator[] statAggregators = new StatAggregator[columnsToImpute.Count]; for (int ii = 0; ii < columnsToImpute.Count; ii++) { int iinfo = columnsToImpute[ii]; - bool bySlot = args.Column[ii].Slot ?? args.ImputeBySlot; - if (_types[iinfo].IsVector && !_types[iinfo].IsKnownSizeVector && bySlot) + bool bySlot = columns[ii].ImputeBySlot; + if (types[iinfo].IsVector && !types[iinfo].IsKnownSizeVector && bySlot) { ch.Warning("By-slot imputation can not be done on variable-length column"); bySlot = false; } - statAggregators[ii] = CreateStatAggregator(ch, _types[iinfo], imputationModes[iinfo], bySlot, - cursor, Infos[iinfo].Source); + + statAggregators[ii] = CreateStatAggregator(ch, types[iinfo], imputationModes[iinfo], bySlot, + cursor, sources[iinfo]); } while (cursor.MoveNext()) @@ -425,8 +315,8 @@ private void GetReplacementValues(Arguments args, out object[] repValues, out Bi if (repValues[slot] is Array) { Func func = ComputeDefaultSlots; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_types[slot].ItemType.RawType); - slotIsDefault[slot] = (BitArray)meth.Invoke(this, new object[] { _types[slot], repValues[slot] }); + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(types[slot].ItemType.RawType); + slotIsDefault[slot] = (BitArray)meth.Invoke(this, new object[] { types[slot], repValues[slot] }); } } } @@ -444,31 +334,6 @@ private BitArray ComputeDefaultSlots(ColumnType type, T[] values) return defaultSlots; } - private void GetInfoAndMetadata(out ColumnType[] types, out Delegate[] isNAs) - { - var md = Metadata; - types = new ColumnType[Infos.Length]; - isNAs = new Delegate[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - var type = Infos[iinfo].TypeSrc; - - if (!type.IsVector) - types[iinfo] = type; - else - types[iinfo] = new VectorType(type.ItemType.AsPrimitive, type.AsVector); - - isNAs[iinfo] = GetIsNADelegate(type); - - // Pass through slot name metadata and normalization data. - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source, - MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized)) - { - } - } - md.Seal(); - } - private object GetDefault(ColumnType type) { Func func = GetDefault; @@ -481,415 +346,586 @@ private object GetDefault() return default(T); } - /// - /// Returns the isNA predicate for the respective type. - /// - private Delegate GetIsNADelegate(ColumnType type) + private static VersionInfo GetVersionInfo() { - Func func = GetIsNADelegate; - return Utils.MarshalInvoke(func, type.ItemType.RawType, type); + return new VersionInfo( + // REVIEW: temporary name + modelSignature: "NAREP TF", + // verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x0010002, // Added imputation methods. + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010001, + loaderSignature: LoadName); } - private Delegate GetIsNADelegate(ColumnType type) + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - return Conversions.Instance.GetIsNAPredicate(type.ItemType); - } + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + var kind = item.Kind ?? args.ReplacementKind; + if (!Enum.IsDefined(typeof(ReplacementKind), kind)) + throw env.ExceptUserArg(nameof(args.ReplacementKind), "Undefined sorting criteria '{0}' detected for column '{1}'", kind, item.Name); + + cols[i] = new ColumnInfo(item.Source, + item.Name, + item.Kind ?? args.ReplacementKind, + item.Slot ?? args.ImputeBySlot); + cols[i].ReplacementString = item.ReplacementString; + }; + return new NAReplaceTransform(env, input, cols).MakeDataTransform(input); } - /// - /// Converts a string to its respective value in the corresponding type. - /// - private object GetSpecifiedValue(string srcStr, ColumnType dstType, Delegate isNA) + public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) { - Func, object> func = GetSpecifiedValue; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(dstType.ItemType.RawType); - return meth.Invoke(this, new object[] { srcStr, dstType, isNA }); + return new NAReplaceTransform(env, input, columns).MakeDataTransform(input); } - private object GetSpecifiedValue(string srcStr, ColumnType dstType, RefPredicate isNA) + // Factory method for SignatureLoadModel. + public static NAReplaceTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - var val = default(T); - if (!string.IsNullOrEmpty(srcStr)) - { - // Handles converting input strings to correct types. - DvText srcTxt = new DvText(srcStr); - bool identity; - var strToT = Conversions.Instance.GetStandardConversion(TextType.Instance, dstType.ItemType, out identity); - strToT(ref srcTxt, ref val); - // Make sure that the srcTxt can legitimately be converted to dstType, throw error otherwise. - if (isNA(ref val)) - throw Contracts.Except("No conversion of '{0}' to '{1}'", srcStr, dstType.ItemType); - } + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(LoadName); - return val; - } + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - if (!Infos[iinfo].TypeSrc.IsVector) - return ComposeGetterOne(input, iinfo); - return ComposeGetterVec(input, iinfo); + return new NAReplaceTransform(host, ctx); } - /// - /// Getter generator for single valued inputs. - /// - private Delegate ComposeGetterOne(IRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterOne, Infos[iinfo].TypeSrc.RawType, input, iinfo); + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); - /// - /// Replaces NA values for scalars. - /// - private Delegate ComposeGetterOne(IRow input, int iinfo) + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + private VBuffer CreateVBuffer(T[] array) { - var getSrc = GetSrcGetter(input, iinfo); - var src = default(T); - var isNA = (RefPredicate)_isNAs[iinfo]; - Host.Assert(_repValues[iinfo] is T); - T rep = (T)_repValues[iinfo]; - ValueGetter getter; - - return getter = - (ref T dst) => - { - getSrc(ref src); - dst = isNA(ref src) ? rep : src; - }; + Host.AssertValue(array); + return new VBuffer(array.Length, array); } - /// - /// Getter generator for vector valued inputs. - /// - private Delegate ComposeGetterVec(IRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterVec, Infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); - - /// - /// Replaces NA values for vectors. - /// - private Delegate ComposeGetterVec(IRow input, int iinfo) + private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType type, T rep) { - var getSrc = GetSrcGetter>(input, iinfo); - var isNA = (RefPredicate)_isNAs[iinfo]; - var isDefault = Conversions.Instance.GetIsDefaultPredicate(input.Schema.GetColumnType(Infos[iinfo].Source).ItemType); + Host.AssertValue(stream); + Host.AssertValue(saver); + Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); - var src = default(VBuffer); - ValueGetter> getter; + int bytesWritten; + if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out bytesWritten)) + throw Host.Except("We do not know how to serialize terms of type '{0}'", type); + } - if (_repIsDefault[iinfo] == null) - { - // One replacement value for all slots. - Host.Assert(_repValues[iinfo] is T); - T rep = (T)_repValues[iinfo]; - bool repIsDefault = isDefault(ref rep); - return getter = - (ref VBuffer dst) => - { - getSrc(ref src); - FillValues(ref src, ref dst, isNA, rep, repIsDefault); - }; - } + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); - // Replacement values by slot. - Host.Assert(_repValues[iinfo] is T[]); - // The replacement array. - T[] repArray = (T[])_repValues[iinfo]; + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); - return getter = - (ref VBuffer dst) => + SaveColumns(ctx); + var saver = new BinarySaver(Host, new BinarySaver.Arguments()); + for (int iinfo = 0; iinfo < _types.Length; iinfo++) + { + var repValue = _repValues[iinfo]; + var repType = _types[iinfo].ItemType; + if (_repIsDefault[iinfo] != null) { - getSrc(ref src); - Host.Check(src.Length == repArray.Length); - FillValues(ref src, ref dst, isNA, repArray, _repIsDefault[iinfo]); - }; + Host.Assert(repValue is Array); + Func> function = CreateVBuffer; + var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_types[iinfo].ItemType.RawType); + repValue = method.Invoke(this, new object[] { _repValues[iinfo] }); + repType = _types[iinfo]; + } + Host.Assert(!(repValue is Array)); + object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, repValue }; + Action func = WriteTypeAndValue; + Host.Assert(repValue.GetType() == _types[iinfo].RawType || repValue.GetType() == _types[iinfo].ItemType.RawType); + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); + meth.Invoke(this, args); + } } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + private NAReplaceTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - DataKind rawKind; - var type = Infos[iinfo].TypeSrc; - if (type.IsVector) - rawKind = type.AsVector.ItemType.RawKind; - else if (type.IsKey) - rawKind = type.AsKey.RawKind; - else - rawKind = type.RawKind; - - if (rawKind != DataKind.R4) - return false; - - string opType = "Imputer"; - var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - node.AddAttribute("replaced_value_float", Single.NaN); - - if (!Infos[iinfo].TypeSrc.IsVector) - node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1)); - else + var columnsLength = ColumnPairs.Length; + _repValues = new object[columnsLength]; + _repIsDefault = new BitArray[columnsLength]; + _types = new ColumnType[columnsLength]; + var saver = new BinarySaver(Host, new BinarySaver.Arguments()); + for (int i = 0; i< columnsLength; i++) { - if (_repIsDefault[iinfo] != null) - node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]); + if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out ColumnType savedType, out object repValue)) + throw Host.ExceptDecode(); + _types[i] = savedType; + if (savedType.IsVector) + { + // REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array. + // It might be of value to consider storing the VBUffer in order to possibly benefit from sparsity. However, this would + // necessitate a reimplementation of the FillValues code to accomodate sparse VBuffers. + object[] args = new object[] { repValue, _types[i], i }; + Func, ColumnType, int, int[]> func = GetValuesArray; + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(savedType.ItemType.RawType); + _repValues[i] = meth.Invoke(this, args); + } else - node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1)); - } + _repValues[i] = repValue; - return true; + Host.Assert(repValue.GetType() == _types[i].RawType || repValue.GetType() == _types[i].ItemType.RawType); + } } - protected override VectorType GetSlotTypeCore(int iinfo) + private T[] GetValuesArray(VBuffer src, ColumnType srcType, int iinfo) { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - return Infos[iinfo].SlotTypeSrc; + Host.Assert(srcType.IsVector); + Host.Assert(srcType.VectorSize == src.Length); + VBufferUtils.Densify(ref src); + RefPredicate defaultPred = Conversions.Instance.GetIsDefaultPredicate(srcType.ItemType); + _repIsDefault[iinfo] = new BitArray(srcType.VectorSize); + for (int slot = 0; slot < src.Length; slot++) + { + if (defaultPred(ref src.Values[slot])) + _repIsDefault[iinfo][slot] = true; + } + T[] valReturn = src.Values; + Array.Resize(ref valReturn, srcType.VectorSize); + Host.Assert(valReturn.Length == src.Length); + return valReturn; } - protected override ISlotCursor GetSlotCursorCore(int iinfo) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.AssertValue(Infos[iinfo].SlotTypeSrc); + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); - ISlotCursor cursor = InputTranspose.GetSlotCursor(Infos[iinfo].Source); - var type = GetSlotTypeCore(iinfo); - Host.AssertValue(type); - return Utils.MarshalInvoke(GetSlotCursorCore, type.ItemType.RawType, this, iinfo, cursor, type); - } + private sealed class Mapper : MapperBase, ISaveAsOnnx + { + private readonly NAReplaceTransform _parent; + private readonly ColInfo[] _infos; + // The isNA delegates, parallel to Infos. + private readonly Delegate[] _isNAs; + public bool CanSaveOnnx => true; - private ISlotCursor GetSlotCursorCore(NAReplaceTransform parent, int iinfo, ISlotCursor cursor, VectorType type) - => new SlotCursor(parent, iinfo, cursor, type); + public Mapper(NAReplaceTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _infos = CreateInfos(inputSchema); + _isNAs = new Delegate[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + var type = _infos[i].TypeSrc; + _isNAs[i] = GetIsNADelegate(type); + } + } - private sealed class SlotCursor : SynchronizedCursorBase, ISlotCursor - { - private readonly ValueGetter> _getter; - private readonly VectorType _type; + /// + /// Returns the isNA predicate for the respective type. + /// + private Delegate GetIsNADelegate(ColumnType type) + { + Func func = GetIsNADelegate; + return Utils.MarshalInvoke(func, type.ItemType.RawType, type); + } - public SlotCursor(NAReplaceTransform parent, int iinfo, ISlotCursor cursor, VectorType type) - : base(parent.Host, cursor) + private Delegate GetIsNADelegate(ColumnType type) { - Ch.Assert(0 <= iinfo && iinfo < parent.Infos.Length); - Ch.AssertValue(cursor); - Ch.AssertValue(type); - var srcGetter = cursor.GetGetter(); - _type = type; - _getter = CreateGetter(parent, iinfo, cursor, type); + return Conversions.Instance.GetIsNAPredicate(type.ItemType); } - private ValueGetter> CreateGetter(NAReplaceTransform parent, int iinfo, ISlotCursor cursor, VectorType type) + private ColInfo[] CreateInfos(ISchema inputSchema) { - var src = default(VBuffer); - ValueGetter> getter; + Host.AssertValue(inputSchema); + var infos = new ColInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + _parent.CheckInputColumn(inputSchema, i, colSrc); + var type = inputSchema.GetColumnType(colSrc); + infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + } + return infos; + } - var getSrc = cursor.GetGetter(); - var isNA = (RefPredicate)parent._isNAs[iinfo]; - var isDefault = Conversions.Instance.GetIsDefaultPredicate(type.ItemType); + public override RowMapperColumnInfo[] GetOutputColumns() + { + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + Host.Assert(colIndex >= 0); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + foreach (var type in InputSchema.GetMetadataTypes(colIndex).Where(x => x.Key == MetadataUtils.Kinds.SlotNames || x.Key == MetadataUtils.Kinds.IsNormalized)) + Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, InputSchema, type.Key, type.Value, colIndex); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _parent._types[i], colMetaInfo); + } + return result; + } - if (parent._repIsDefault[iinfo] == null) + private int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, int originalCol) + { + MetadataUtils.MetadataGetter getter = (int col, ref T dst) => { - // One replacement value for all slots. - Ch.Assert(parent._repValues[iinfo] is T); - T rep = (T)parent._repValues[iinfo]; - bool repIsDefault = isDefault(ref rep); + // We don't care about 'col': this getter is specialized for a column 'originalCol', + // and 'col' in this case is the 'metadata kind index', not the column index. + schema.GetMetadata(kind, originalCol, ref dst); + }; + var info = new MetadataInfo(ct, getter); + colMetaInfo.Add(kind, info); + return 0; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + disposer = null; + + if (!_infos[iinfo].TypeSrc.IsVector) + return ComposeGetterOne(input, iinfo); + return ComposeGetterVec(input, iinfo); + } - return (ref VBuffer dst) => + /// + /// Getter generator for single valued inputs. + /// + private Delegate ComposeGetterOne(IRow input, int iinfo) + => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].TypeSrc.RawType, input, iinfo); + + /// + /// Replaces NA values for scalars. + /// + private Delegate ComposeGetterOne(IRow input, int iinfo) + { + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + var src = default(T); + var isNA = (RefPredicate)_isNAs[iinfo]; + Host.Assert(_parent._repValues[iinfo] is T); + T rep = (T)_parent._repValues[iinfo]; + ValueGetter getter; + + return getter = + (ref T dst) => { getSrc(ref src); - parent.FillValues(ref src, ref dst, isNA, rep, repIsDefault); + dst = isNA(ref src) ? rep : src; }; + } + + /// + /// Getter generator for vector valued inputs. + /// + private Delegate ComposeGetterVec(IRow input, int iinfo) + => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); + + /// + /// Replaces NA values for vectors. + /// + private Delegate ComposeGetterVec(IRow input, int iinfo) + { + //IVAN: check you use corret iinfo or maybe it should be ColMapNewToOld[iinfo] everywhere in code + var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); + var isNA = (RefPredicate)_isNAs[iinfo]; + var isDefault = Conversions.Instance.GetIsDefaultPredicate(_infos[iinfo].TypeSrc.ItemType); + + var src = default(VBuffer); + ValueGetter> getter; + + if (_parent._repIsDefault[iinfo] == null) + { + // One replacement value for all slots. + Host.Assert(_parent._repValues[iinfo] is T); + T rep = (T)_parent._repValues[iinfo]; + bool repIsDefault = isDefault(ref rep); + return getter = + (ref VBuffer dst) => + { + getSrc(ref src); + FillValues(ref src, ref dst, isNA, rep, repIsDefault); + }; } // Replacement values by slot. - Ch.Assert(parent._repValues[iinfo] is T[]); + Host.Assert(_parent._repValues[iinfo] is T[]); // The replacement array. - T[] repArray = (T[])parent._repValues[iinfo]; + T[] repArray = (T[])_parent._repValues[iinfo]; return getter = (ref VBuffer dst) => { getSrc(ref src); - Ch.Check(0 <= Position && Position < repArray.Length); - T rep = repArray[(int)Position]; - parent.FillValues(ref src, ref dst, isNA, rep, isDefault(ref rep)); + Host.Check(src.Length == repArray.Length); + FillValues(ref src, ref dst, isNA, repArray, _parent._repIsDefault[iinfo]); }; } - public VectorType GetSlotType() => _type; - - public ValueGetter> GetGetter() + /// + /// Fills values for vectors where there is one replacement value. + /// + private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T rep, bool repIsDefault) { - ValueGetter> getter = _getter as ValueGetter>; - if (getter == null) - throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); - return getter; - } - } - - /// - /// Fills values for vectors where there is one replacement value. - /// - private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T rep, bool repIsDefault) - { - Host.AssertValue(isNA); + Host.AssertValue(isNA); - int srcSize = src.Length; - int srcCount = src.Count; - var srcValues = src.Values; - Host.Assert(Utils.Size(srcValues) >= srcCount); - var srcIndices = src.Indices; + int srcSize = src.Length; + int srcCount = src.Count; + var srcValues = src.Values; + Host.Assert(Utils.Size(srcValues) >= srcCount); + var srcIndices = src.Indices; - var dstValues = dst.Values; - var dstIndices = dst.Indices; + var dstValues = dst.Values; + var dstIndices = dst.Indices; - // If the values array is not large enough, allocate sufficient space. - // Note: We can't set the max to srcSize as vectors can be of variable lengths. - Utils.EnsureSize(ref dstValues, srcCount, keepOld: false); + // If the values array is not large enough, allocate sufficient space. + // Note: We can't set the max to srcSize as vectors can be of variable lengths. + Utils.EnsureSize(ref dstValues, srcCount, keepOld: false); - int iivDst = 0; - if (src.IsDense) - { - // The source vector is dense. - Host.Assert(srcSize == srcCount); + int iivDst = 0; + if (src.IsDense) + { + // The source vector is dense. + Host.Assert(srcSize == srcCount); - for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) + for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) + { + var srcVal = srcValues[ivSrc]; + + // The output for dense inputs is always dense. + // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with + // the default value, resulting in more than half of the indices being the default value. + // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision + // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. + dstValues[ivSrc] = isNA(ref srcVal) ? rep : srcVal; + } + iivDst = srcCount; + } + else { - var srcVal = srcValues[ivSrc]; - - // The output for dense inputs is always dense. - // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with - // the default value, resulting in more than half of the indices being the default value. - // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision - // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. - dstValues[ivSrc] = isNA(ref srcVal) ? rep : srcVal; + // The source vector is sparse. + Host.Assert(Utils.Size(srcIndices) >= srcCount); + Host.Assert(srcCount < srcSize); + + // Allocate more space if necessary. + // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays + // does is over-allocate space if the replacement value is the default value in a dataset with a + // signficiant amount of NA values -- is it worth handling allocation of memory for this case? + Utils.EnsureSize(ref dstIndices, srcCount, keepOld: false); + + // Note: ivPrev is only used for asserts. + int ivPrev = -1; + for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) + { + Host.Assert(iivDst <= iivSrc); + var srcVal = srcValues[iivSrc]; + int iv = srcIndices[iivSrc]; + Host.Assert(ivPrev < iv & iv < srcSize); + ivPrev = iv; + + if (!isNA(ref srcVal)) + { + dstValues[iivDst] = srcVal; + dstIndices[iivDst++] = iv; + } + else if (!repIsDefault) + { + // Allow for further sparsification. + dstValues[iivDst] = rep; + dstIndices[iivDst++] = iv; + } + } + Host.Assert(iivDst <= srcCount); } - iivDst = srcCount; + Host.Assert(0 <= iivDst); + Host.Assert(repIsDefault || iivDst == srcCount); + dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); } - else + + /// + /// Fills values for vectors where there is slot-wise replacement values. + /// + private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T[] rep, BitArray repIsDefault) { - // The source vector is sparse. - Host.Assert(Utils.Size(srcIndices) >= srcCount); - Host.Assert(srcCount < srcSize); - - // Allocate more space if necessary. - // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays - // does is over-allocate space if the replacement value is the default value in a dataset with a - // signficiant amount of NA values -- is it worth handling allocation of memory for this case? - Utils.EnsureSize(ref dstIndices, srcCount, keepOld: false); - - // Note: ivPrev is only used for asserts. - int ivPrev = -1; - for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) + Host.AssertValue(rep); + Host.Assert(rep.Length == src.Length); + Host.AssertValue(repIsDefault); + Host.Assert(repIsDefault.Length == src.Length); + Host.AssertValue(isNA); + + int srcSize = src.Length; + int srcCount = src.Count; + var srcValues = src.Values; + Host.Assert(Utils.Size(srcValues) >= srcCount); + var srcIndices = src.Indices; + + var dstValues = dst.Values; + var dstIndices = dst.Indices; + + // If the values array is not large enough, allocate sufficient space. + Utils.EnsureSize(ref dstValues, srcCount, srcSize, keepOld: false); + + int iivDst = 0; + Host.Assert(Utils.Size(srcValues) >= srcCount); + if (src.IsDense) { - Host.Assert(iivDst <= iivSrc); - var srcVal = srcValues[iivSrc]; - int iv = srcIndices[iivSrc]; - Host.Assert(ivPrev < iv & iv < srcSize); - ivPrev = iv; + // The source vector is dense. + Host.Assert(srcSize == srcCount); - if (!isNA(ref srcVal)) + for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) { - dstValues[iivDst] = srcVal; - dstIndices[iivDst++] = iv; + var srcVal = srcValues[ivSrc]; + + // The output for dense inputs is always dense. + // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with + // the default value, resulting in more than half of the indices being the default value. + // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision + // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. + dstValues[ivSrc] = isNA(ref srcVal) ? rep[ivSrc] : srcVal; } - else if (!repIsDefault) + iivDst = srcCount; + } + else + { + // The source vector is sparse. + Host.Assert(Utils.Size(srcIndices) >= srcCount); + Host.Assert(srcCount < srcSize); + + // Allocate more space if necessary. + // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays + // does is over-allocate space if the replacement value is the default value in a dataset with a + // signficiant amount of NA values -- is it worth handling allocation of memory for this case? + Utils.EnsureSize(ref dstIndices, srcCount, srcSize, keepOld: false); + + // Note: ivPrev is only used for asserts. + int ivPrev = -1; + for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) { - // Allow for further sparsification. - dstValues[iivDst] = rep; - dstIndices[iivDst++] = iv; + Host.Assert(iivDst <= iivSrc); + var srcVal = srcValues[iivSrc]; + int iv = srcIndices[iivSrc]; + Host.Assert(ivPrev < iv & iv < srcSize); + ivPrev = iv; + + if (!isNA(ref srcVal)) + { + dstValues[iivDst] = srcVal; + dstIndices[iivDst++] = iv; + } + else if (!repIsDefault[iv]) + { + // Allow for further sparsification. + dstValues[iivDst] = rep[iv]; + dstIndices[iivDst++] = iv; + } } + Host.Assert(iivDst <= srcCount); } - Host.Assert(iivDst <= srcCount); + Host.Assert(0 <= iivDst); + dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); } - Host.Assert(0 <= iivDst); - Host.Assert(repIsDefault || iivDst == srcCount); - dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); - } - /// - /// Fills values for vectors where there is slot-wise replacement values. - /// - private void FillValues(ref VBuffer src, ref VBuffer dst, RefPredicate isNA, T[] rep, BitArray repIsDefault) - { - Host.AssertValue(rep); - Host.Assert(rep.Length == src.Length); - Host.AssertValue(repIsDefault); - Host.Assert(repIsDefault.Length == src.Length); - Host.AssertValue(isNA); - - int srcSize = src.Length; - int srcCount = src.Count; - var srcValues = src.Values; - Host.Assert(Utils.Size(srcValues) >= srcCount); - var srcIndices = src.Indices; - - var dstValues = dst.Values; - var dstIndices = dst.Indices; - - // If the values array is not large enough, allocate sufficient space. - Utils.EnsureSize(ref dstValues, srcCount, srcSize, keepOld: false); - - int iivDst = 0; - Host.Assert(Utils.Size(srcValues) >= srcCount); - if (src.IsDense) + public void SaveAsOnnx(OnnxContext ctx) { - // The source vector is dense. - Host.Assert(srcSize == srcCount); + Host.CheckValue(ctx, nameof(ctx)); - for (int ivSrc = 0; ivSrc < srcCount; ivSrc++) + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { - var srcVal = srcValues[ivSrc]; - - // The output for dense inputs is always dense. - // Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with - // the default value, resulting in more than half of the indices being the default value. - // In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision - // is it is not worth handling this case at the expense of running checks that will almost always not be triggered. - dstValues[ivSrc] = isNA(ref srcVal) ? rep[ivSrc] : srcVal; - } - iivDst = srcCount; - } - else - { - // The source vector is sparse. - Host.Assert(Utils.Size(srcIndices) >= srcCount); - Host.Assert(srcCount < srcSize); - - // Allocate more space if necessary. - // REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays - // does is over-allocate space if the replacement value is the default value in a dataset with a - // signficiant amount of NA values -- is it worth handling allocation of memory for this case? - Utils.EnsureSize(ref dstIndices, srcCount, srcSize, keepOld: false); - - // Note: ivPrev is only used for asserts. - int ivPrev = -1; - for (int iivSrc = 0; iivSrc < srcCount; iivSrc++) - { - Host.Assert(iivDst <= iivSrc); - var srcVal = srcValues[iivSrc]; - int iv = srcIndices[iivSrc]; - Host.Assert(ivPrev < iv & iv < srcSize); - ivPrev = iv; - - if (!isNA(ref srcVal)) + ColInfo info = _infos[iinfo]; + string sourceColumnName = info.Source; + if (!ctx.ContainsColumn(sourceColumnName)) { - dstValues[iivDst] = srcVal; - dstIndices[iivDst++] = iv; + ctx.RemoveColumn(info.Name, false); + continue; } - else if (!repIsDefault[iv]) + + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + ctx.AddIntermediateVariable(_parent._types[iinfo], info.Name))) { - // Allow for further sparsification. - dstValues[iivDst] = rep[iv]; - dstIndices[iivDst++] = iv; + ctx.RemoveColumn(info.Name, true); } } - Host.Assert(iivDst <= srcCount); } - Host.Assert(0 <= iivDst); - dst = new VBuffer(srcSize, iivDst, dstValues, dstIndices); + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + { + DataKind rawKind; + var type = _infos[iinfo].TypeSrc; + if (type.IsVector) + rawKind = type.AsVector.ItemType.RawKind; + else if (type.IsKey) + rawKind = type.AsKey.RawKind; + else + rawKind = type.RawKind; + + if (rawKind != DataKind.R4) + return false; + + string opType = "Imputer"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("replaced_value_float", Single.NaN); + + if (!_infos[iinfo].TypeSrc.IsVector) + node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1)); + else + { + if (_parent._repIsDefault[iinfo] != null) + node.AddAttribute("imputed_value_floats", (float[])_parent._repValues[iinfo]); + else + node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1)); + } + + return true; + } + } + } + + public sealed class NAReplaceEstimator : IEstimator + { + private readonly IHost _host; + private readonly NAReplaceTransform.ColumnInfo[] _columns; + + public NAReplaceEstimator(IHostEnvironment env, string name, string source = null, NAReplaceTransform.ReplacementKind replacementKind = NAReplaceTransform.ReplacementKind.DefaultValue) + : this(env, new NAReplaceTransform.ColumnInfo(source ?? name, name, replacementKind)) + { + + } + public NAReplaceEstimator(IHostEnvironment env, params NAReplaceTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(NAReplaceEstimator)); + _columns = columns; } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + //IVAN Proper validation, not copypaste from term estimator + if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + var metadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) + metadata.Add(slotMeta); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normalized)) + metadata.Add(normalized); + var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); + } + + return new SchemaShape(result.Values); + } + + public NAReplaceTransform Fit(IDataView input) => new NAReplaceTransform(_host, input, _columns); } } diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs new file mode 100644 index 0000000000..751fec1807 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -0,0 +1,93 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class NAReplaceTests : TestDataPipeBase + { + private class TestClass + { + public float A; + public string B; + public double C; + [VectorType(2)] + public float[] D; + [VectorType(2)] + public double[] E; + } + + public NAReplaceTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void NAReplaceWorkout() + { + var data = new[] { + new TestClass() { A = 1, B = "A", C = 3, D= new float[2]{ 1, 2 } , E = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = null, C = double.NaN, D= new float[2]{ float.NaN, float.NaN } , E = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = null, C = double.NegativeInfinity,D= new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , E = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = null, C = double.PositiveInfinity,D= new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , E = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = "B", C = 1 ,D= new float[2]{ 3, 4 } , E = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new NAReplaceEstimator(Env, + new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ReplacementKind.Mean), + new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ReplacementKind.Default), + new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ReplacementKind.Mean), + new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ReplacementKind.Mean), + new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ReplacementKind.Mean)); + TestEstimatorCore(pipe, dataView); + Done(); + } + + [Fact] + public void TestCommandLine() + { + using (var env = new TlcEnvironment()) + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=NAReplace{col=C:A} in=f:\2.txt" }), (int)0); + } + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { + new TestClass() { A = 1, B = "A", C = 3, D= new float[2]{ 1, 2 } , E = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = null, C = double.NaN, D= new float[2]{ float.NaN, float.NaN } , E = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = null, C = double.NegativeInfinity,D= new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , E = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = null, C = double.PositiveInfinity,D= new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , E = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = "B", C = 1 ,D= new float[2]{ 3, 4 } , E = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new NAReplaceEstimator(Env, + new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ReplacementKind.Mean), + new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ReplacementKind.Default), + new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ReplacementKind.Mean), + new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ReplacementKind.Mean), + new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ReplacementKind.Mean)); + + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + } +} From 2f79f974c2dcfaa29efae4668b5b065d7d5d2209 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 14 Sep 2018 11:45:12 -0700 Subject: [PATCH 2/8] some code rearrangemnts --- .../NAReplaceTransform.cs | 180 +++++++++--------- 1 file changed, 90 insertions(+), 90 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 9dc00fe8d0..fda0559d31 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -122,12 +122,51 @@ public sealed class Arguments : TransformInputBase } public const string LoadName = "NAReplaceTransform"; + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + // REVIEW: temporary name + modelSignature: "NAREP TF", + // verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x0010002, // Added imputation methods. + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010001, + loaderSignature: LoadName); + } + internal const string Summary = "Create an output column of the same type and size of the input column, where missing values " + "are replaced with either the default value or the mean/min/max value (for non-text columns only)."; internal const string FriendlyName = "NA Replace Transform"; internal const string ShortName = "NARep"; + private static string TestType(ColumnType type) + { + // Item type must have an NA value that exists and is not equal to its default value. + Func func = TestType; + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); + return (string)meth.Invoke(null, new object[] { type.ItemType }); + } + + private static string TestType(ColumnType type) + { + Contracts.Assert(type.ItemType.RawType == typeof(T)); + RefPredicate isNA; + if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out isNA)) + { + return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", + type, LoadName); + } + var t = default(T); + if (isNA(ref t)) + { + // REVIEW: Key values will be handled in a "new key value" transform. + return string.Format("Type '{0}' is not supported by {1} since its NA value is equivalent to its default value", + type, LoadName); + } + return null; + } + public class ColumnInfo { public readonly string Input; @@ -166,6 +205,9 @@ public ColInfo(string name, string source, ColumnType type) } } + // The output column types, parallel to Infos. + private readonly ColumnType[] _types; + // The replacementValues for the columns, parallel to Infos. // The elements of this array can be either primitive values or arrays of primitive values. When replacing a scalar valued column in Infos, // this array will hold a primitive value. When replacing a vector valued column in Infos, this array will either hold a primitive @@ -178,36 +220,6 @@ public ColInfo(string name, string source, ColumnType type) // REVIEW: Currently these arrays are constructed on load but could be changed to being constructed lazily. private readonly BitArray[] _repIsDefault; - // The output column types, parallel to Infos. - private readonly ColumnType[] _types; - - private string TestType(ColumnType type) - { - // Item type must have an NA value that exists and is not equal to its default value. - Func func = TestType; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); - return (string)meth.Invoke(null, new object[] { type.ItemType }); - } - - private static string TestType(ColumnType type) - { - Contracts.Assert(type.ItemType.RawType == typeof(T)); - RefPredicate isNA; - if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out isNA)) - { - return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", - type, LoadName); - } - var t = default(T); - if (isNA(ref t)) - { - // REVIEW: Key values will be handled in a "new key value" transform. - return string.Format("Type '{0}' is not supported by {1} since its NA value is equivalent to its default value", - type, LoadName); - } - return null; - } - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { var type = inputSchema.GetColumnType(srcCol); @@ -224,6 +236,54 @@ public NAReplaceTransform(IHostEnvironment env, IDataView input, params ColumnIn GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _types); } + private NAReplaceTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { + var columnsLength = ColumnPairs.Length; + _repValues = new object[columnsLength]; + _repIsDefault = new BitArray[columnsLength]; + _types = new ColumnType[columnsLength]; + var saver = new BinarySaver(Host, new BinarySaver.Arguments()); + for (int i = 0; i < columnsLength; i++) + { + if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out ColumnType savedType, out object repValue)) + throw Host.ExceptDecode(); + _types[i] = savedType; + if (savedType.IsVector) + { + // REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array. + // It might be of value to consider storing the VBUffer in order to possibly benefit from sparsity. However, this would + // necessitate a reimplementation of the FillValues code to accomodate sparse VBuffers. + object[] args = new object[] { repValue, _types[i], i }; + Func, ColumnType, int, int[]> func = GetValuesArray; + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(savedType.ItemType.RawType); + _repValues[i] = meth.Invoke(this, args); + } + else + _repValues[i] = repValue; + + Host.Assert(repValue.GetType() == _types[i].RawType || repValue.GetType() == _types[i].ItemType.RawType); + } + } + + private T[] GetValuesArray(VBuffer src, ColumnType srcType, int iinfo) + { + Host.Assert(srcType.IsVector); + Host.Assert(srcType.VectorSize == src.Length); + VBufferUtils.Densify(ref src); + RefPredicate defaultPred = Conversions.Instance.GetIsDefaultPredicate(srcType.ItemType); + _repIsDefault[iinfo] = new BitArray(srcType.VectorSize); + for (int slot = 0; slot < src.Length; slot++) + { + if (defaultPred(ref src.Values[slot])) + _repIsDefault[iinfo][slot] = true; + } + T[] valReturn = src.Values; + Array.Resize(ref valReturn, srcType.VectorSize); + Host.Assert(valReturn.Length == src.Length); + return valReturn; + } + /// /// Fill the repValues array with the correct replacement values based on the user-given replacement kinds. /// Vectors default to by-slot imputation unless otherwise specified, except for unknown sized vectors @@ -346,18 +406,6 @@ private object GetDefault() return default(T); } - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - // REVIEW: temporary name - modelSignature: "NAREP TF", - // verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x0010002, // Added imputation methods. - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); - } - // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { @@ -455,54 +503,6 @@ public override void Save(ModelSaveContext ctx) } } - private NAReplaceTransform(IHost host, ModelLoadContext ctx) - : base(host, ctx) - { - var columnsLength = ColumnPairs.Length; - _repValues = new object[columnsLength]; - _repIsDefault = new BitArray[columnsLength]; - _types = new ColumnType[columnsLength]; - var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int i = 0; i< columnsLength; i++) - { - if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out ColumnType savedType, out object repValue)) - throw Host.ExceptDecode(); - _types[i] = savedType; - if (savedType.IsVector) - { - // REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array. - // It might be of value to consider storing the VBUffer in order to possibly benefit from sparsity. However, this would - // necessitate a reimplementation of the FillValues code to accomodate sparse VBuffers. - object[] args = new object[] { repValue, _types[i], i }; - Func, ColumnType, int, int[]> func = GetValuesArray; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(savedType.ItemType.RawType); - _repValues[i] = meth.Invoke(this, args); - } - else - _repValues[i] = repValue; - - Host.Assert(repValue.GetType() == _types[i].RawType || repValue.GetType() == _types[i].ItemType.RawType); - } - } - - private T[] GetValuesArray(VBuffer src, ColumnType srcType, int iinfo) - { - Host.Assert(srcType.IsVector); - Host.Assert(srcType.VectorSize == src.Length); - VBufferUtils.Densify(ref src); - RefPredicate defaultPred = Conversions.Instance.GetIsDefaultPredicate(srcType.ItemType); - _repIsDefault[iinfo] = new BitArray(srcType.VectorSize); - for (int slot = 0; slot < src.Length; slot++) - { - if (defaultPred(ref src.Values[slot])) - _repIsDefault[iinfo][slot] = true; - } - T[] valReturn = src.Values; - Array.Resize(ref valReturn, srcType.VectorSize); - Host.Assert(valReturn.Length == src.Length); - return valReturn; - } - protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); From 57c0774aeb51157098d190fa234d22a60f48ee59 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 14 Sep 2018 14:54:25 -0700 Subject: [PATCH 3/8] some polishing --- .../NAHandleTransform.cs | 21 +-- .../NAReplaceTransform.cs | 168 ++++++++++-------- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 2 +- .../Transformers/NAReplaceTests.cs | 20 +-- 4 files changed, 118 insertions(+), 93 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index d723b52dee..eb520cbe15 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -20,28 +20,28 @@ namespace Microsoft.ML.Runtime.Data /// public static class NAHandleTransform { - public enum ReplacementKind + public enum ReplacementKind:byte { /// /// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns. /// [EnumValueDisplay("Zero/empty")] - DefaultValue, + DefaultValue=0, /// /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns. /// - Mean, + Mean=1, /// /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns. /// - Minimum, + Minimum = 2, /// /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns. /// - Maximum, + Maximum =3, [HideEnumValue] Def = DefaultValue, @@ -149,19 +149,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var addInd = column.ConcatIndicator ?? args.Concat; if (!addInd) { - replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ReplacementKind)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); continue; } // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type, // so that they can be concatenated. - int inputCol; - if (!input.Schema.TryGetColumnIndex(column.Source, out inputCol)) + if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol)) throw h.Except("Column '{0}' does not exist", column.Source); var replaceType = input.Schema.GetColumnType(inputCol); - Delegate conv; - bool identity; - if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out conv, out identity)) + if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity)) { throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'", BoolType.Instance, replaceType.ItemType); @@ -179,7 +176,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV naConvCols.Add(new ConvertTransform.Column() { Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind }); // Add the NAReplaceTransform column. - replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ReplacementKind)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType.IsVector) diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index fda0559d31..e4c5e7e334 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -37,14 +37,14 @@ namespace Microsoft.ML.Runtime.Data public sealed partial class NAReplaceTransform : OneToOneTransformerBase { //IVAN: CLEAN IT - public enum ReplacementKind + public enum ReplacementKind : byte { // REVIEW: What should the full list of options for this transform be? - DefaultValue, - Mean, - Minimum, - Maximum, - SpecifiedValue, + DefaultValue = 0, + Mean = 1, + Minimum = 2, + Maximum = 3, + SpecifiedValue = 4, [HideEnumValue] Def = DefaultValue, @@ -151,8 +151,7 @@ private static string TestType(ColumnType type) private static string TestType(ColumnType type) { Contracts.Assert(type.ItemType.RawType == typeof(T)); - RefPredicate isNA; - if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out isNA)) + if (!Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out RefPredicate isNA)) { return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", type, LoadName); @@ -169,18 +168,34 @@ private static string TestType(ColumnType type) public class ColumnInfo { + public enum ReplacementMode : byte + { + DefaultValue = 0, + Mean = 1, + Minimum = 2, + Maximum = 3, + } + public readonly string Input; public readonly string Output; public readonly bool ImputeBySlot; - public readonly ReplacementKind Kind; + public readonly ReplacementMode Replacement; - public ColumnInfo(string input, string output, ReplacementKind kind = ReplacementKind.DefaultValue, bool imputeBySlot = true) + /// + /// Description what to do with columns + /// + /// Input column + /// Output column + /// Strategy how to replace NA value + /// If we operate on vector array do we want to find replace value for each slot in vector or for whole vector? + public ColumnInfo(string input, string output, ReplacementMode replacementMode = ReplacementMode.DefaultValue, bool imputeBySlot = true) { Input = input; Output = output; ImputeBySlot = imputeBySlot; - Kind = kind; + Replacement = replacementMode; } + internal string ReplacementString { get; set; } } @@ -190,21 +205,6 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - ///IVAN: move to mapper. - internal sealed class ColInfo - { - public readonly string Name; - public readonly string Source; - public readonly ColumnType TypeSrc; - - public ColInfo(string name, string source, ColumnType type) - { - Name = name; - Source = source; - TypeSrc = type; - } - } - // The output column types, parallel to Infos. private readonly ColumnType[] _types; @@ -295,7 +295,7 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj slotIsDefault = new BitArray[columns.Length]; types = new ColumnType[columns.Length]; var sources = new int[columns.Length]; - ReplacementKind?[] imputationModes = new ReplacementKind?[columns.Length]; + ReplacementKind[] imputationModes = new ReplacementKind[columns.Length]; List columnsToImpute = null; // REVIEW: Would like to get rid of the sourceColumns list but seems to be the best way to provide @@ -308,13 +308,14 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj var type = input.Schema.GetColumnType(colSrc); if (type.IsVector) type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + Delegate isNa = GetIsNADelegate(type); types[iinfo] = type; - switch (columns[iinfo].Kind) + var kind = (ReplacementKind)columns[iinfo].Replacement; + switch (kind) { - //IVAN: what to do with you? - //case ReplacementKind.SpecifiedValue: - // repValues[iinfo] = GetSpecifiedValue(args.Column[iinfo].ReplacementString, _types[iinfo], _isNAs[iinfo]); - // break; + case ReplacementKind.SpecifiedValue: + repValues[iinfo] = GetSpecifiedValue(columns[iinfo].ReplacementString, _types[iinfo], isNa); + break; case ReplacementKind.DefaultValue: repValues[iinfo] = GetDefault(type); break; @@ -323,13 +324,13 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj case ReplacementKind.Maximum: if (!type.ItemType.IsNumber && !type.ItemType.IsTimeSpan && !type.ItemType.IsDateTime) throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.ItemType); - imputationModes[iinfo] = columns[iinfo].Kind; + imputationModes[iinfo] = kind; Utils.Add(ref columnsToImpute, iinfo); Utils.Add(ref sourceColumns, colSrc); break; default: Host.Assert(false); - throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", columns[iinfo].Kind); + throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", columns[iinfo].Replacement); } } @@ -406,6 +407,47 @@ private object GetDefault() return default(T); } + /// + /// Returns the isNA predicate for the respective type. + /// + private Delegate GetIsNADelegate(ColumnType type) + { + Func func = GetIsNADelegate; + return Utils.MarshalInvoke(func, type.ItemType.RawType, type); + } + + private Delegate GetIsNADelegate(ColumnType type) + { + return Conversions.Instance.GetIsNAPredicate(type.ItemType); + } + + /// + /// Converts a string to its respective value in the corresponding type. + /// + private object GetSpecifiedValue(string srcStr, ColumnType dstType, Delegate isNA) + { + Func, object> func = GetSpecifiedValue; + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(dstType.ItemType.RawType); + return meth.Invoke(this, new object[] { srcStr, dstType, isNA }); + } + + private object GetSpecifiedValue(string srcStr, ColumnType dstType, RefPredicate isNA) + { + var val = default(T); + if (!string.IsNullOrEmpty(srcStr)) + { + // Handles converting input strings to correct types. + DvText srcTxt = new DvText(srcStr); + var strToT = Conversions.Instance.GetStandardConversion(TextType.Instance, dstType.ItemType, out bool identity); + strToT(ref srcTxt, ref val); + // Make sure that the srcTxt can legitimately be converted to dstType, throw error otherwise. + if (isNA(ref val)) + throw Contracts.Except("No conversion of '{0}' to '{1}'", srcStr, dstType.ItemType); + } + + return val; + } + // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { @@ -424,7 +466,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV cols[i] = new ColumnInfo(item.Source, item.Name, - item.Kind ?? args.ReplacementKind, + (ColumnInfo.ReplacementMode)(item.Kind ?? args.ReplacementKind), item.Slot ?? args.ImputeBySlot); cols[i].ReplacementString = item.ReplacementString; }; @@ -468,8 +510,7 @@ private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType t Host.AssertValue(saver); Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); - int bytesWritten; - if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out bytesWritten)) + if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out int bytesWritten)) throw Host.Except("We do not know how to serialize terms of type '{0}'", type); } @@ -508,6 +549,21 @@ protected override IRowMapper MakeRowMapper(ISchema schema) private sealed class Mapper : MapperBase, ISaveAsOnnx { + + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + private readonly NAReplaceTransform _parent; private readonly ColInfo[] _infos; // The isNA delegates, parallel to Infos. @@ -523,24 +579,10 @@ public Mapper(NAReplaceTransform parent, ISchema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { var type = _infos[i].TypeSrc; - _isNAs[i] = GetIsNADelegate(type); + _isNAs[i] = _parent.GetIsNADelegate(type); } } - /// - /// Returns the isNA predicate for the respective type. - /// - private Delegate GetIsNADelegate(ColumnType type) - { - Func func = GetIsNADelegate; - return Utils.MarshalInvoke(func, type.ItemType.RawType, type); - } - - private Delegate GetIsNADelegate(ColumnType type) - { - return Conversions.Instance.GetIsNAPredicate(type.ItemType); - } - private ColInfo[] CreateInfos(ISchema inputSchema) { Host.AssertValue(inputSchema); @@ -564,26 +606,12 @@ public override RowMapperColumnInfo[] GetOutputColumns() InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); Host.Assert(colIndex >= 0); var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); - foreach (var type in InputSchema.GetMetadataTypes(colIndex).Where(x => x.Key == MetadataUtils.Kinds.SlotNames || x.Key == MetadataUtils.Kinds.IsNormalized)) - Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, InputSchema, type.Key, type.Value, colIndex); - result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _parent._types[i], colMetaInfo); + var meta = RowColumnUtils.GetMetadataAsRow(InputSchema, colIndex, x => x == MetadataUtils.Kinds.SlotNames || x == MetadataUtils.Kinds.IsNormalized); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _parent._types[i], meta); } return result; } - private int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, int originalCol) - { - MetadataUtils.MetadataGetter getter = (int col, ref T dst) => - { - // We don't care about 'col': this getter is specialized for a column 'originalCol', - // and 'col' in this case is the 'metadata kind index', not the column index. - schema.GetMetadata(kind, originalCol, ref dst); - }; - var info = new MetadataInfo(ct, getter); - colMetaInfo.Add(kind, info); - return 0; - } - protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { Host.AssertValue(input); @@ -880,7 +908,6 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src else node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1)); } - return true; } } @@ -891,11 +918,12 @@ public sealed class NAReplaceEstimator : IEstimator private readonly IHost _host; private readonly NAReplaceTransform.ColumnInfo[] _columns; - public NAReplaceEstimator(IHostEnvironment env, string name, string source = null, NAReplaceTransform.ReplacementKind replacementKind = NAReplaceTransform.ReplacementKind.DefaultValue) + public NAReplaceEstimator(IHostEnvironment env, string name, string source = null, NAReplaceTransform.ColumnInfo.ReplacementMode replacementKind = NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue) : this(env, new NAReplaceTransform.ColumnInfo(source ?? name, name, replacementKind)) { } + public NAReplaceEstimator(IHostEnvironment env, params NAReplaceTransform.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index 2340f9b413..8100a5f84b 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -1597,7 +1597,7 @@ public override object GetStat() // If sparsity occurred, fold in a zero. if (ValueCount > (ulong)ValuesProcessed) { - TItem def = default(TItem); + TItem def = default; ProcValueDelegate(ref def); } return _converter.FromLong(Stat); diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs index 751fec1807..b1ebd43b3d 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -43,11 +43,11 @@ public void NAReplaceWorkout() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new NAReplaceEstimator(Env, - new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ReplacementKind.Mean), - new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ReplacementKind.Default), - new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ReplacementKind.Mean), - new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ReplacementKind.Mean), - new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ReplacementKind.Mean)); + new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue), + new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean)); TestEstimatorCore(pipe, dataView); Done(); } @@ -74,11 +74,11 @@ public void TestOldSavingAndLoading() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new NAReplaceEstimator(Env, - new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ReplacementKind.Mean), - new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ReplacementKind.Default), - new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ReplacementKind.Mean), - new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ReplacementKind.Mean), - new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ReplacementKind.Mean)); + new NAReplaceTransform.ColumnInfo("A", "NAA", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("B", "NAB", NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue), + new NAReplaceTransform.ColumnInfo("C", "NAC", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("D", "NAD", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + new NAReplaceTransform.ColumnInfo("E", "NAE", NAReplaceTransform.ColumnInfo.ReplacementMode.Mean)); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); From 7ee53322cf5accd7d40d6a5c0822004058e02018 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 14 Sep 2018 15:16:17 -0700 Subject: [PATCH 4/8] clean code more --- .../NAReplaceTransform.cs | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index e4c5e7e334..0c91dd7e0c 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -36,7 +36,6 @@ namespace Microsoft.ML.Runtime.Data { public sealed partial class NAReplaceTransform : OneToOneTransformerBase { - //IVAN: CLEAN IT public enum ReplacementKind : byte { // REVIEW: What should the full list of options for this transform be? @@ -140,7 +139,7 @@ private static VersionInfo GetVersionInfo() internal const string FriendlyName = "NA Replace Transform"; internal const string ShortName = "NARep"; - private static string TestType(ColumnType type) + internal static string TestType(ColumnType type) { // Item type must have an NA value that exists and is not equal to its default value. Func func = TestType; @@ -225,19 +224,24 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo var type = inputSchema.GetColumnType(srcCol); string reason = TestType(type); if (reason != null) - //IVAN: not sure about schema mismatch - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); + throw Host.ExceptParam(nameof(inputSchema), reason); } public NAReplaceTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAReplaceTransform)), GetColumnPairs(columns)) { - // Validate input schema. + // Check that all the input columns are present and correct. + for (int i = 0; i < ColumnPairs.Length; i++) + { + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + CheckInputColumn(input.Schema, i, srcCol); + } GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _types); } private NAReplaceTransform(IHost host, ModelLoadContext ctx) - : base(host, ctx) + : base(host, ctx) { var columnsLength = ColumnPairs.Length; _repValues = new object[columnsLength]; @@ -660,7 +664,6 @@ private Delegate ComposeGetterVec(IRow input, int iinfo) /// private Delegate ComposeGetterVec(IRow input, int iinfo) { - //IVAN: check you use corret iinfo or maybe it should be ColMapNewToOld[iinfo] everywhere in code var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); var isNA = (RefPredicate)_isNAs[iinfo]; var isDefault = Conversions.Instance.GetIsDefaultPredicate(_infos[iinfo].TypeSrc.ItemType); @@ -939,9 +942,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - //IVAN Proper validation, not copypaste from term estimator - if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + string reason = NAReplaceTransform.TestType(col.ItemType); + if (reason != null) + throw _host.ExceptParam(nameof(inputSchema), reason); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); From 0d064c30fcd5eb889e397ea4c07470e4ddc1b58f Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 14 Sep 2018 15:58:33 -0700 Subject: [PATCH 5/8] csharp api strike back! --- src/Microsoft.ML/CSharpApi.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 708c8c62f6..d031f53d06 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -13979,7 +13979,7 @@ public MinMaxNormalizerPipelineStep(Output output) namespace Transforms { - public enum NAHandleTransformReplacementKind + public enum NAHandleTransformReplacementKind : byte { DefaultValue = 0, Mean = 1, @@ -14440,7 +14440,7 @@ public MissingValuesRowDropperPipelineStep(Output output) namespace Transforms { - public enum NAReplaceTransformReplacementKind + public enum NAReplaceTransformReplacementKind : byte { DefaultValue = 0, Mean = 1, From f2869fe568bb13fa85a390869785d2407113ecef Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 17 Sep 2018 15:26:36 -0700 Subject: [PATCH 6/8] merge with master + pigsty --- .../NAHandleTransform.cs | 10 +- .../NAReplaceTransform.cs | 166 +++++++++++++++++- .../KeyToBinaryVectorEstimatorTest.cs | 5 +- .../Transformers/KeyToVectorEstimatorTests.cs | 5 +- .../Transformers/NAReplaceTests.cs | 47 ++++- 5 files changed, 212 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index 0245c3fbab..bf3ab5f41a 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -20,18 +20,18 @@ namespace Microsoft.ML.Runtime.Data /// public static class NAHandleTransform { - public enum ReplacementKind:byte + public enum ReplacementKind : byte { /// /// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns. /// [EnumValueDisplay("Zero/empty")] - DefaultValue=0, + DefaultValue = 0, /// /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns. /// - Mean=1, + Mean = 1, /// /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns. @@ -41,7 +41,7 @@ public enum ReplacementKind:byte /// /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns. /// - Maximum =3, + Maximum = 3, [HideEnumValue] Def = DefaultValue, @@ -158,7 +158,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol)) throw h.Except("Column '{0}' does not exist", column.Source); var replaceType = input.Schema.GetColumnType(inputCol); - if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity)) + if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity)) { throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'", BoolType.Instance, replaceType.ItemType); diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 0c91dd7e0c..d7d4820819 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -19,6 +19,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; [assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] @@ -34,6 +35,12 @@ namespace Microsoft.ML.Runtime.Data { + // This transform can transform either scalars or vectors (both fixed and variable size), + // creating output columns that are identical to the input columns except for replacing NA values + // with either the default value, user input, or imputed values (min/max/mean are currently supported). + // Imputation modes are supported for vectors both by slot and across all slots. + // REVIEW: May make sense to implement the transform template interface. + /// public sealed partial class NAReplaceTransform : OneToOneTransformerBase { public enum ReplacementKind : byte @@ -113,14 +120,15 @@ public sealed class Arguments : TransformInputBase public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "The replacement method to utilize", ShortName = "kind")] - public ReplacementKind ReplacementKind = ReplacementKind.DefaultValue; + public ReplacementKind ReplacementKind = (ReplacementKind)NAReplaceEstimator.Defaults.ReplacementMode; // Specifying by-slot imputation for vectors of unknown size will cause a warning, and the imputation will be global. [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to impute values by slot", ShortName = "slot")] - public bool ImputeBySlot = true; + public bool ImputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot; } public const string LoadName = "NAReplaceTransform"; + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -187,7 +195,8 @@ public enum ReplacementMode : byte /// Output column /// Strategy how to replace NA value /// If we operate on vector array do we want to find replace value for each slot in vector or for whole vector? - public ColumnInfo(string input, string output, ReplacementMode replacementMode = ReplacementMode.DefaultValue, bool imputeBySlot = true) + public ColumnInfo(string input, string output, ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, + bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Input = input; Output = output; @@ -252,6 +261,7 @@ private NAReplaceTransform(IHost host, ModelLoadContext ctx) { if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out ColumnType savedType, out object repValue)) throw Host.ExceptDecode(); + //IVAN: savedType contains _types[iinfo].ItemType not _types[iinfo] _types[i] = savedType; if (savedType.IsVector) { @@ -918,10 +928,16 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src public sealed class NAReplaceEstimator : IEstimator { + public static class Defaults + { + public const NAReplaceTransform.ColumnInfo.ReplacementMode ReplacementMode = NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue; + public const bool ImputeBySlot = true; + } + private readonly IHost _host; private readonly NAReplaceTransform.ColumnInfo[] _columns; - public NAReplaceEstimator(IHostEnvironment env, string name, string source = null, NAReplaceTransform.ColumnInfo.ReplacementMode replacementKind = NAReplaceTransform.ColumnInfo.ReplacementMode.DefaultValue) + public NAReplaceEstimator(IHostEnvironment env, string name, string source = null, NAReplaceTransform.ColumnInfo.ReplacementMode replacementKind = Defaults.ReplacementMode) : this(env, new NAReplaceTransform.ColumnInfo(source ?? name, name, replacementKind)) { @@ -953,10 +969,150 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector); result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } - return new SchemaShape(result.Values); } public NAReplaceTransform Fit(IDataView input) => new NAReplaceTransform(_host, input, _columns); } + + /// + /// Extension methods for the static-pipeline over objects. + /// + public static class NAReplaceExtensions + { + private struct Config + { + public readonly bool ImputeBySlot; + public readonly NAReplaceTransform.ColumnInfo.ReplacementMode ReplacementMode; + + public Config(NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, + bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + ImputeBySlot = imputeBySlot; + ReplacementMode = replacementMode; + } + } + + private interface IColInput + { + PipelineColumn Input { get; } + Config Config { get; } + } + + private sealed class OutScalar : Scalar, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutScalar(Scalar input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutVectorColumn(Vector input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + + } + + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutVarVectorColumn(VarVector input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var infos = new NAReplaceTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var col = (IColInput)toOutput[i]; + infos[i] = new NAReplaceTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Config.ReplacementMode, col.Config.ImputeBySlot); + } + return new NAReplaceEstimator(env, infos); + } + } + + public static Scalar NAReplace(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input, new Config( replacementMode, imputeBySlot)); + } + + public static Scalar NAReplace(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input, new Config(replacementMode, imputeBySlot)); + } + + public static Scalar NAReplace (this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input, new Config(replacementMode, imputeBySlot)); + } + + public static Vector NAReplace(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static Vector NAReplace(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static Vector NAReplace(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static VarVector NAReplace(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static VarVector NAReplace(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + + public static VarVector NAReplace(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); + } + } } diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index d7b57c232a..c5f3b68baf 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -150,10 +150,7 @@ private void ValidateMetadata(IDataView result) [Fact] public void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0); - } + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 1701c17083..3584159ac3 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -215,10 +215,7 @@ private void ValidateMetadata(IDataView result) [Fact] public void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0); - } + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs index b1ebd43b3d..369668dd55 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; @@ -53,12 +54,52 @@ public void NAReplaceWorkout() } [Fact] - public void TestCommandLine() + public void NAReplaceStatic() { - using (var env = new TlcEnvironment()) + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarString: ctx.LoadText(1), + ScalarFloat: ctx.LoadFloat(1), + ScalarDouble: ctx.LoadDouble(1), + VectorString: ctx.LoadText(1, 4), + VectorFloat: ctx.LoadFloat(1, 4), + VectorDoulbe: ctx.LoadDouble(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)); + var wrongCollection = new[] { new TestClass() { A = 1, B = "A", C = 3, D = new float[2] { 1, 2 }, E = new double[2] { 3, 4 } } }; + var invalidData = ComponentCreation.CreateDataView(Env, wrongCollection); + + var est = data.MakeNewEstimator(). + Append(row => ( + A: row.ScalarString.NAReplace(), + B: row.ScalarFloat.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Maximum), + C: row.ScalarDouble.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + D: row.VectorString.NAReplace(), + E: row.VectorFloat.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + F: row.VectorDoulbe.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Minimum) + )); + + TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); + var outputPath = GetOutputPath("NAReplace", "featurized.tsv"); + using (var ch = Env.Start("save")) { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=NAReplace{col=C:A} in=f:\2.txt" }), (int)0); + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); + savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D", "E"); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); } + + CheckEquality("NAReplace", "featurized.tsv"); + Done(); + Done(); + } + + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=NAReplace{col=C:A} in=f:\2.txt" }), (int)0); } [Fact] From 8a94270cdec3d4fb32e2cda29035a3a157de5138 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 17 Sep 2018 17:54:29 -0700 Subject: [PATCH 7/8] update --- .../NAReplaceTransform.cs | 100 +++++++++++------- .../SingleDebug/NAReplace/featurized.tsv | 14 +++ .../SingleRelease/NAReplace/featurized.tsv | 14 +++ .../Transformers/NAReplaceTests.cs | 13 ++- 4 files changed, 95 insertions(+), 46 deletions(-) create mode 100644 test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv create mode 100644 test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index d7d4820819..b106b7f376 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -2,24 +2,24 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections; -using System.Collections.Generic; -using System.Reflection; -using System.Text; -using System.IO; -using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data.StaticPipe.Runtime; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; [assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] @@ -189,12 +189,14 @@ public enum ReplacementMode : byte public readonly ReplacementMode Replacement; /// - /// Description what to do with columns + /// Describes how the transformer handles one column pair. /// - /// Input column - /// Output column - /// Strategy how to replace NA value - /// If we operate on vector array do we want to find replace value for each slot in vector or for whole vector? + /// Name of input column. + /// Name of output column. + /// What to replace the missing value with. + /// If true, per-slot imputation of replacement is performed. + /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, + /// where imputation is always for the entire column. public ColumnInfo(string input, string output, ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { @@ -214,7 +216,7 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum } // The output column types, parallel to Infos. - private readonly ColumnType[] _types; + private readonly ColumnType[] _replaceTypes; // The replacementValues for the columns, parallel to Infos. // The elements of this array can be either primitive values or arrays of primitive values. When replacing a scalar valued column in Infos, @@ -246,7 +248,7 @@ public NAReplaceTransform(IHostEnvironment env, IDataView input, params ColumnIn throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); CheckInputColumn(input.Schema, i, srcCol); } - GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _types); + GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _replaceTypes); } private NAReplaceTransform(IHost host, ModelLoadContext ctx) @@ -255,20 +257,19 @@ private NAReplaceTransform(IHost host, ModelLoadContext ctx) var columnsLength = ColumnPairs.Length; _repValues = new object[columnsLength]; _repIsDefault = new BitArray[columnsLength]; - _types = new ColumnType[columnsLength]; + _replaceTypes = new ColumnType[columnsLength]; var saver = new BinarySaver(Host, new BinarySaver.Arguments()); for (int i = 0; i < columnsLength; i++) { if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out ColumnType savedType, out object repValue)) throw Host.ExceptDecode(); - //IVAN: savedType contains _types[iinfo].ItemType not _types[iinfo] - _types[i] = savedType; + _replaceTypes[i] = savedType; if (savedType.IsVector) { // REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array. // It might be of value to consider storing the VBUffer in order to possibly benefit from sparsity. However, this would // necessitate a reimplementation of the FillValues code to accomodate sparse VBuffers. - object[] args = new object[] { repValue, _types[i], i }; + object[] args = new object[] { repValue, _replaceTypes[i], i }; Func, ColumnType, int, int[]> func = GetValuesArray; var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(savedType.ItemType.RawType); _repValues[i] = meth.Invoke(this, args); @@ -276,7 +277,7 @@ private NAReplaceTransform(IHost host, ModelLoadContext ctx) else _repValues[i] = repValue; - Host.Assert(repValue.GetType() == _types[i].RawType || repValue.GetType() == _types[i].ItemType.RawType); + Host.Assert(repValue.GetType() == _replaceTypes[i].RawType || repValue.GetType() == _replaceTypes[i].ItemType.RawType); } } @@ -328,7 +329,7 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj switch (kind) { case ReplacementKind.SpecifiedValue: - repValues[iinfo] = GetSpecifiedValue(columns[iinfo].ReplacementString, _types[iinfo], isNa); + repValues[iinfo] = GetSpecifiedValue(columns[iinfo].ReplacementString, _replaceTypes[iinfo], isNa); break; case ReplacementKind.DefaultValue: repValues[iinfo] = GetDefault(type); @@ -537,22 +538,22 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int iinfo = 0; iinfo < _types.Length; iinfo++) + for (int iinfo = 0; iinfo < _replaceTypes.Length; iinfo++) { var repValue = _repValues[iinfo]; - var repType = _types[iinfo].ItemType; + var repType = _replaceTypes[iinfo].ItemType; if (_repIsDefault[iinfo] != null) { Host.Assert(repValue is Array); Func> function = CreateVBuffer; - var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_types[iinfo].ItemType.RawType); + var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repType.RawType); repValue = method.Invoke(this, new object[] { _repValues[iinfo] }); - repType = _types[iinfo]; + repType = _replaceTypes[iinfo]; } Host.Assert(!(repValue is Array)); object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, repValue }; Action func = WriteTypeAndValue; - Host.Assert(repValue.GetType() == _types[iinfo].RawType || repValue.GetType() == _types[iinfo].ItemType.RawType); + Host.Assert(repValue.GetType() == _replaceTypes[iinfo].RawType || repValue.GetType() == _replaceTypes[iinfo].ItemType.RawType); var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); meth.Invoke(this, args); } @@ -580,6 +581,7 @@ public ColInfo(string name, string source, ColumnType type) private readonly NAReplaceTransform _parent; private readonly ColInfo[] _infos; + private readonly ColumnType[] _types; // The isNA delegates, parallel to Infos. private readonly Delegate[] _isNAs; public bool CanSaveOnnx => true; @@ -589,10 +591,30 @@ public Mapper(NAReplaceTransform parent, ISchema inputSchema) { _parent = parent; _infos = CreateInfos(inputSchema); + _types = new ColumnType[_parent.ColumnPairs.Length]; _isNAs = new Delegate[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { var type = _infos[i].TypeSrc; + if (type.IsVector) + type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + var repType = _parent._repIsDefault[i] != null ? _parent._replaceTypes[i] : _parent._replaceTypes[i].ItemType; + if (!type.ItemType.Equals(repType.ItemType)) + throw Host.ExceptParam(nameof(InputSchema), "Column '{0}' item type '{1}' does not match expected ColumnType of '{2}'", + _infos[i].Source, _parent._replaceTypes[i].ItemType.ToString(), _infos[i].TypeSrc); + // If type is a vector and the value is not either a scalar or a vector of the same size, throw an error. + if (repType.IsVector) + { + if (!type.IsVector) + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' cannot be a vector when Columntype is a scalar of type '{2}'", + _infos[i].Source, repType, type); + if (!type.IsKnownSizeVector) + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' is unknown size vector '{1}' must be a scalar instead of type '{2}'", _infos[i].Source, type, parent._replaceTypes[i]); + if (type.VectorSize != repType.VectorSize) + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' must be a scalar or a vector of the same size as Columntype '{2}'", + _infos[i].Source, repType, type); + } + _types[i] = type; _isNAs[i] = _parent.GetIsNADelegate(type); } } @@ -621,7 +643,7 @@ public override RowMapperColumnInfo[] GetOutputColumns() Host.Assert(colIndex >= 0); var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); var meta = RowColumnUtils.GetMetadataAsRow(InputSchema, colIndex, x => x == MetadataUtils.Kinds.SlotNames || x == MetadataUtils.Kinds.IsNormalized); - result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _parent._types[i], meta); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], meta); } return result; } @@ -887,7 +909,7 @@ public void SaveAsOnnx(OnnxContext ctx) } if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), - ctx.AddIntermediateVariable(_parent._types[iinfo], info.Name))) + ctx.AddIntermediateVariable(_parent._replaceTypes[iinfo], info.Name))) { ctx.RemoveColumn(info.Name, true); } @@ -1061,43 +1083,43 @@ public override IEstimator Reconcile(IHostEnvironment env, } } - public static Scalar NAReplace(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static Scalar ReplaceWithMissingValues(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); - return new OutScalar(input, new Config( replacementMode, imputeBySlot)); + return new OutScalar(input, new Config(replacementMode, imputeBySlot)); } - public static Scalar NAReplace(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static Scalar ReplaceWithMissingValues(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input, new Config(replacementMode, imputeBySlot)); } - public static Scalar NAReplace (this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static Scalar ReplaceWithMissingValues(this Scalar input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input, new Config(replacementMode, imputeBySlot)); } - public static Vector NAReplace(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceWithMissingValues(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); } - public static Vector NAReplace(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceWithMissingValues(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); } - public static Vector NAReplace(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceWithMissingValues(this Vector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); } - public static VarVector NAReplace(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static VarVector ReplaceWithMissingValues(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); @@ -1109,7 +1131,7 @@ public static VarVector NAReplace(this VarVector input, NAReplac return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); } - public static VarVector NAReplace(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) + public static VarVector ReplaceWithMissingValues(this VarVector input, NAReplaceTransform.ColumnInfo.ReplacementMode replacementMode = NAReplaceEstimator.Defaults.ReplacementMode, bool imputeBySlot = NAReplaceEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input, new Config(replacementMode, imputeBySlot)); diff --git a/test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv b/test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv new file mode 100644 index 0000000000..f9ba7bf9f4 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/NAReplace/featurized.tsv @@ -0,0 +1,14 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=A:TX:0 +#@ col=B:R4:1 +#@ col=C:R8:2 +#@ col=D:TX:3-6 +#@ col=E:R4:7-10 +#@ } +A B C 8 0:"" +5 5 5 5 1 1 1 5 1 1 1 +5 5 5 5 4 4 5 5 4 4 5 +3 3 3 3 1 1 1 3 1 1 1 +6 6 6 6 8 8 1 6 8 8 1 diff --git a/test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv b/test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv new file mode 100644 index 0000000000..f9ba7bf9f4 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/NAReplace/featurized.tsv @@ -0,0 +1,14 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=A:TX:0 +#@ col=B:R4:1 +#@ col=C:R8:2 +#@ col=D:TX:3-6 +#@ col=E:R4:7-10 +#@ } +A B C 8 0:"" +5 5 5 5 1 1 1 5 1 1 1 +5 5 5 5 4 4 5 5 4 4 5 +3 3 3 3 1 1 1 3 1 1 1 +6 6 6 6 8 8 1 6 8 8 1 diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs index 369668dd55..969c5198b8 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -72,12 +72,12 @@ public void NAReplaceStatic() var est = data.MakeNewEstimator(). Append(row => ( - A: row.ScalarString.NAReplace(), - B: row.ScalarFloat.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Maximum), - C: row.ScalarDouble.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), - D: row.VectorString.NAReplace(), - E: row.VectorFloat.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), - F: row.VectorDoulbe.NAReplace(NAReplaceTransform.ColumnInfo.ReplacementMode.Minimum) + A: row.ScalarString.ReplaceWithMissingValues(), + B: row.ScalarFloat.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Maximum), + C: row.ScalarDouble.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + D: row.VectorString.ReplaceWithMissingValues(), + E: row.VectorFloat.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Mean), + F: row.VectorDoulbe.ReplaceWithMissingValues(NAReplaceTransform.ColumnInfo.ReplacementMode.Minimum) )); TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); @@ -93,7 +93,6 @@ public void NAReplaceStatic() CheckEquality("NAReplace", "featurized.tsv"); Done(); - Done(); } [Fact] From 7839b60fae63ac725266257de8e77f8f53ba3c96 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 18 Sep 2018 17:28:08 -0700 Subject: [PATCH 8/8] merge with master --- src/Microsoft.ML.Transforms/NAReplaceTransform.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index b106b7f376..d8cd698658 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -24,7 +24,7 @@ [assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), typeof(NAReplaceTransform.Arguments), typeof(SignatureDataTransform), NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName, "NAReplace", NAReplaceTransform.ShortName, DocName = "transform/NAHandle.md")] -[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataView), typeof(NAReplaceTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(NAReplaceTransform.Summary, typeof(IDataTransform), typeof(NAReplaceTransform), null, typeof(SignatureLoadDataTransform), NAReplaceTransform.FriendlyName, NAReplaceTransform.LoadName)] [assembly: LoadableClass(NAReplaceTransform.Summary, typeof(NAReplaceTransform), null, typeof(SignatureLoadModel),