Skip to content

Commit c8b5a56

Browse files
committed
Removes ChooseColumnsTransform and DropColumnsTransform classes
replacing them with SelectColumnsTransform. These changes include: * Updates to SelectColumnsTransform to respect ordering when keeping columns. For example, if the input is ABC and CB is selected, the output will be CB. * Updates to code that used Choose or Drop columns, replacing with SelectColumns. * Updates to baseline output for tests to pass * Re-enabled the SavePipeline tests This fixes #1342 These changes are also related to #754
1 parent 586533c commit c8b5a56

File tree

45 files changed

+592
-1586
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+592
-1586
lines changed

src/Microsoft.ML.Data/Commands/SaveDataCommand.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.Runtime.Data;
1313
using Microsoft.ML.Runtime.Data.IO;
1414
using Microsoft.ML.Runtime.Internal.Utilities;
15+
using Microsoft.ML.Transforms;
1516

1617
[assembly: LoadableClass(SaveDataCommand.Summary, typeof(SaveDataCommand), typeof(SaveDataCommand.Arguments), typeof(SignatureCommand),
1718
"Save Data", "SaveData", "save")]
@@ -129,11 +130,10 @@ private void RunCore(IChannel ch)
129130

130131
if (!string.IsNullOrWhiteSpace(Args.Columns))
131132
{
132-
var args = new ChooseColumnsTransform.Arguments();
133-
args.Column = Args.Columns
134-
.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).Select(s => new ChooseColumnsTransform.Column() { Name = s }).ToArray();
135-
if (Utils.Size(args.Column) > 0)
136-
data = new ChooseColumnsTransform(Host, args, data);
133+
var keepColumns = Args.Columns
134+
.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).ToArray();
135+
if (keepColumns.Length > 0)
136+
data = SelectColumnsTransform.CreateKeep(Host, data, keepColumns);
137137
}
138138

139139
IDataSaver saver;

src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I
101101
}
102102

103103
var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data);
104-
var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = copyCols.Select(c => c.Source).ToArray() }, copyColumn);
104+
var dropColumn = SelectColumnsTransform.CreateDrop(env, copyColumn, copyCols.Select(c => c.Source).ToArray());
105105
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn };
106106
}
107107
}

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.ML.Runtime.Data;
1212
using Microsoft.ML.Runtime.EntryPoints;
1313
using Microsoft.ML.Runtime.Internal.Utilities;
14+
using Microsoft.ML.Transforms;
1415

1516
[assembly: LoadableClass(typeof(AnomalyDetectionEvaluator), typeof(AnomalyDetectionEvaluator), typeof(AnomalyDetectionEvaluator.Arguments), typeof(SignatureEvaluator),
1617
"Anomaly Detection Evaluator", AnomalyDetectionEvaluator.LoadName, "AnomalyDetection", "Anomaly")]
@@ -704,59 +705,56 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
704705
}
705706
}
706707

707-
var args = new ChooseColumnsTransform.Arguments();
708-
var cols = new List<ChooseColumnsTransform.Column>()
708+
var kFormatName = string.Format(FoldDrAtKFormat, _k);
709+
var pFormatName = string.Format(FoldDrAtPFormat, _p);
710+
var numAnomName = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies);
711+
712+
var args = new CopyColumnsTransform.Arguments();
713+
var cols = new List<CopyColumnsTransform.Column>()
714+
{
715+
new CopyColumnsTransform.Column()
709716
{
710-
new ChooseColumnsTransform.Column()
711-
{
712-
Name = string.Format(FoldDrAtKFormat, _k),
713-
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK
714-
},
715-
new ChooseColumnsTransform.Column()
716-
{
717-
Name = string.Format(FoldDrAtPFormat, _p),
718-
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr
719-
},
720-
new ChooseColumnsTransform.Column()
721-
{
722-
Name = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies),
723-
Source=AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos
724-
},
725-
new ChooseColumnsTransform.Column()
726-
{
727-
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK
728-
},
729-
new ChooseColumnsTransform.Column()
730-
{
731-
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP
732-
},
733-
new ChooseColumnsTransform.Column()
734-
{
735-
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
736-
},
737-
new ChooseColumnsTransform.Column()
738-
{
739-
Name = BinaryClassifierEvaluator.Auc
740-
}
741-
};
717+
Name = kFormatName,
718+
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK
719+
},
720+
new CopyColumnsTransform.Column()
721+
{
722+
Name = pFormatName,
723+
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr
724+
},
725+
new CopyColumnsTransform.Column()
726+
{
727+
Name = numAnomName,
728+
Source=AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos
729+
}
730+
};
731+
732+
// List of columns to keep, note that the order specified determines the order of the output
733+
var colsToKeep = new List<string>();
734+
colsToKeep.Add(kFormatName);
735+
colsToKeep.Add(pFormatName);
736+
colsToKeep.Add(numAnomName);
737+
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK);
738+
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP);
739+
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
740+
colsToKeep.Add(BinaryClassifierEvaluator.Auc);
742741

