From a1b76063e7fa75f925010c4a659588f7d25fd41e Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 26 Feb 2019 09:25:01 -0800 Subject: [PATCH 1/8] Remove IHostEnviroment from GetColumn's argument list --- .../Dynamic/FeatureSelectionTransform.cs | 4 ++-- .../Dynamic/KeyToValueValueToKey.cs | 6 ++--- .../Dynamic/LdaTransform.cs | 2 +- .../Dynamic/NgramExtraction.cs | 4 ++-- .../Dynamic/Normalizer.cs | 6 ++--- .../Dynamic/ProjectionTransforms.cs | 6 ++--- .../Dynamic/StopWordRemoverTransform.cs | 6 ++--- .../Dynamic/TextTransform.cs | 4 ++-- .../Transforms/Projection/VectorWhiten.cs | 2 +- .../VectorWhitenWithColumnOptions.cs | 2 +- .../Dynamic/WordEmbeddingTransform.cs | 6 ++--- .../Static/FeatureSelectionTransform.cs | 4 ++-- .../Utilities/ColumnCursor.cs | 22 ++++++++++++++++--- src/Microsoft.ML.StaticPipe/DataView.cs | 2 +- .../StaticPipeTests.cs | 8 +++---- test/Microsoft.ML.Tests/CachingTests.cs | 8 +++---- test/Microsoft.ML.Tests/RangeFilterTests.cs | 4 ++-- .../Api/CookbookSamples/CookbookSamples.cs | 2 +- .../CookbookSamplesDynamicApi.cs | 16 +++++++------- .../Scenarios/Api/Estimators/Visibility.cs | 6 ++--- .../Scenarios/Api/TestApi.cs | 2 +- .../Scenarios/GetColumnTests.cs | 20 ++++++++--------- .../SymSgdClassificationTests.cs | 6 ++--- .../Transformers/ValueMappingTests.cs | 2 +- 24 files changed, 83 insertions(+), 67 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs index 1f6b4f9b55..e003e34ee1 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs @@ -82,8 +82,8 @@ public static void Example() }; // Print the data that results from the transformations. - var countSelectColumn = transformedData.GetColumn>(ml, "FeaturesCountSelect"); - var MISelectColumn = transformedData.GetColumn>(ml, "FeaturesMISelect"); + var countSelectColumn = transformedData.GetColumn>("FeaturesCountSelect"); + var MISelectColumn = transformedData.GetColumn>("FeaturesMISelect"); printHelper("FeaturesCountSelect", countSelectColumn); printHelper("FeaturesMISelect", MISelectColumn); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs index 28e5fc918e..522ea886f9 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs @@ -60,7 +60,7 @@ public static void Example() }; // Preview of the DefaultKeys column obtained after processing the input. - var defaultColumn = transformedData_default.GetColumn>(ml, defaultColumnName); + var defaultColumn = transformedData_default.GetColumn>(defaultColumnName); printHelper(defaultColumnName, defaultColumn); // DefaultKeys column obtained post-transformation. @@ -71,7 +71,7 @@ public static void Example() // 9 10 11 12 13 6 // Previewing the CustomizedKeys column obtained after processing the input. - var customizedColumn = transformedData_customized.GetColumn>(ml, customizedColumnName); + var customizedColumn = transformedData_customized.GetColumn>(customizedColumnName); printHelper(customizedColumnName, customizedColumn); // CustomizedKeys column obtained post-transformation. @@ -87,7 +87,7 @@ public static void Example() transformedData_default = pipeline.Fit(trainData).Transform(trainData); // Preview of the DefaultColumnName column obtained. - var originalColumnBack = transformedData_default.GetColumn>>(ml, defaultColumnName); + var originalColumnBack = transformedData_default.GetColumn>>(defaultColumnName); foreach (var row in originalColumnBack) { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs index b55c1e0934..a3710ce5da 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs @@ -37,7 +37,7 @@ public static void Example() var transformed_data = transformer.Transform(trainData); // Column obtained after processing the input. - var ldaFeaturesColumn = transformed_data.GetColumn>(ml, ldaFeatures); + var ldaFeaturesColumn = transformed_data.GetColumn>(ldaFeatures); Console.WriteLine($"{ldaFeatures} column obtained post-transformation."); foreach (var featureRow in ldaFeaturesColumn) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs index 310763b4e5..aa2d539b16 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs @@ -53,7 +53,7 @@ public static void NgramTransform() // Preview of the CharsUnigrams column obtained after processing the input. VBuffer> slotNames = default; transformedData_onechars.Schema["CharsUnigrams"].GetSlotNames(ref slotNames); - var charsOneGramColumn = transformedData_onechars.GetColumn>(ml, "CharsUnigrams"); + var charsOneGramColumn = transformedData_onechars.GetColumn>("CharsUnigrams"); printHelper("CharsUnigrams", charsOneGramColumn, slotNames); // CharsUnigrams column obtained post-transformation. @@ -61,7 +61,7 @@ public static void NgramTransform() // 'e' - 1 '' - 2 'd' - 1 '=' - 4 'R' - 1 'U' - 1 'D' - 2 'E' - 1 'u' - 1 ',' - 1 '2' - 1 // 'B' - 0 'e' - 6 's' - 3 't' - 6 '' - 9 'g' - 2 'a' - 2 'm' - 2 'I' - 0 ''' - 0 'v' - 0 ... // Preview of the CharsTwoGrams column obtained after processing the input. - var charsTwoGramColumn = transformedData_twochars.GetColumn>(ml, "CharsTwograms"); + var charsTwoGramColumn = transformedData_twochars.GetColumn>("CharsTwograms"); transformedData_twochars.Schema["CharsTwograms"].GetSlotNames(ref slotNames); printHelper("CharsTwograms", charsTwoGramColumn, slotNames); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs index 7a97c318fb..d33b9dbf49 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs @@ -44,7 +44,7 @@ public static void Example() var transformedData = transformer.Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var normalizedColumn = transformedData.GetColumn(ml, "Induced"); + var normalizedColumn = transformedData.GetColumn("Induced"); // A small printing utility. Action> printHelper = (colName, column) => @@ -72,8 +72,8 @@ public static void Example() var multiColtransformedData = multiColtransformer.Transform(trainData); // Getting the newly created columns. - var normalizedInduced = multiColtransformedData.GetColumn(ml, "LogInduced"); - var normalizedSpont = multiColtransformedData.GetColumn(ml, "LogSpontaneous"); + var normalizedInduced = multiColtransformedData.GetColumn("LogInduced"); + var normalizedSpont = multiColtransformedData.GetColumn("LogSpontaneous"); printHelper("LogInduced", normalizedInduced); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs index 7bd5a41f36..faf3911878 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs @@ -41,7 +41,7 @@ public static void Example() // The transformed (projected) data. var transformedData = rffPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var randomFourier = transformedData.GetColumn>(ml, nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var randomFourier = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), randomFourier); @@ -59,7 +59,7 @@ public static void Example() // The transformed (projected) data. transformedData = lpNormalizePipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var lpNormalize= transformedData.GetColumn>(ml, nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var lpNormalize= transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), lpNormalize); @@ -77,7 +77,7 @@ public static void Example() // The transformed (projected) data. transformedData = gcNormalizePipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var gcNormalize = transformedData.GetColumn>(ml, nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var gcNormalize = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), gcNormalize); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs index 1b8a04ad66..9c04214f96 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs @@ -54,14 +54,14 @@ public static void Example() }; // Preview the result of breaking string into array of words. - var originalText = transformedDataDefault.GetColumn>>(ml, originalTextColumnName); + var originalText = transformedDataDefault.GetColumn>>(originalTextColumnName); printHelper(originalTextColumnName, originalText); // Best|game|I've|ever|played.| // == RUDE ==| Dude,| 2 | // Until | the | next | game,| this |is| the | best | Xbox | game!| // Preview the result of cleaning with default stop word remover. - var defaultRemoverData = transformedDataDefault.GetColumn>>(ml, "DefaultRemover"); + var defaultRemoverData = transformedDataDefault.GetColumn>>("DefaultRemover"); printHelper("DefaultRemover", defaultRemoverData); // Best|game|I've|played.| // == RUDE ==| Dude,| 2 | @@ -70,7 +70,7 @@ public static void Example() // Preview the result of cleaning with default customized stop word remover. - var customizeRemoverData = transformedDataCustomized.GetColumn>>(ml, "RemovedWords"); + var customizeRemoverData = transformedDataCustomized.GetColumn>>("RemovedWords"); printHelper("RemovedWords", customizeRemoverData); // Best|game|I've|ever|played.| diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs index 044a744b20..3a82a4f827 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs @@ -58,7 +58,7 @@ public static void Example() }; // Preview of the DefaultTextFeatures column obtained after processing the input. - var defaultColumn = transformedData_default.GetColumn>(ml, defaultColumnName); + var defaultColumn = transformedData_default.GetColumn>(defaultColumnName); printHelper(defaultColumnName, defaultColumn); // DefaultTextFeatures column obtained post-transformation. @@ -68,7 +68,7 @@ public static void Example() // 0 0.1230915 0.1230915 0.1230915 0.1230915 0.246183 0.246183 0.246183 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.1230915 0 0 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.3692745 0.246183 0.246183 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.246183 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.2886751 0 0 0 0 0 0 0 0.2886751 0.5773503 0.2886751 0.2886751 0.2886751 0.2886751 0.2886751 0.2886751 // Preview of the CustomizedTextFeatures column obtained after processing the input. - var customizedColumn = transformedData_customized.GetColumn>(ml, customizedColumnName); + var customizedColumn = transformedData_customized.GetColumn>(customizedColumnName); printHelper(customizedColumnName, customizedColumn); // CustomizedTextFeatures column obtained post-transformation. diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs index 538af5d6eb..05ca3832cb 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs @@ -45,7 +45,7 @@ public static void Example() // The transformed (projected) data. var transformedData = whiteningPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var whitening = transformedData.GetColumn>(ml, nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var whitening = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), whitening); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs index 9b4de2274d..0d49e150ad 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs @@ -44,7 +44,7 @@ public static void Example() // The transformed (projected) data. var transformedData = whiteningPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var whitening = transformedData.GetColumn>(ml, nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var whitening = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), whitening); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs index 9d69f74e0c..6686be19e4 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs @@ -31,7 +31,7 @@ public static void Example() var wordsDataview = wordsPipeline.Fit(trainData).Transform(trainData); // Preview of the CleanWords column obtained after processing SentimentText. - var cleanWords = wordsDataview.GetColumn[]>(ml, "CleanWords"); + var cleanWords = wordsDataview.GetColumn[]>("CleanWords"); Console.WriteLine($" CleanWords column obtained post-transformation."); foreach (var featureRow in cleanWords) { @@ -86,7 +86,7 @@ public static void Example() // And do all required transformations. var embeddingDataview = pipeline.Fit(wordsDataview).Transform(wordsDataview); - var customEmbeddings = embeddingDataview.GetColumn(ml, "CustomEmbeddings"); + var customEmbeddings = embeddingDataview.GetColumn("CustomEmbeddings"); printEmbeddings("GloveEmbeddings", customEmbeddings); // -1 -2 -3 -0.5 -1 8.5 0 0 20 @@ -98,7 +98,7 @@ public static void Example() // Second set of 3 floats in output represent average (for each dimension) for extracted values. // Third set of 3 floats in output represent maximum values (for each dimension) for extracted values. // Preview of GloveEmbeddings. - var gloveEmbeddings = embeddingDataview.GetColumn(ml, "GloveEmbeddings"); + var gloveEmbeddings = embeddingDataview.GetColumn("GloveEmbeddings"); printEmbeddings("GloveEmbeddings", gloveEmbeddings); // 0.23166 0.048825 0.26878 -1.3945 -0.86072 -0.026778 0.84075 -0.81987 -1.6681 -1.0658 -0.30596 0.50974 ... //-0.094905 0.61109 0.52546 - 0.2516 0.054786 0.022661 1.1801 0.33329 - 0.85388 0.15471 - 0.5984 0.4364 ... diff --git a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs index 07b6060785..5f959b6b88 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs @@ -83,8 +83,8 @@ public static void FeatureSelectionTransform() }; // Print the data that results from the transformations. - var countSelectColumn = transformedData.AsDynamic.GetColumn>(ml, "FeaturesCountSelect"); - var MISelectColumn = transformedData.AsDynamic.GetColumn>(ml, "FeaturesMISelect"); + var countSelectColumn = transformedData.AsDynamic.GetColumn>("FeaturesCountSelect"); + var MISelectColumn = transformedData.AsDynamic.GetColumn>("FeaturesMISelect"); printHelper("FeaturesCountSelect", countSelectColumn); printHelper("FeaturesMISelect", MISelectColumn); diff --git a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs index b1c65e42a8..ed4619bfdf 100644 --- a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs +++ b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs @@ -20,11 +20,10 @@ public static class ColumnCursorExtensions /// /// The type of the values. This must match the actual column type. /// The data view to get the column from. - /// The current host environment. /// The name of the column to extract. - public static IEnumerable GetColumn(this IDataView data, IHostEnvironment env, string columnName) + public static IEnumerable GetColumn(this IDataView data, string columnName) { - Contracts.CheckValue(env, nameof(env)); + var env = RetrieveHost(data); env.CheckValue(data, nameof(data)); env.CheckNonEmpty(columnName, nameof(columnName)); @@ -79,6 +78,23 @@ public static IEnumerable GetColumn(this IDataView data, IHostEnvironment throw env.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}."); } + /// + /// Return a assigned in a given implementation. + /// + /// an implementation. + private static IHost RetrieveHost(IDataView data) + { + // Search for the first (if there are multiples) field typed to IHost. + var fields = data.GetType().GetFields(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var fieldInfo = fields.Where((field, index) => field.FieldType == typeof(IHost)).FirstOrDefault(); + + // Check if a IHost really gets retrieved. + string errorMessage = nameof(data) + " should contains a field of " + typeof(IHost); + Contracts.CheckValue(fieldInfo, nameof(data), errorMessage); + + return (IHost)fieldInfo.GetValue(data); + } + private static IEnumerable GetColumnDirect(IDataView data, int col) { Contracts.AssertValue(data); diff --git a/src/Microsoft.ML.StaticPipe/DataView.cs b/src/Microsoft.ML.StaticPipe/DataView.cs index 6e03a029ae..88f76b2bd2 100644 --- a/src/Microsoft.ML.StaticPipe/DataView.cs +++ b/src/Microsoft.ML.StaticPipe/DataView.cs @@ -49,7 +49,7 @@ private static IEnumerable GetColumnCore(DataView da var indexer = StaticPipeUtils.GetIndexer(data); string columnName = indexer.Get(column(indexer.Indices)); - return data.AsDynamic.GetColumn(env, columnName); + return data.AsDynamic.GetColumn(columnName); } public static IEnumerable GetColumn(this DataView data, Func> column) diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 059d075370..9e71ad1afe 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -793,10 +793,10 @@ public void NAIndicatorStatic() IDataView newData = ml.Data.TakeRows(est.Fit(data).Transform(data).AsDynamic, 4); Assert.NotNull(newData); - bool[] ScalarFloat = newData.GetColumn(ml, "A").ToArray(); - bool[] ScalarDouble = newData.GetColumn(ml, "B").ToArray(); - bool[][] VectorFloat = newData.GetColumn(ml, "C").ToArray(); - bool[][] VectorDoulbe = newData.GetColumn(ml, "D").ToArray(); + bool[] ScalarFloat = newData.GetColumn("A").ToArray(); + bool[] ScalarDouble = newData.GetColumn("B").ToArray(); + bool[][] VectorFloat = newData.GetColumn("C").ToArray(); + bool[][] VectorDoulbe = newData.GetColumn("D").ToArray(); Assert.NotNull(ScalarFloat); Assert.NotNull(ScalarDouble); diff --git a/test/Microsoft.ML.Tests/CachingTests.cs b/test/Microsoft.ML.Tests/CachingTests.cs index 8fd8519889..fd505cb9ee 100644 --- a/test/Microsoft.ML.Tests/CachingTests.cs +++ b/test/Microsoft.ML.Tests/CachingTests.cs @@ -66,15 +66,15 @@ public void CacheTest() { var src = Enumerable.Range(0, 100).Select(c => new MyData()).ToArray(); var data = ML.Data.LoadFromEnumerable(src); - data.GetColumn(ML, "Features").ToArray(); - data.GetColumn(ML, "Features").ToArray(); + data.GetColumn("Features").ToArray(); + data.GetColumn("Features").ToArray(); Assert.True(src.All(x => x.AccessCount == 2)); src = Enumerable.Range(0, 100).Select(c => new MyData()).ToArray(); data = ML.Data.LoadFromEnumerable(src); data = ML.Data.Cache(data); - data.GetColumn(ML, "Features").ToArray(); - data.GetColumn(ML, "Features").ToArray(); + data.GetColumn("Features").ToArray(); + data.GetColumn("Features").ToArray(); Assert.True(src.All(x => x.AccessCount == 1)); } diff --git a/test/Microsoft.ML.Tests/RangeFilterTests.cs b/test/Microsoft.ML.Tests/RangeFilterTests.cs index 0cc1b2d5f7..f66c5e9452 100644 --- a/test/Microsoft.ML.Tests/RangeFilterTests.cs +++ b/test/Microsoft.ML.Tests/RangeFilterTests.cs @@ -26,12 +26,12 @@ public void RangeFilterTest() var data = builder.GetDataView(); var data1 = ML.Data.FilterRowsByColumn(data, "Floats", upperBound: 2.8); - var cnt = data1.GetColumn(ML, "Floats").Count(); + var cnt = data1.GetColumn("Floats").Count(); Assert.Equal(2L, cnt); data = ML.Transforms.Conversion.Hash("Key", "Strings", hashBits: 20).Fit(data).Transform(data); var data2 = ML.Data.FilterRowsByKeyColumnFraction(data, "Key", upperBound: 0.5); - cnt = data2.GetColumn(ML, "Floats").Count(); + cnt = data2.GetColumn("Floats").Count(); Assert.Equal(1L, cnt); } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index c794c23f11..64b1ad876b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -79,7 +79,7 @@ private void IntermediateData(string dataPath) // The same extension method also applies to the dynamic-typed data, except you have to // specify the column name and type: var dynamicData = transformedData.AsDynamic; - var sameFeatureColumns = dynamicData.GetColumn(mlContext, "AllFeatures") + var sameFeatureColumns = dynamicData.GetColumn("AllFeatures") .Take(20).ToArray(); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 3b760dae6d..c7f0ab42d1 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -61,7 +61,7 @@ private void IntermediateData(string dataPath) // This will give the entire dataset: make sure to only take several row // in case the dataset is huge. The is similar to the static API, except // you have to specify the column name and type. - var featureColumns = transformedData.GetColumn(mlContext, "AllFeatures") + var featureColumns = transformedData.GetColumn("AllFeatures") .Take(20).ToArray(); } @@ -251,7 +251,7 @@ private void NormalizationWorkout(string dataPath) var normalizedData = pipeline.Fit(trainData).Transform(trainData); // Inspect one column of the resulting dataset. - var meanVarValues = normalizedData.GetColumn(mlContext, "MeanVarNormalized").ToArray(); + var meanVarValues = normalizedData.GetColumn("MeanVarNormalized").ToArray(); } [Fact] @@ -289,7 +289,7 @@ private void TextFeaturizationOn(string dataPath) var data = loader.Load(dataPath); // Inspect the message texts that are read from the file. - var messageTexts = data.GetColumn(mlContext, "Message").Take(20).ToArray(); + var messageTexts = data.GetColumn("Message").Take(20).ToArray(); // Apply various kinds of text operations supported by ML.NET. var pipeline = @@ -321,8 +321,8 @@ private void TextFeaturizationOn(string dataPath) var transformedData = pipeline.Fit(data).Transform(data); // Inspect some columns of the resulting dataset. - var embeddings = transformedData.GetColumn(mlContext, "Embeddings").Take(10).ToArray(); - var unigrams = transformedData.GetColumn(mlContext, "BagOfWords").Take(10).ToArray(); + var embeddings = transformedData.GetColumn("Embeddings").Take(10).ToArray(); + var unigrams = transformedData.GetColumn("BagOfWords").Take(10).ToArray(); } [Fact(Skip = "This test is running for one minute")] @@ -361,7 +361,7 @@ private void CategoricalFeaturizationOn(params string[] dataPath) var data = loader.Load(dataPath); // Inspect the first 10 records of the categorical columns to check that they are correctly read. - var catColumns = data.GetColumn(mlContext, "CategoricalFeatures").Take(10).ToArray(); + var catColumns = data.GetColumn("CategoricalFeatures").Take(10).ToArray(); // Build several alternative featurization pipelines. var pipeline = @@ -377,8 +377,8 @@ private void CategoricalFeaturizationOn(params string[] dataPath) var transformedData = pipeline.Fit(data).Transform(data); // Inspect some columns of the resulting dataset. - var categoricalBags = transformedData.GetColumn(mlContext, "CategoricalBag").Take(10).ToArray(); - var workclasses = transformedData.GetColumn(mlContext, "WorkclassOneHotTrimmed").Take(10).ToArray(); + var categoricalBags = transformedData.GetColumn("CategoricalBag").Take(10).ToArray(); + var workclasses = transformedData.GetColumn("WorkclassOneHotTrimmed").Take(10).ToArray(); // Of course, if we want to train the model, we will need to compose a single float vector of all the features. // Here's how we could do this: diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs index 2e11a54d9d..43ce72912d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs @@ -31,9 +31,9 @@ void Visibility() var src = new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)); var data = pipeline.Fit(src).Load(src); - var textColumn = data.GetColumn(ml, "SentimentText").Take(20); - var transformedTextColumn = data.GetColumn(ml, "Features_TransformedText").Take(20); - var features = data.GetColumn(ml, "Features").Take(20); + var textColumn = data.GetColumn("SentimentText").Take(20); + var transformedTextColumn = data.GetColumn("Features_TransformedText").Take(20); + var features = data.GetColumn("Features").Take(20); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index dda1652a2d..d0a0e7e30a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -306,7 +306,7 @@ public void TestTrainTestSplit() // this function will accept dataview and return content of "Workclass" column as List of strings. Func> getWorkclass = (IDataView view) => { - return view.GetColumn>(mlContext, "Workclass").Select(x => x.ToString()).ToList(); + return view.GetColumn>("Workclass").Select(x => x.ToString()).ToList(); }; // Let's test what train test properly works with seed. diff --git a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs index 7b2043e9e3..04c3ce390d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs @@ -53,18 +53,18 @@ public void TestGetColumn() } }; - var enum1 = data.AsDynamic.GetColumn(env, "floatScalar").ToArray(); - var enum2 = data.AsDynamic.GetColumn(env, "floatVector").ToArray(); - var enum3 = data.AsDynamic.GetColumn>(env, "floatVector").ToArray(); + var enum1 = data.AsDynamic.GetColumn("floatScalar").ToArray(); + var enum2 = data.AsDynamic.GetColumn("floatVector").ToArray(); + var enum3 = data.AsDynamic.GetColumn>("floatVector").ToArray(); - var enum4 = data.AsDynamic.GetColumn(env, "stringScalar").ToArray(); - var enum5 = data.AsDynamic.GetColumn(env, "stringVector").ToArray(); + var enum4 = data.AsDynamic.GetColumn("stringScalar").ToArray(); + var enum5 = data.AsDynamic.GetColumn("stringVector").ToArray(); - mustFail(() => data.AsDynamic.GetColumn(env, "floatScalar")); - mustFail(() => data.AsDynamic.GetColumn(env, "floatVector")); - mustFail(() => data.AsDynamic.GetColumn(env, "floatScalar")); - mustFail(() => data.AsDynamic.GetColumn(env, "floatScalar")); - mustFail(() => data.AsDynamic.GetColumn(env, "floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatVector")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); // Static types. var enum8 = data.GetColumn(r => r.floatScalar); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs index 27e1558f88..a09407729a 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs @@ -42,9 +42,9 @@ public void TestEstimatorSymSgdInitPredictor() var outNoInitData = notInitPredictor.Transform(transformedData); int numExamples = 10; - var col1 = data.GetColumn(Env, "Score").Take(numExamples).ToArray(); - var col2 = outInitData.GetColumn(Env, "Score").Take(numExamples).ToArray(); - var col3 = outNoInitData.GetColumn(Env, "Score").Take(numExamples).ToArray(); + var col1 = data.GetColumn("Score").Take(numExamples).ToArray(); + var col2 = outInitData.GetColumn("Score").Take(numExamples).ToArray(); + var col3 = outNoInitData.GetColumn("Score").Take(numExamples).ToArray(); bool col12Diff = default; bool col23Diff = default; diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index b7b99f2291..4aadf5852e 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -647,7 +647,7 @@ void TestValueMapBackCompatTermLookupKeyTypeValue() Assert.True(result.Schema[labelIdx].Type is KeyType); Assert.Equal((ulong)5, result.Schema[labelIdx].Type.GetItemType().GetKeyCount()); - var t = result.GetColumn(Env, "Label"); + var t = result.GetColumn("Label"); uint s = t.First(); Assert.Equal((uint)3, s); } From c5826d761b4ba94f5db5c189499fe212274c45b3 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 26 Feb 2019 09:59:47 -0800 Subject: [PATCH 2/8] Replace reflection with Contracts --- .../Utilities/ColumnCursor.cs | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs index ed4619bfdf..d2c55158b6 100644 --- a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs +++ b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs @@ -23,12 +23,11 @@ public static class ColumnCursorExtensions /// The name of the column to extract. public static IEnumerable GetColumn(this IDataView data, string columnName) { - var env = RetrieveHost(data); - env.CheckValue(data, nameof(data)); - env.CheckNonEmpty(columnName, nameof(columnName)); + Contracts.CheckValue(data, nameof(data)); + Contracts.CheckNonEmpty(columnName, nameof(columnName)); if (!data.Schema.TryGetColumnIndex(columnName, out int col)) - throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName); + throw Contracts.ExceptSchemaMismatch(nameof(columnName), "input", columnName); // There are two decisions that we make here: // - Is the T an array type? @@ -56,7 +55,7 @@ public static IEnumerable GetColumn(this IDataView data, string columnName { // Output is an array type. if (!(colType is VectorType colVectorType)) - throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar"); + throw Contracts.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar"); var elementType = typeof(T).GetElementType(); if (elementType == colVectorType.ItemType.RawType) { @@ -75,24 +74,7 @@ public static IEnumerable GetColumn(this IDataView data, string columnName } // Fall through to the failure. } - throw env.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}."); - } - - /// - /// Return a assigned in a given implementation. - /// - /// an implementation. - private static IHost RetrieveHost(IDataView data) - { - // Search for the first (if there are multiples) field typed to IHost. - var fields = data.GetType().GetFields(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); - var fieldInfo = fields.Where((field, index) => field.FieldType == typeof(IHost)).FirstOrDefault(); - - // Check if a IHost really gets retrieved. - string errorMessage = nameof(data) + " should contains a field of " + typeof(IHost); - Contracts.CheckValue(fieldInfo, nameof(data), errorMessage); - - return (IHost)fieldInfo.GetValue(data); + throw Contracts.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}."); } private static IEnumerable GetColumnDirect(IDataView data, int col) From 99eeae27b62931c8f940cce6ca2e66b566493d7a Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 26 Feb 2019 11:20:49 -0800 Subject: [PATCH 3/8] Use Schema.Column instead of string to identify which column to load --- .../Dynamic/FeatureSelectionTransform.cs | 4 +-- .../Dynamic/KeyToValueValueToKey.cs | 6 ++-- .../Dynamic/LdaTransform.cs | 2 +- .../Dynamic/NgramExtraction.cs | 30 +++++++++---------- .../Dynamic/Normalizer.cs | 6 ++-- .../Dynamic/ProjectionTransforms.cs | 6 ++-- .../Dynamic/StopWordRemoverTransform.cs | 6 ++-- .../Dynamic/TextTransform.cs | 4 +-- .../Transforms/Projection/VectorWhiten.cs | 2 +- .../VectorWhitenWithColumnOptions.cs | 2 +- .../Dynamic/WordEmbeddingTransform.cs | 6 ++-- .../Static/FeatureSelectionTransform.cs | 4 +-- .../Utilities/ColumnCursor.cs | 30 +++++++++++-------- src/Microsoft.ML.StaticPipe/DataView.cs | 3 +- .../StaticPipeTests.cs | 8 ++--- test/Microsoft.ML.Tests/CachingTests.cs | 8 ++--- test/Microsoft.ML.Tests/RangeFilterTests.cs | 4 +-- .../Api/CookbookSamples/CookbookSamples.cs | 2 +- .../CookbookSamplesDynamicApi.cs | 16 +++++----- .../Scenarios/Api/Estimators/Visibility.cs | 6 ++-- .../Scenarios/Api/TestApi.cs | 2 +- .../Scenarios/GetColumnTests.cs | 20 ++++++------- .../SymSgdClassificationTests.cs | 6 ++-- .../Transformers/ValueMappingTests.cs | 2 +- 24 files changed, 96 insertions(+), 89 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs index e003e34ee1..113b4794fb 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs @@ -82,8 +82,8 @@ public static void Example() }; // Print the data that results from the transformations. - var countSelectColumn = transformedData.GetColumn>("FeaturesCountSelect"); - var MISelectColumn = transformedData.GetColumn>("FeaturesMISelect"); + var countSelectColumn = transformedData.GetColumn>(transformedData.Schema["FeaturesCountSelect"]); + var MISelectColumn = transformedData.GetColumn>(transformedData.Schema["FeaturesMISelect"]); printHelper("FeaturesCountSelect", countSelectColumn); printHelper("FeaturesMISelect", MISelectColumn); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs index 522ea886f9..e3ea8971cb 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs @@ -60,7 +60,7 @@ public static void Example() }; // Preview of the DefaultKeys column obtained after processing the input. - var defaultColumn = transformedData_default.GetColumn>(defaultColumnName); + var defaultColumn = transformedData_default.GetColumn>(transformedData_default.Schema[defaultColumnName]); printHelper(defaultColumnName, defaultColumn); // DefaultKeys column obtained post-transformation. @@ -71,7 +71,7 @@ public static void Example() // 9 10 11 12 13 6 // Previewing the CustomizedKeys column obtained after processing the input. - var customizedColumn = transformedData_customized.GetColumn>(customizedColumnName); + var customizedColumn = transformedData_customized.GetColumn>(transformedData_customized.Schema[customizedColumnName]); printHelper(customizedColumnName, customizedColumn); // CustomizedKeys column obtained post-transformation. @@ -87,7 +87,7 @@ public static void Example() transformedData_default = pipeline.Fit(trainData).Transform(trainData); // Preview of the DefaultColumnName column obtained. - var originalColumnBack = transformedData_default.GetColumn>>(defaultColumnName); + var originalColumnBack = transformedData_default.GetColumn>>(transformedData_default.Schema[defaultColumnName]); foreach (var row in originalColumnBack) { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs index a3710ce5da..e564b68cb8 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs @@ -37,7 +37,7 @@ public static void Example() var transformed_data = transformer.Transform(trainData); // Column obtained after processing the input. - var ldaFeaturesColumn = transformed_data.GetColumn>(ldaFeatures); + var ldaFeaturesColumn = transformed_data.GetColumn>(transformed_data.Schema[ldaFeatures]); Console.WriteLine($"{ldaFeatures} column obtained post-transformation."); foreach (var featureRow in ldaFeaturesColumn) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs index aa2d539b16..fa3c6317bf 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs @@ -26,8 +26,8 @@ public static void NgramTransform() // A pipeline to tokenize text as characters and then combine them together into ngrams // The pipeline uses the default settings to featurize. - var charsPipeline = ml.Transforms.Text.TokenizeCharacters("Chars", "SentimentText", useMarkerCharacters:false); - var ngramOnePipeline = ml.Transforms.Text.ProduceNgrams("CharsUnigrams", "Chars", ngramLength:1); + var charsPipeline = ml.Transforms.Text.TokenizeCharacters("Chars", "SentimentText", useMarkerCharacters: false); + var ngramOnePipeline = ml.Transforms.Text.ProduceNgrams("CharsUnigrams", "Chars", ngramLength: 1); var ngramTwpPipeline = ml.Transforms.Text.ProduceNgrams("CharsTwograms", "Chars"); var oneCharsPipeline = charsPipeline.Append(ngramOnePipeline); var twoCharsPipeline = charsPipeline.Append(ngramTwpPipeline); @@ -38,22 +38,22 @@ public static void NgramTransform() // Small helper to print the text inside the columns, in the console. Action>, VBuffer>> printHelper = (columnName, column, names) => - { - Console.WriteLine($"{columnName} column obtained post-transformation."); - var slots = names.GetValues(); - foreach (var featureRow in column) - { - foreach (var item in featureRow.Items()) - Console.Write($"'{slots[item.Key]}' - {item.Value} "); - Console.WriteLine(""); - } + { + Console.WriteLine($"{columnName} column obtained post-transformation."); + var slots = names.GetValues(); + foreach (var featureRow in column) + { + foreach (var item in featureRow.Items()) + Console.Write($"'{slots[item.Key]}' - {item.Value} "); + Console.WriteLine(""); + } - Console.WriteLine("==================================================="); - }; + Console.WriteLine("==================================================="); + }; // Preview of the CharsUnigrams column obtained after processing the input. VBuffer> slotNames = default; transformedData_onechars.Schema["CharsUnigrams"].GetSlotNames(ref slotNames); - var charsOneGramColumn = transformedData_onechars.GetColumn>("CharsUnigrams"); + var charsOneGramColumn = transformedData_onechars.GetColumn>(transformedData_onechars.Schema["CharsUnigrams"]); printHelper("CharsUnigrams", charsOneGramColumn, slotNames); // CharsUnigrams column obtained post-transformation. @@ -61,7 +61,7 @@ public static void NgramTransform() // 'e' - 1 '' - 2 'd' - 1 '=' - 4 'R' - 1 'U' - 1 'D' - 2 'E' - 1 'u' - 1 ',' - 1 '2' - 1 // 'B' - 0 'e' - 6 's' - 3 't' - 6 '' - 9 'g' - 2 'a' - 2 'm' - 2 'I' - 0 ''' - 0 'v' - 0 ... // Preview of the CharsTwoGrams column obtained after processing the input. - var charsTwoGramColumn = transformedData_twochars.GetColumn>("CharsTwograms"); + var charsTwoGramColumn = transformedData_twochars.GetColumn>(transformedData_onechars.Schema["CharsUnigrams"]); transformedData_twochars.Schema["CharsTwograms"].GetSlotNames(ref slotNames); printHelper("CharsTwograms", charsTwoGramColumn, slotNames); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs index d33b9dbf49..149f4ea9ce 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs @@ -44,7 +44,7 @@ public static void Example() var transformedData = transformer.Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var normalizedColumn = transformedData.GetColumn("Induced"); + var normalizedColumn = transformedData.GetColumn(transformedData.Schema["Induced"]); // A small printing utility. Action> printHelper = (colName, column) => @@ -72,8 +72,8 @@ public static void Example() var multiColtransformedData = multiColtransformer.Transform(trainData); // Getting the newly created columns. - var normalizedInduced = multiColtransformedData.GetColumn("LogInduced"); - var normalizedSpont = multiColtransformedData.GetColumn("LogSpontaneous"); + var normalizedInduced = multiColtransformedData.GetColumn(multiColtransformedData.Schema["LogInduced"]); + var normalizedSpont = multiColtransformedData.GetColumn(multiColtransformedData.Schema["LogSpontaneous"]); printHelper("LogInduced", normalizedInduced); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs index faf3911878..f7df0a1155 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs @@ -41,7 +41,7 @@ public static void Example() // The transformed (projected) data. var transformedData = rffPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var randomFourier = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var randomFourier = transformedData.GetColumn>(transformedData.Schema[nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)]); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), randomFourier); @@ -59,7 +59,7 @@ public static void Example() // The transformed (projected) data. transformedData = lpNormalizePipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var lpNormalize= transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var lpNormalize= transformedData.GetColumn>(transformedData.Schema[nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)]); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), lpNormalize); @@ -77,7 +77,7 @@ public static void Example() // The transformed (projected) data. transformedData = gcNormalizePipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var gcNormalize = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var gcNormalize = transformedData.GetColumn>(transformedData.Schema[nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)]); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), gcNormalize); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs index 9c04214f96..9eee1ff76e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/StopWordRemoverTransform.cs @@ -54,14 +54,14 @@ public static void Example() }; // Preview the result of breaking string into array of words. - var originalText = transformedDataDefault.GetColumn>>(originalTextColumnName); + var originalText = transformedDataDefault.GetColumn>>(transformedDataDefault.Schema[originalTextColumnName]); printHelper(originalTextColumnName, originalText); // Best|game|I've|ever|played.| // == RUDE ==| Dude,| 2 | // Until | the | next | game,| this |is| the | best | Xbox | game!| // Preview the result of cleaning with default stop word remover. - var defaultRemoverData = transformedDataDefault.GetColumn>>("DefaultRemover"); + var defaultRemoverData = transformedDataDefault.GetColumn>>(transformedDataDefault.Schema["DefaultRemover"]); printHelper("DefaultRemover", defaultRemoverData); // Best|game|I've|played.| // == RUDE ==| Dude,| 2 | @@ -70,7 +70,7 @@ public static void Example() // Preview the result of cleaning with default customized stop word remover. - var customizeRemoverData = transformedDataCustomized.GetColumn>>("RemovedWords"); + var customizeRemoverData = transformedDataCustomized.GetColumn>>(transformedDataCustomized.Schema["RemovedWords"]); printHelper("RemovedWords", customizeRemoverData); // Best|game|I've|ever|played.| diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs index 3a82a4f827..675fd1218d 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs @@ -58,7 +58,7 @@ public static void Example() }; // Preview of the DefaultTextFeatures column obtained after processing the input. - var defaultColumn = transformedData_default.GetColumn>(defaultColumnName); + var defaultColumn = transformedData_default.GetColumn>(transformedData_default.Schema[defaultColumnName]); printHelper(defaultColumnName, defaultColumn); // DefaultTextFeatures column obtained post-transformation. @@ -68,7 +68,7 @@ public static void Example() // 0 0.1230915 0.1230915 0.1230915 0.1230915 0.246183 0.246183 0.246183 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.1230915 0 0 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.3692745 0.246183 0.246183 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.246183 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.1230915 0.2886751 0 0 0 0 0 0 0 0.2886751 0.5773503 0.2886751 0.2886751 0.2886751 0.2886751 0.2886751 0.2886751 // Preview of the CustomizedTextFeatures column obtained after processing the input. - var customizedColumn = transformedData_customized.GetColumn>(customizedColumnName); + var customizedColumn = transformedData_customized.GetColumn>(transformedData_customized.Schema[customizedColumnName]); printHelper(customizedColumnName, customizedColumn); // CustomizedTextFeatures column obtained post-transformation. diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs index 05ca3832cb..616e86255e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhiten.cs @@ -45,7 +45,7 @@ public static void Example() // The transformed (projected) data. var transformedData = whiteningPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var whitening = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var whitening = transformedData.GetColumn>(transformedData.Schema[nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)]); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), whitening); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs index 0d49e150ad..811c983768 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs @@ -44,7 +44,7 @@ public static void Example() // The transformed (projected) data. var transformedData = whiteningPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. - var whitening = transformedData.GetColumn>(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)); + var whitening = transformedData.GetColumn>(transformedData.Schema[nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features)]); printHelper(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), whitening); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs index 6686be19e4..9da2f086e6 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs @@ -31,7 +31,7 @@ public static void Example() var wordsDataview = wordsPipeline.Fit(trainData).Transform(trainData); // Preview of the CleanWords column obtained after processing SentimentText. - var cleanWords = wordsDataview.GetColumn[]>("CleanWords"); + var cleanWords = wordsDataview.GetColumn[]>(wordsDataview.Schema["CleanWords"]); Console.WriteLine($" CleanWords column obtained post-transformation."); foreach (var featureRow in cleanWords) { @@ -86,7 +86,7 @@ public static void Example() // And do all required transformations. var embeddingDataview = pipeline.Fit(wordsDataview).Transform(wordsDataview); - var customEmbeddings = embeddingDataview.GetColumn("CustomEmbeddings"); + var customEmbeddings = embeddingDataview.GetColumn(embeddingDataview.Schema["CustomEmbeddings"]); printEmbeddings("GloveEmbeddings", customEmbeddings); // -1 -2 -3 -0.5 -1 8.5 0 0 20 @@ -98,7 +98,7 @@ public static void Example() // Second set of 3 floats in output represent average (for each dimension) for extracted values. // Third set of 3 floats in output represent maximum values (for each dimension) for extracted values. // Preview of GloveEmbeddings. - var gloveEmbeddings = embeddingDataview.GetColumn("GloveEmbeddings"); + var gloveEmbeddings = embeddingDataview.GetColumn(embeddingDataview.Schema["GloveEmbeddings"]); printEmbeddings("GloveEmbeddings", gloveEmbeddings); // 0.23166 0.048825 0.26878 -1.3945 -0.86072 -0.026778 0.84075 -0.81987 -1.6681 -1.0658 -0.30596 0.50974 ... //-0.094905 0.61109 0.52546 - 0.2516 0.054786 0.022661 1.1801 0.33329 - 0.85388 0.15471 - 0.5984 0.4364 ... diff --git a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs index 5f959b6b88..88b6c6839d 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs @@ -83,8 +83,8 @@ public static void FeatureSelectionTransform() }; // Print the data that results from the transformations. - var countSelectColumn = transformedData.AsDynamic.GetColumn>("FeaturesCountSelect"); - var MISelectColumn = transformedData.AsDynamic.GetColumn>("FeaturesMISelect"); + var countSelectColumn = transformedData.AsDynamic.GetColumn>(transformedData.AsDynamic.Schema["FeaturesCountSelect"]); + var MISelectColumn = transformedData.AsDynamic.GetColumn>(transformedData.AsDynamic.Schema["FeaturesMISelect"]); printHelper("FeaturesCountSelect", countSelectColumn); printHelper("FeaturesMISelect", MISelectColumn); diff --git a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs index d2c55158b6..178c83345d 100644 --- a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs +++ b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs @@ -20,14 +20,18 @@ public static class ColumnCursorExtensions /// /// The type of the values. This must match the actual column type. /// The data view to get the column from. - /// The name of the column to extract. - public static IEnumerable GetColumn(this IDataView data, string columnName) + /// The column to be extracted. + public static IEnumerable GetColumn(this IDataView data, DataViewSchema.Column column) { Contracts.CheckValue(data, nameof(data)); - Contracts.CheckNonEmpty(columnName, nameof(columnName)); + Contracts.CheckNonEmpty(column.Name, nameof(column)); - if (!data.Schema.TryGetColumnIndex(columnName, out int col)) - throw Contracts.ExceptSchemaMismatch(nameof(columnName), "input", columnName); + if (!data.Schema.TryGetColumnIndex(column.Name, out int colIndex)) + throw Contracts.ExceptParam(nameof(column), string.Format("column name {0} cannot be found in {1}", column.Name, nameof(data))); + + if (data.Schema[colIndex].Type != column.Type) + throw Contracts.ExceptParam(nameof(column), string.Format("column {0}'s type {1} doesn't match the expected type {2} in {3}", + column.Name, column.Type, data.Schema[colIndex].Type, nameof(data))); // There are two decisions that we make here: // - Is the T an array type? @@ -37,11 +41,11 @@ public static IEnumerable GetColumn(this IDataView data, string columnName // - If this is the same type, we can map directly. // - Otherwise, we need a conversion delegate. - var colType = data.Schema[col].Type; + var colType = column.Type; if (colType.RawType == typeof(T)) { // Direct mapping is possible. - return GetColumnDirect(data, col); + return GetColumnDirect(data, colIndex); } else if (typeof(T) == typeof(string) && colType is TextDataViewType) { @@ -49,20 +53,20 @@ public static IEnumerable GetColumn(this IDataView data, string columnName Delegate convert = (Func, string>)((ReadOnlyMemory txt) => txt.ToString()); Func, IEnumerable> del = GetColumnConvert; var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(typeof(T), colType.RawType); - return (IEnumerable)(meth.Invoke(null, new object[] { data, col, convert })); + return (IEnumerable)(meth.Invoke(null, new object[] { data, colIndex, convert })); } else if (typeof(T).IsArray) { // Output is an array type. if (!(colType is VectorType colVectorType)) - throw Contracts.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar"); + throw Contracts.ExceptParam(nameof(column), string.Format("Cannot load vector type, {0}, specified in {1} to the user-defined type, {2}.", column.Type, nameof(column), typeof(T))); var elementType = typeof(T).GetElementType(); if (elementType == colVectorType.ItemType.RawType) { // Direct mapping of items. Func> del = GetColumnArrayDirect; var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType); - return (IEnumerable)meth.Invoke(null, new object[] { data, col }); + return (IEnumerable)meth.Invoke(null, new object[] { data, colIndex }); } else if (elementType == typeof(string) && colVectorType.ItemType is TextDataViewType) { @@ -70,11 +74,13 @@ public static IEnumerable GetColumn(this IDataView data, string columnName Delegate convert = (Func, string>)((ReadOnlyMemory txt) => txt.ToString()); Func, IEnumerable> del = GetColumnArrayConvert; var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType, colVectorType.ItemType.RawType); - return (IEnumerable)meth.Invoke(null, new object[] { data, col, convert }); + return (IEnumerable)meth.Invoke(null, new object[] { data, colIndex, convert }); } // Fall through to the failure. } - throw Contracts.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}."); + + throw Contracts.ExceptParam(nameof(column), string.Format("Cannot map column (name: {0}, type: {1}) in {2} to the user-defined type, {3}.", + column.Name, column.Type, nameof(data), typeof(T))); } private static IEnumerable GetColumnDirect(IDataView data, int col) diff --git a/src/Microsoft.ML.StaticPipe/DataView.cs b/src/Microsoft.ML.StaticPipe/DataView.cs index 88f76b2bd2..42e45e2072 100644 --- a/src/Microsoft.ML.StaticPipe/DataView.cs +++ b/src/Microsoft.ML.StaticPipe/DataView.cs @@ -49,7 +49,8 @@ private static IEnumerable GetColumnCore(DataView da var indexer = StaticPipeUtils.GetIndexer(data); string columnName = indexer.Get(column(indexer.Indices)); - return data.AsDynamic.GetColumn(columnName); + var dynamicData = data.AsDynamic; + return dynamicData.GetColumn(dynamicData.Schema[columnName]); } public static IEnumerable GetColumn(this DataView data, Func> column) diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 9e71ad1afe..b39b4c34b8 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -793,10 +793,10 @@ public void NAIndicatorStatic() IDataView newData = ml.Data.TakeRows(est.Fit(data).Transform(data).AsDynamic, 4); Assert.NotNull(newData); - bool[] ScalarFloat = newData.GetColumn("A").ToArray(); - bool[] ScalarDouble = newData.GetColumn("B").ToArray(); - bool[][] VectorFloat = newData.GetColumn("C").ToArray(); - bool[][] VectorDoulbe = newData.GetColumn("D").ToArray(); + bool[] ScalarFloat = newData.GetColumn(newData.Schema["A"]).ToArray(); + bool[] ScalarDouble = newData.GetColumn(newData.Schema["B"]).ToArray(); + bool[][] VectorFloat = newData.GetColumn(newData.Schema["C"]).ToArray(); + bool[][] VectorDoulbe = newData.GetColumn(newData.Schema["D"]).ToArray(); Assert.NotNull(ScalarFloat); Assert.NotNull(ScalarDouble); diff --git a/test/Microsoft.ML.Tests/CachingTests.cs b/test/Microsoft.ML.Tests/CachingTests.cs index fd505cb9ee..8f4cba9f1b 100644 --- a/test/Microsoft.ML.Tests/CachingTests.cs +++ b/test/Microsoft.ML.Tests/CachingTests.cs @@ -66,15 +66,15 @@ public void CacheTest() { var src = Enumerable.Range(0, 100).Select(c => new MyData()).ToArray(); var data = ML.Data.LoadFromEnumerable(src); - data.GetColumn("Features").ToArray(); - data.GetColumn("Features").ToArray(); + data.GetColumn(data.Schema["Features"]).ToArray(); + data.GetColumn(data.Schema["Features"]).ToArray(); Assert.True(src.All(x => x.AccessCount == 2)); src = Enumerable.Range(0, 100).Select(c => new MyData()).ToArray(); data = ML.Data.LoadFromEnumerable(src); data = ML.Data.Cache(data); - data.GetColumn("Features").ToArray(); - data.GetColumn("Features").ToArray(); + data.GetColumn(data.Schema["Features"]).ToArray(); + data.GetColumn(data.Schema["Features"]).ToArray(); Assert.True(src.All(x => x.AccessCount == 1)); } diff --git a/test/Microsoft.ML.Tests/RangeFilterTests.cs b/test/Microsoft.ML.Tests/RangeFilterTests.cs index f66c5e9452..ee487e5cec 100644 --- a/test/Microsoft.ML.Tests/RangeFilterTests.cs +++ b/test/Microsoft.ML.Tests/RangeFilterTests.cs @@ -26,12 +26,12 @@ public void RangeFilterTest() var data = builder.GetDataView(); var data1 = ML.Data.FilterRowsByColumn(data, "Floats", upperBound: 2.8); - var cnt = data1.GetColumn("Floats").Count(); + var cnt = data1.GetColumn(data1.Schema["Floats"]).Count(); Assert.Equal(2L, cnt); data = ML.Transforms.Conversion.Hash("Key", "Strings", hashBits: 20).Fit(data).Transform(data); var data2 = ML.Data.FilterRowsByKeyColumnFraction(data, "Key", upperBound: 0.5); - cnt = data2.GetColumn("Floats").Count(); + cnt = data2.GetColumn(data.Schema["Floats"]).Count(); Assert.Equal(1L, cnt); } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 64b1ad876b..9e49fa1ae2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -79,7 +79,7 @@ private void IntermediateData(string dataPath) // The same extension method also applies to the dynamic-typed data, except you have to // specify the column name and type: var dynamicData = transformedData.AsDynamic; - var sameFeatureColumns = dynamicData.GetColumn("AllFeatures") + var sameFeatureColumns = dynamicData.GetColumn(dynamicData.Schema["AllFeatures"]) .Take(20).ToArray(); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index c7f0ab42d1..a5e3710ce7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -61,7 +61,7 @@ private void IntermediateData(string dataPath) // This will give the entire dataset: make sure to only take several row // in case the dataset is huge. The is similar to the static API, except // you have to specify the column name and type. - var featureColumns = transformedData.GetColumn("AllFeatures") + var featureColumns = transformedData.GetColumn(transformedData.Schema["AllFeatures"]) .Take(20).ToArray(); } @@ -251,7 +251,7 @@ private void NormalizationWorkout(string dataPath) var normalizedData = pipeline.Fit(trainData).Transform(trainData); // Inspect one column of the resulting dataset. - var meanVarValues = normalizedData.GetColumn("MeanVarNormalized").ToArray(); + var meanVarValues = normalizedData.GetColumn(normalizedData.Schema["MeanVarNormalized"]).ToArray(); } [Fact] @@ -289,7 +289,7 @@ private void TextFeaturizationOn(string dataPath) var data = loader.Load(dataPath); // Inspect the message texts that are read from the file. - var messageTexts = data.GetColumn("Message").Take(20).ToArray(); + var messageTexts = data.GetColumn(data.Schema["Message"]).Take(20).ToArray(); // Apply various kinds of text operations supported by ML.NET. var pipeline = @@ -321,8 +321,8 @@ private void TextFeaturizationOn(string dataPath) var transformedData = pipeline.Fit(data).Transform(data); // Inspect some columns of the resulting dataset. - var embeddings = transformedData.GetColumn("Embeddings").Take(10).ToArray(); - var unigrams = transformedData.GetColumn("BagOfWords").Take(10).ToArray(); + var embeddings = transformedData.GetColumn(transformedData.Schema["Embeddings"]).Take(10).ToArray(); + var unigrams = transformedData.GetColumn(transformedData.Schema["BagOfWords"]).Take(10).ToArray(); } [Fact(Skip = "This test is running for one minute")] @@ -361,7 +361,7 @@ private void CategoricalFeaturizationOn(params string[] dataPath) var data = loader.Load(dataPath); // Inspect the first 10 records of the categorical columns to check that they are correctly read. - var catColumns = data.GetColumn("CategoricalFeatures").Take(10).ToArray(); + var catColumns = data.GetColumn(data.Schema["CategoricalFeatures"]).Take(10).ToArray(); // Build several alternative featurization pipelines. var pipeline = @@ -377,8 +377,8 @@ private void CategoricalFeaturizationOn(params string[] dataPath) var transformedData = pipeline.Fit(data).Transform(data); // Inspect some columns of the resulting dataset. - var categoricalBags = transformedData.GetColumn("CategoricalBag").Take(10).ToArray(); - var workclasses = transformedData.GetColumn("WorkclassOneHotTrimmed").Take(10).ToArray(); + var categoricalBags = transformedData.GetColumn(transformedData.Schema["CategoricalBag"]).Take(10).ToArray(); + var workclasses = transformedData.GetColumn(transformedData.Schema["WorkclassOneHotTrimmed"]).Take(10).ToArray(); // Of course, if we want to train the model, we will need to compose a single float vector of all the features. // Here's how we could do this: diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs index 43ce72912d..9f9838c9da 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs @@ -31,9 +31,9 @@ void Visibility() var src = new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)); var data = pipeline.Fit(src).Load(src); - var textColumn = data.GetColumn("SentimentText").Take(20); - var transformedTextColumn = data.GetColumn("Features_TransformedText").Take(20); - var features = data.GetColumn("Features").Take(20); + var textColumn = data.GetColumn(data.Schema["SentimentText"]).Take(20); + var transformedTextColumn = data.GetColumn(data.Schema["Features_TransformedText"]).Take(20); + var features = data.GetColumn(data.Schema["Features"]).Take(20); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index d0a0e7e30a..a2e2da1daa 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -306,7 +306,7 @@ public void TestTrainTestSplit() // this function will accept dataview and return content of "Workclass" column as List of strings. Func> getWorkclass = (IDataView view) => { - return view.GetColumn>("Workclass").Select(x => x.ToString()).ToList(); + return view.GetColumn>(view.Schema["Workclass"]).Select(x => x.ToString()).ToList(); }; // Let's test what train test properly works with seed. diff --git a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs index 04c3ce390d..b36bf990a4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs @@ -53,18 +53,18 @@ public void TestGetColumn() } }; - var enum1 = data.AsDynamic.GetColumn("floatScalar").ToArray(); - var enum2 = data.AsDynamic.GetColumn("floatVector").ToArray(); - var enum3 = data.AsDynamic.GetColumn>("floatVector").ToArray(); + var enum1 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"]).ToArray(); + var enum2 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatVector"]).ToArray(); + var enum3 = data.AsDynamic.GetColumn>(data.AsDynamic.Schema["floatVector"]).ToArray(); - var enum4 = data.AsDynamic.GetColumn("stringScalar").ToArray(); - var enum5 = data.AsDynamic.GetColumn("stringVector").ToArray(); + var enum4 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["stringScalar"]).ToArray(); + var enum5 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["stringVector"]).ToArray(); - mustFail(() => data.AsDynamic.GetColumn("floatScalar")); - mustFail(() => data.AsDynamic.GetColumn("floatVector")); - mustFail(() => data.AsDynamic.GetColumn("floatScalar")); - mustFail(() => data.AsDynamic.GetColumn("floatScalar")); - mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); + mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatVector"])); + mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); + mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); + mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); // Static types. var enum8 = data.GetColumn(r => r.floatScalar); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs index a09407729a..8e4f6bf170 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs @@ -42,9 +42,9 @@ public void TestEstimatorSymSgdInitPredictor() var outNoInitData = notInitPredictor.Transform(transformedData); int numExamples = 10; - var col1 = data.GetColumn("Score").Take(numExamples).ToArray(); - var col2 = outInitData.GetColumn("Score").Take(numExamples).ToArray(); - var col3 = outNoInitData.GetColumn("Score").Take(numExamples).ToArray(); + var col1 = data.GetColumn(data.Schema["Score"]).Take(numExamples).ToArray(); + var col2 = outInitData.GetColumn(outInitData.Schema["Score"]).Take(numExamples).ToArray(); + var col3 = outNoInitData.GetColumn(outNoInitData.Schema["Score"]).Take(numExamples).ToArray(); bool col12Diff = default; bool col23Diff = default; diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index 4aadf5852e..11cb53ab22 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -647,7 +647,7 @@ void TestValueMapBackCompatTermLookupKeyTypeValue() Assert.True(result.Schema[labelIdx].Type is KeyType); Assert.Equal((ulong)5, result.Schema[labelIdx].Type.GetItemType().GetKeyCount()); - var t = result.GetColumn("Label"); + var t = result.GetColumn(result.Schema["Label"]); uint s = t.First(); Assert.Equal((uint)3, s); } From 2798ebb12217c959535f8cb9654dd2664898acaf Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 26 Feb 2019 13:16:54 -0800 Subject: [PATCH 4/8] Use index in the column passed in to access everything --- src/Microsoft.ML.Data/Utilities/ColumnCursor.cs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs index 178c83345d..31d3081ec5 100644 --- a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs +++ b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs @@ -26,12 +26,15 @@ public static IEnumerable GetColumn(this IDataView data, DataViewSchema.Co Contracts.CheckValue(data, nameof(data)); Contracts.CheckNonEmpty(column.Name, nameof(column)); - if (!data.Schema.TryGetColumnIndex(column.Name, out int colIndex)) - throw Contracts.ExceptParam(nameof(column), string.Format("column name {0} cannot be found in {1}", column.Name, nameof(data))); + var colIndex = column.Index; + var colType = column.Type; + var colName = column.Name; - if (data.Schema[colIndex].Type != column.Type) - throw Contracts.ExceptParam(nameof(column), string.Format("column {0}'s type {1} doesn't match the expected type {2} in {3}", - column.Name, column.Type, data.Schema[colIndex].Type, nameof(data))); + // Use column index as the principle address of the specified input column and check if that address in data contains + // the column indicated. + if (data.Schema[colIndex].Name != colName || data.Schema[colIndex].Type != colType) + throw Contracts.ExceptParam(nameof(column), string.Format("column with name {0}, type {1}, and index {2} cannot be found in {3}", + colName, colType, colIndex, nameof(data))); // There are two decisions that we make here: // - Is the T an array type? @@ -41,7 +44,6 @@ public static IEnumerable GetColumn(this IDataView data, DataViewSchema.Co // - If this is the same type, we can map directly. // - Otherwise, we need a conversion delegate. - var colType = column.Type; if (colType.RawType == typeof(T)) { // Direct mapping is possible. From 9af95abcf7236a3325d28e156619ffd3e9d25920 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 26 Feb 2019 20:02:49 -0800 Subject: [PATCH 5/8] Update cookbook examples --- docs/code/MlNetCookBook.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 7a9f9a77c0..80eeaa27c7 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -303,8 +303,8 @@ var someRows = mlContext // This will give the entire dataset: make sure to only take several row // in case the dataset is huge. The is similar to the static API, except // you have to specify the column name and type. -var featureColumns = transformedData.GetColumn(mlContext, "AllFeatures") - .Take(20).ToArray(); +var featureColumns = transformedData.GetColumn(transformedData.Schema["AllFeatures"]) + ``` ## How do I train a regression model? @@ -637,7 +637,7 @@ var pipeline = var normalizedData = pipeline.Fit(trainData).Transform(trainData); // Inspect one column of the resulting dataset. -var meanVarValues = normalizedData.GetColumn(mlContext, "MeanVarNormalized").ToArray(); +var meanVarValues = normalizedData.GetColumn(normalizedData.Schema["MeanVarNormalized"]).ToArray(); ``` ## How do I train my model on categorical data? @@ -682,8 +682,8 @@ var loader = mlContext.Data.CreateTextLoader(new[] // Load the data. var data = loader.Load(dataPath); -// Inspect the first 10 records of the categorical columns to check that they are correctly load. -var catColumns = data.GetColumn(mlContext, "CategoricalFeatures").Take(10).ToArray(); +// Inspect the first 10 records of the categorical columns to check that they are correctly read. +var catColumns = data.GetColumn(data.Schema["CategoricalFeatures"]).Take(10).ToArray(); // Build several alternative featurization pipelines. var pipeline = @@ -699,8 +699,8 @@ var pipeline = var transformedData = pipeline.Fit(data).Transform(data); // Inspect some columns of the resulting dataset. -var categoricalBags = transformedData.GetColumn(mlContext, "CategoricalBag").Take(10).ToArray(); -var workclasses = transformedData.GetColumn(mlContext, "WorkclassOneHotTrimmed").Take(10).ToArray(); +var categoricalBags = transformedData.GetColumn(transformedData.Schema["CategoricalBag"]).Take(10).ToArray(); +var workclasses = transformedData.GetColumn(transformedData.Schema["WorkclassOneHotTrimmed"]).Take(10).ToArray(); // Of course, if we want to train the model, we will need to compose a single float vector of all the features. // Here's how we could do this: @@ -756,8 +756,8 @@ var loader = mlContext.Data.CreateTextLoader(new[] // Load the data. var data = loader.Load(dataPath); -// Inspect the message texts that are load from the file. -var messageTexts = data.GetColumn(mlContext, "Message").Take(20).ToArray(); +// Inspect the message texts that are read from the file. +var messageTexts = data.GetColumn(data.Schema["Message"]).Take(20).ToArray(); // Apply various kinds of text operations supported by ML.NET. var pipeline = From 858542a962a3b3dbe866ad1d89c5b3369389a5a8 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 27 Feb 2019 09:40:50 -0800 Subject: [PATCH 6/8] Improve a test --- .../Microsoft.ML.Tests/Scenarios/GetColumnTests.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs index b36bf990a4..29f238f6f8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs @@ -66,11 +66,25 @@ public void TestGetColumn() mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); + // Static types. var enum8 = data.GetColumn(r => r.floatScalar); var enum9 = data.GetColumn(r => r.floatVector); var enum10 = data.GetColumn(r => r.stringScalar); var enum11 = data.GetColumn(r => r.stringVector); + + var data1 = TextLoaderStatic.CreateLoader(env, ctx => ( + floatScalar: ctx.LoadText(1), + anotherFloatVector: ctx.LoadFloat(2, 6), + stringVector: ctx.LoadText(5, 7) + )).Load(path); + + // Type wrong. Load float as string. + mustFail(() => data.AsDynamic.GetColumn(data1.AsDynamic.Schema["floatScalar"])); + // Name wrong. Load anotherFloatVector from floatVector column. + mustFail(() => data.AsDynamic.GetColumn(data1.AsDynamic.Schema["anotherFloatVector"])); + // Index wrong. stringVector is indexed by 3 in data but 2 in data1. + mustFail(() => data.AsDynamic.GetColumn(data1.AsDynamic.Schema["stringVector"]).ToArray()); } } } From 51a89b2f91c225bba87479a61c8e36ead04e5566 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 27 Feb 2019 09:50:04 -0800 Subject: [PATCH 7/8] Add overload over string --- .../Utilities/ColumnCursor.cs | 11 +++ .../Scenarios/GetColumnTests.cs | 79 ++++++++++++++----- 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs index 31d3081ec5..721f91729d 100644 --- a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs +++ b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs @@ -15,6 +15,17 @@ namespace Microsoft.ML.Data /// public static class ColumnCursorExtensions { + + /// + /// Extract all values of one column of the data view in a form of an . + /// + /// The type of the values. This must match the actual column type. + /// The data view to get the column from. + /// The name of the column to be extracted. + + public static IEnumerable GetColumn(this IDataView data, string columnName) + => GetColumn(data, data.Schema[columnName]); + /// /// Extract all values of one column of the data view in a form of an . /// diff --git a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs index 29f238f6f8..b71c058b38 100644 --- a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs @@ -23,7 +23,6 @@ public GetColumnTests(ITestOutputHelper output) : base(output) [Fact] public void TestGetColumn() { - var path = GetDataPath(TestDatasets.breastCancer.trainFilename); var env = new MLContext(); var data = TextLoaderStatic.CreateLoader(env, ctx => ( @@ -33,26 +32,6 @@ public void TestGetColumn() stringVector: ctx.LoadText(5, 7) )).Load(path); - Action mustFail = (Action action) => - { - try - { - action(); - Assert.False(true); - } - catch (ArgumentOutOfRangeException) { } - catch (InvalidOperationException) { } - catch (TargetInvocationException ex) - { - Exception e; - for (e = ex; e.InnerException != null; e = e.InnerException) - { - } - Assert.True(e is ArgumentOutOfRangeException || e is InvalidOperationException); - Assert.True(e.IsMarked()); - } - }; - var enum1 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"]).ToArray(); var enum2 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatVector"]).ToArray(); var enum3 = data.AsDynamic.GetColumn>(data.AsDynamic.Schema["floatVector"]).ToArray(); @@ -60,6 +39,7 @@ public void TestGetColumn() var enum4 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["stringScalar"]).ToArray(); var enum5 = data.AsDynamic.GetColumn(data.AsDynamic.Schema["stringVector"]).ToArray(); + var mustFail = GetMustFail(); mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatVector"])); mustFail(() => data.AsDynamic.GetColumn(data.AsDynamic.Schema["floatScalar"])); @@ -86,5 +66,62 @@ public void TestGetColumn() // Index wrong. stringVector is indexed by 3 in data but 2 in data1. mustFail(() => data.AsDynamic.GetColumn(data1.AsDynamic.Schema["stringVector"]).ToArray()); } + + [Fact] + public void TestGetColumnSelectedByString() + { + var path = GetDataPath(TestDatasets.breastCancer.trainFilename); + var env = new MLContext(); + var data = TextLoaderStatic.CreateLoader(env, ctx => ( + floatScalar: ctx.LoadFloat(1), + floatVector: ctx.LoadFloat(2, 6), + stringScalar: ctx.LoadText(4), + stringVector: ctx.LoadText(5, 7) + )).Load(path); + + var enum1 = data.AsDynamic.GetColumn("floatScalar").ToArray(); + var enum2 = data.AsDynamic.GetColumn("floatVector").ToArray(); + var enum3 = data.AsDynamic.GetColumn>("floatVector").ToArray(); + + var enum4 = data.AsDynamic.GetColumn("stringScalar").ToArray(); + var enum5 = data.AsDynamic.GetColumn("stringVector").ToArray(); + + var mustFail = GetMustFail(); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatVector")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + mustFail(() => data.AsDynamic.GetColumn("floatScalar")); + + + // Static types. + var enum8 = data.GetColumn(r => r.floatScalar); + var enum9 = data.GetColumn(r => r.floatVector); + var enum10 = data.GetColumn(r => r.stringScalar); + var enum11 = data.GetColumn(r => r.stringVector); + } + + private static Action GetMustFail() + { + return (Action action) => + { + try + { + action(); + Assert.False(true); + } + catch (ArgumentOutOfRangeException) { } + catch (InvalidOperationException) { } + catch (TargetInvocationException ex) + { + Exception e; + for (e = ex; e.InnerException != null; e = e.InnerException) + { + } + Assert.True(e is ArgumentOutOfRangeException || e is InvalidOperationException); + Assert.True(e.IsMarked()); + } + }; + } } } From b8cce02fe9901a4f6ca48013a03455a284d5eec0 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 27 Feb 2019 09:54:35 -0800 Subject: [PATCH 8/8] Drop repeated lines --- test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs index b71c058b38..e3e76ae7b5 100644 --- a/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/GetColumnTests.cs @@ -92,13 +92,6 @@ public void TestGetColumnSelectedByString() mustFail(() => data.AsDynamic.GetColumn("floatScalar")); mustFail(() => data.AsDynamic.GetColumn("floatScalar")); mustFail(() => data.AsDynamic.GetColumn("floatScalar")); - - - // Static types. - var enum8 = data.GetColumn(r => r.floatScalar); - var enum9 = data.GetColumn(r => r.floatVector); - var enum10 = data.GetColumn(r => r.stringScalar); - var enum11 = data.GetColumn(r => r.stringVector); } private static Action GetMustFail()