743742
args.Column = cols.ToArray();
744-
IDataView fold = new ChooseColumnsTransform(Host, args, overall);
743+
overall = CopyColumnsTransform.Create(Host, args, overall);
744+
IDataView fold = SelectColumnsTransform.CreateKeep(Host, overall, colsToKeep.ToArray());
745+
745746
string weightedFold;
746747
ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
747748
}
748749

749750
protected override IDataView GetOverallResultsCore(IDataView overall)
750751
{
751-
var args = new DropColumnsTransform.Arguments();
752-
args.Column = new[]
753-
{
754-
AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
755-
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
756-
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
757-
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
758-
};
759-
return new DropColumnsTransform(Host, args, overall);
752+
return SelectColumnsTransform.CreateDrop(Host,
753+
overall,
754+
AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
755+
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
756+
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
757+
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
760758
}
761759

762760
protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.Runtime.Internal.Utilities;
1313
using Microsoft.ML.Runtime.Model;
1414
using Microsoft.ML.Runtime.Internal.Internallearn;
15+
using Microsoft.ML.Transforms;
1516

1617
[assembly: LoadableClass(typeof(BinaryClassifierEvaluator), typeof(BinaryClassifierEvaluator), typeof(BinaryClassifierEvaluator.Arguments), typeof(SignatureEvaluator),
1718
"Binary Classifier Evaluator", BinaryClassifierEvaluator.LoadName, "BinaryClassifier", "Binary", "bin")]
@@ -1333,43 +1334,47 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
13331334
if (!metrics.TryGetValue(MetricKinds.ConfusionMatrix, out conf))
13341335
throw ch.Except("No overall metrics found");
13351336

1336-
var args = new ChooseColumnsTransform.Arguments();
1337-
var cols = new List<ChooseColumnsTransform.Column>()
1337+
var args = new CopyColumnsTransform.Arguments();
1338+
var cols = new List<CopyColumnsTransform.Column>()
13381339
{
1339-
new ChooseColumnsTransform.Column()
1340+
new CopyColumnsTransform.Column()
13401341
{
13411342
Name = FoldAccuracy,
13421343
Source = BinaryClassifierEvaluator.Accuracy
13431344
},
1344-
new ChooseColumnsTransform.Column()
1345+
new CopyColumnsTransform.Column()
13451346
{
13461347
Name = FoldLogLoss,
13471348
Source = BinaryClassifierEvaluator.LogLoss
13481349
},
1349-
new ChooseColumnsTransform.Column()
1350-
{
1351-
Name = BinaryClassifierEvaluator.Entropy
1352-
},
1353-
new ChooseColumnsTransform.Column()
1350+
new CopyColumnsTransform.Column()
13541351
{
13551352
Name = FoldLogLosRed,
13561353
Source = BinaryClassifierEvaluator.LogLossReduction
1357-
},
1358-
new ChooseColumnsTransform.Column()
1359-
{
1360-
Name = BinaryClassifierEvaluator.Auc
13611354
}
13621355
};
1356+
1357+
var colsToKeep = new List<string>();
1358+
colsToKeep.Add(FoldAccuracy);
1359+
colsToKeep.Add(FoldLogLoss);
1360+
colsToKeep.Add(BinaryClassifierEvaluator.Entropy);
1361+
colsToKeep.Add(FoldLogLosRed);
1362+
colsToKeep.Add(BinaryClassifierEvaluator.Auc);
1363+
13631364
int index;
13641365
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out index))
1365-
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.IsWeighted });
1366+
colsToKeep.Add(MetricKinds.ColumnNames.IsWeighted);
13661367
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out index))
1367-
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratCol });
1368+
colsToKeep.Add(MetricKinds.ColumnNames.StratCol);
13681369
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out index))
1369-
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratVal });
1370+
colsToKeep.Add(MetricKinds.ColumnNames.StratVal);
13701371

13711372
args.Column = cols.ToArray();
1372-
fold = new ChooseColumnsTransform(Host, args, fold);
1373+
fold = CopyColumnsTransform.Create(Host, args, fold);
1374+
1375+
// Select the columns that are specified in the Copy
1376+
fold = SelectColumnsTransform.CreateKeep(Host, fold, colsToKeep.ToArray());
1377+
13731378
string weightedConf;
13741379
var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf);
13751380
string weightedFold;
@@ -1386,9 +1391,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
13861391

13871392
protected override IDataView GetOverallResultsCore(IDataView overall)
13881393
{
1389-
var args = new DropColumnsTransform.Arguments();
1390-
args.Column = new[] { BinaryClassifierEvaluator.Entropy };
1391-
return new DropColumnsTransform(Host, args, overall);
1394+
return SelectColumnsTransform.CreateDrop(Host, overall, BinaryClassifierEvaluator.Entropy);
13921395
}
13931396

13941397
protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System.Threading;
1212
using Microsoft.ML.Runtime.Data.IO;
1313
using Microsoft.ML.Runtime.Internal.Utilities;
14+
using Microsoft.ML.Transforms;
1415

1516
namespace Microsoft.ML.Runtime.Data
1617
{
@@ -931,7 +932,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string
931932
variableSizeVectorColumnName, type);
932933

933934
// Drop the old column that does not have variable length.
934-
idv = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv);
935+
idv = SelectColumnsTransform.CreateDrop(env, idv, variableSizeVectorColumnName);
935936
}
936937
return idv;
937938
};
@@ -1057,8 +1058,7 @@ internal static IDataView GetOverallMetricsData(IHostEnvironment env, IDataView
10571058
{
10581059
if (Utils.Size(nonAveragedCols) > 0)
10591060
{
1060-
var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() };
1061-
data = new DropColumnsTransform(env, dropArgs, data);
1061+
data = SelectColumnsTransform.CreateDrop(env, data, nonAveragedCols.ToArray());
10621062
}
10631063
idvList.Add(data);
10641064
}
@@ -1732,9 +1732,7 @@ public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView
17321732
var found = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);
17331733
env.Check(found, "If stratification column exist, data view must also contain a StratVal column");
17341734

1735-
var dropArgs = new DropColumnsTransform.Arguments();
1736-
dropArgs.Column = new[] { data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal) };
1737-
data = new DropColumnsTransform(env, dropArgs, data);
1735+
data = SelectColumnsTransform.CreateDrop(env, data, data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal));
17381736
return data;
17391737
}
17401738
}

src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using Microsoft.ML.Runtime.CommandLine;
88
using Microsoft.ML.Runtime.EntryPoints;
9+
using Microsoft.ML.Transforms;
910

1011
namespace Microsoft.ML.Runtime.Data
1112
{
@@ -212,13 +213,14 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
212213
var idv = perInst.Data;
213214

214215
// Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
215-
// the per-instance data computed by the evaluator in a ChooseColumnsTransform.
216-
var cols = new List<ChooseColumnsTransform.Column>();
216+
// the per-instance data computed by the evaluator in a SelectColumnsTransform.
217+
var cols = new List<CopyColumnsTransform.Column>();
218+
var colsToKeep = new List<string>();
217219

218220
// If perInst is the result of cross-validation and contains a fold Id column, include it.
219221
int foldCol;
220222
if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
221-
cols.Add(new ChooseColumnsTransform.Column() { Source = MetricKinds.ColumnNames.FoldIndex });
223+
colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);
222224

223225
// Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
224226
if (perInst.Schema.Name == null)
@@ -227,22 +229,26 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
227229
args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } };
228230
args.UseCounter = true;
229231
idv = new GenerateNumberTransform(Host, args, idv);
230-
cols.Add(new ChooseColumnsTransform.Column() { Name = "Instance" });
232+
colsToKeep.Add("Instance");
231233
}
232234
else
233-
cols.Add(new ChooseColumnsTransform.Column() { Source = perInst.Schema.Name.Name, Name = "Instance" });
235+
{
236+
cols.Add(new CopyColumnsTransform.Column() { Source = perInst.Schema.Name.Name, Name = "Instance" });
237+
colsToKeep.Add("Instance");
238+
}
234239

235240
// Maml outputs the weight column if it exists.
236241
if (perInst.Schema.Weight != null)
237-
cols.Add(new ChooseColumnsTransform.Column() { Name = perInst.Schema.Weight.Name });
242+
colsToKeep.Add(perInst.Schema.Weight.Name);
238243

239244
// Get the other columns from the evaluator.
240245
foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
241-
cols.Add(new ChooseColumnsTransform.Column() { Name = col });
246+
colsToKeep.Add(col);
242247

243-
var chooseArgs = new ChooseColumnsTransform.Arguments();
244-
chooseArgs.Column = cols.ToArray();
245-
idv = new ChooseColumnsTransform(Host, chooseArgs, idv);
248+
var copyArgs = new CopyColumnsTransform.Arguments();
249+
copyArgs.Column = cols.ToArray();
250+
idv = CopyColumnsTransform.Create(Host, copyArgs, idv);
251+
idv = SelectColumnsTransform.CreateKeep(Host, idv, colsToKeep.ToArray());
246252
return GetPerInstanceMetricsCore(idv, perInst.Schema);
247253
}
248254

src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,22 +1051,14 @@ protected override IDataView GetOverallResultsCore(IDataView overall)
10511051
private IDataView ChangeTopKAccColumnName(IDataView input)
10521052
{
10531053
input = new CopyColumnsTransform(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input);
1054-
var dropArgs = new DropColumnsTransform.Arguments
1055-
{
1056-
Column = new[] { MultiClassClassifierEvaluator.TopKAccuracy }
1057-
};
1058-
return new DropColumnsTransform(Host, dropArgs, input);
1054+
return SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.TopKAccuracy );
10591055
}
10601056

10611057
private IDataView DropPerClassColumn(IDataView input)
10621058
{
10631059
if (input.Schema.TryGetColumnIndex(MultiClassClassifierEvaluator.PerClassLogLoss, out int perClassCol))
10641060
{
1065-
var args = new DropColumnsTransform.Arguments
1066-
{
1067-
Column = new[] { MultiClassClassifierEvaluator.PerClassLogLoss }
1068-
};
1069-
input = new DropColumnsTransform(Host, args, input);
1061+
input = SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.PerClassLogLoss);
10701062
}
10711063
return input;
10721064
}

0 commit comments

Comments
 (0)