diff --git a/.vsts-dotnet-ci.yml b/.vsts-dotnet-ci.yml
new file mode 100644
index 0000000000..c2e0e2c0d4
--- /dev/null
+++ b/.vsts-dotnet-ci.yml
@@ -0,0 +1,15 @@
+phases:
+- template: /build/ci/phase-template.yml
+ parameters:
+ name: Linux
+ buildScript: ./build.sh
+ dockerImage: microsoft/dotnet-buildtools-prereqs:centos-7-b46d863-20180719033416
+
+- template: /build/ci/phase-template.yml
+ parameters:
+ name: Windows_NT
+ buildScript: build.cmd
+ queue:
+ name: Hosted VS2017
+ demands:
+ - agent.os -equals Windows_NT
diff --git a/Directory.Build.props b/Directory.Build.props
index a37097204f..73144201c7 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -14,6 +14,7 @@
https://api.nuget.org/v3/index.json;
+ https://dotnet.myget.org/F/dotnet-core/api/v3/index.json;
@@ -21,6 +22,7 @@
$(MSBuildThisFileDirectory)
$(RepoRoot)src/
+ $(RepoRoot)pkg/
$(RepoRoot)bin/
diff --git a/Directory.Build.targets b/Directory.Build.targets
index 1ab549e60a..5e6446add9 100644
--- a/Directory.Build.targets
+++ b/Directory.Build.targets
@@ -5,5 +5,33 @@
Text="The tools directory [$(ToolsDir)] does not exist. Please run build in the root of the repo to ensure the tools are installed before attempting to build an individual project." />
+
+
+
+ lib
+ .dll
+ .so
+ .dylib
+
+
+
+
+ $(NativeOutputPath)$(LibPrefix)%(NativeAssemblyReference.Identity)$(LibExtension)
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 3529c0e5b7..58e24041f1 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -5,6 +5,9 @@ MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Core", "src\Microsoft.ML.Core\Microsoft.ML.Core.csproj", "{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{09EADF06-BE25-4228-AB53-95AE3E15B530}"
+ ProjectSection(SolutionItems) = preProject
+ src\Source.ruleset = src\Source.ruleset
+ EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AED9C836-31E3-4F3F-8ABC-929555D3F3C4}"
EndProject
@@ -18,7 +21,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}"
EndProject
@@ -34,16 +37,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Mic
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}"
EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "data", "data", "{FDA2FD2C-A708-43AC-A941-4D941B0853BF}"
- ProjectSection(SolutionItems) = preProject
- test\data\sentiment_data.tsv = test\data\sentiment_data.tsv
- test\data\sentiment_test.tsv = test\data\sentiment_test.tsv
- EndProjectSection
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TestFramework", "test\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj", "{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InternalStreams", "src\Microsoft.ML.InternalStreams\Microsoft.ML.InternalStreams.csproj", "{C4F7938F-7109-43C8-92A5-9BE47C7FF7D9}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Predictor.Tests", "test\Microsoft.ML.Predictor.Tests\Microsoft.ML.Predictor.Tests.csproj", "{6B047E09-39C9-4583-96F3-685D84CA4117}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ResultProcessor", "src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj", "{3769FCC3-9AFF-4C37-97E9-6854324681DF}"
@@ -54,36 +49,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Parquet", "src
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper", "src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj", "{55C8122D-79EA-48AB-85D0-EB551FC1C427}"
EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{E20AF96D-3F66-4065-8A89-BEE479D74536}"
- ProjectSection(SolutionItems) = preProject
- docs\README.md = docs\README.md
- docs\release-notes\0.1\release-0.1.md = docs\release-notes\0.1\release-0.1.md
- EndProjectSection
-EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "project-docs", "project-docs", "{52794B40-AB8A-41AF-9EF7-799C80D6E0BC}"
- ProjectSection(SolutionItems) = preProject
- docs\project-docs\contributing.md = docs\project-docs\contributing.md
- docs\project-docs\developer-guide.md = docs\project-docs\developer-guide.md
- EndProjectSection
-EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{76F579E4-B9D2-4A0C-A511-EEFA4B2B829F}"
- ProjectSection(SolutionItems) = preProject
- CONTRIBUTING.md = CONTRIBUTING.md
- README.md = README.md
- ROADMAP.md = ROADMAP.md
- EndProjectSection
-EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "building", "building", "{DB751004-5D49-4B88-B78F-29CA9887087D}"
- ProjectSection(SolutionItems) = preProject
- docs\building\unix-instructions.md = docs\building\unix-instructions.md
- docs\building\windows-instructions.md = docs\building\windows-instructions.md
- EndProjectSection
-EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "specs", "specs", "{2DEFC784-F2B5-44EA-ABBB-0DCF3E689DAC}"
- ProjectSection(SolutionItems) = preProject
- docs\specs\mvp.md = docs\specs\mvp.md
- EndProjectSection
-EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "pkg", "pkg", "{D3D38B03-B557-484D-8348-8BADEE4DF592}"
ProjectSection(SolutionItems) = preProject
pkg\Directory.Build.props = pkg\Directory.Build.props
@@ -106,108 +71,264 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Mi
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}"
EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{487213C9-E8A9-4F94-85D7-28A05DBBFE3A}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "netstandard2.0", "netstandard2.0", "{9252A8EB-ABFB-440C-AB4D-1D562753CE0F}"
+ ProjectSection(SolutionItems) = preProject
+ pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.props = pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.props
+ pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.targets = pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.targets
+ EndProjectSection
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests", "test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj", "{3DEB504D-7A07-48CE-91A2-8047461CB3D4}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.LightGBM", "src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj", "{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Ensemble", "src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj", "{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.CpuMath", "Microsoft.ML.CpuMath", "{BF66A305-DF10-47E4-8D81-42049B149D2B}"
+ ProjectSection(SolutionItems) = preProject
+ pkg\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.nupkgproj = pkg\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.nupkgproj
+ pkg\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.symbols.nupkgproj = pkg\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.symbols.nupkgproj
+ EndProjectSection
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools-local", "tools-local", "{7F13E156-3EBA-4021-84A5-CD56BA72F99E}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer", "tools-local\Microsoft.ML.CodeAnalyzer\Microsoft.ML.CodeAnalyzer.csproj", "{B4E55B2D-2A92-46E7-B72F-E76D6FD83440}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
+ Debug-Intrinsics|Any CPU = Debug-Intrinsics|Any CPU
Release|Any CPU = Release|Any CPU
+ Release-Intrinsics|Any CPU = Release-Intrinsics|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Release|Any CPU.Build.0 = Release|Any CPU
+ {EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {EC743D1D-7691-43B7-B9B0-5F2F7018A8F6}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
{46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {46F2F967-C23F-4076-858D-33F7DA9BD2DA}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Release|Any CPU.ActiveCfg = Release|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Release|Any CPU.Build.0 = Release|Any CPU
+ {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{65D0603E-B96C-4DFC-BDD1-705891B88C18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{65D0603E-B96C-4DFC-BDD1-705891B88C18}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {65D0603E-B96C-4DFC-BDD1-705891B88C18}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {65D0603E-B96C-4DFC-BDD1-705891B88C18}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{65D0603E-B96C-4DFC-BDD1-705891B88C18}.Release|Any CPU.ActiveCfg = Release|Any CPU
{65D0603E-B96C-4DFC-BDD1-705891B88C18}.Release|Any CPU.Build.0 = Release|Any CPU
+ {65D0603E-B96C-4DFC-BDD1-705891B88C18}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {65D0603E-B96C-4DFC-BDD1-705891B88C18}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{707BB22C-7E5F-497A-8C2F-74578F675705}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{707BB22C-7E5F-497A-8C2F-74578F675705}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {707BB22C-7E5F-497A-8C2F-74578F675705}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {707BB22C-7E5F-497A-8C2F-74578F675705}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{707BB22C-7E5F-497A-8C2F-74578F675705}.Release|Any CPU.ActiveCfg = Release|Any CPU
{707BB22C-7E5F-497A-8C2F-74578F675705}.Release|Any CPU.Build.0 = Release|Any CPU
+ {707BB22C-7E5F-497A-8C2F-74578F675705}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {707BB22C-7E5F-497A-8C2F-74578F675705}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {2911A286-ECA4-4730-97A9-DA1FEE2DED97}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {7288C084-11C0-43BE-AC7F-45DCFEAEEBF6}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Release|Any CPU.Build.0 = Release|Any CPU
+ {F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU
+ {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Release|Any CPU.ActiveCfg = Release|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Release|Any CPU.Build.0 = Release|Any CPU
+ {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Release|Any CPU.Build.0 = Release|Any CPU
- {C4F7938F-7109-43C8-92A5-9BE47C7FF7D9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {C4F7938F-7109-43C8-92A5-9BE47C7FF7D9}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {C4F7938F-7109-43C8-92A5-9BE47C7FF7D9}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {C4F7938F-7109-43C8-92A5-9BE47C7FF7D9}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {B5989C06-4FFA-46C1-9D85-9366B34AB0A2}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{6B047E09-39C9-4583-96F3-685D84CA4117}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6B047E09-39C9-4583-96F3-685D84CA4117}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6B047E09-39C9-4583-96F3-685D84CA4117}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {6B047E09-39C9-4583-96F3-685D84CA4117}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{6B047E09-39C9-4583-96F3-685D84CA4117}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6B047E09-39C9-4583-96F3-685D84CA4117}.Release|Any CPU.Build.0 = Release|Any CPU
+ {6B047E09-39C9-4583-96F3-685D84CA4117}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {6B047E09-39C9-4583-96F3-685D84CA4117}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{3769FCC3-9AFF-4C37-97E9-6854324681DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3769FCC3-9AFF-4C37-97E9-6854324681DF}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3769FCC3-9AFF-4C37-97E9-6854324681DF}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {3769FCC3-9AFF-4C37-97E9-6854324681DF}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{3769FCC3-9AFF-4C37-97E9-6854324681DF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3769FCC3-9AFF-4C37-97E9-6854324681DF}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3769FCC3-9AFF-4C37-97E9-6854324681DF}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {3769FCC3-9AFF-4C37-97E9-6854324681DF}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {B7B593C5-FB8C-4ADA-A638-5B53B47D087E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{16BB1454-2108-40E5-B3A6-594654005303}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{16BB1454-2108-40E5-B3A6-594654005303}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {16BB1454-2108-40E5-B3A6-594654005303}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {16BB1454-2108-40E5-B3A6-594654005303}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{16BB1454-2108-40E5-B3A6-594654005303}.Release|Any CPU.ActiveCfg = Release|Any CPU
{16BB1454-2108-40E5-B3A6-594654005303}.Release|Any CPU.Build.0 = Release|Any CPU
+ {16BB1454-2108-40E5-B3A6-594654005303}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {16BB1454-2108-40E5-B3A6-594654005303}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{55C8122D-79EA-48AB-85D0-EB551FC1C427}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{55C8122D-79EA-48AB-85D0-EB551FC1C427}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {55C8122D-79EA-48AB-85D0-EB551FC1C427}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {55C8122D-79EA-48AB-85D0-EB551FC1C427}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{55C8122D-79EA-48AB-85D0-EB551FC1C427}.Release|Any CPU.ActiveCfg = Release|Any CPU
{55C8122D-79EA-48AB-85D0-EB551FC1C427}.Release|Any CPU.Build.0 = Release|Any CPU
+ {55C8122D-79EA-48AB-85D0-EB551FC1C427}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {55C8122D-79EA-48AB-85D0-EB551FC1C427}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
+ {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release|Any CPU.Build.0 = Release|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Release|Any CPU.Build.0 = Release|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -227,22 +348,25 @@ Global
{58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
- {FDA2FD2C-A708-43AC-A941-4D941B0853BF} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{B5989C06-4FFA-46C1-9D85-9366B34AB0A2} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
- {C4F7938F-7109-43C8-92A5-9BE47C7FF7D9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{6B047E09-39C9-4583-96F3-685D84CA4117} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{3769FCC3-9AFF-4C37-97E9-6854324681DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{B7B593C5-FB8C-4ADA-A638-5B53B47D087E} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{16BB1454-2108-40E5-B3A6-594654005303} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{55C8122D-79EA-48AB-85D0-EB551FC1C427} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
- {52794B40-AB8A-41AF-9EF7-799C80D6E0BC} = {E20AF96D-3F66-4065-8A89-BEE479D74536}
- {DB751004-5D49-4B88-B78F-29CA9887087D} = {E20AF96D-3F66-4065-8A89-BEE479D74536}
- {2DEFC784-F2B5-44EA-ABBB-0DCF3E689DAC} = {E20AF96D-3F66-4065-8A89-BEE479D74536}
{DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {487213C9-E8A9-4F94-85D7-28A05DBBFE3A} = {DEC8F776-49F7-4D87-836C-FE4DC057D08C}
+ {9252A8EB-ABFB-440C-AB4D-1D562753CE0F} = {487213C9-E8A9-4F94-85D7-28A05DBBFE3A}
+ {3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/README.md b/README.md
index 4da710cba4..9ccd06c165 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ Along with these ML capabilities this first release of ML.NET also brings the fi
ML.NET runs on Windows, Linux, and macOS - any platform where 64 bit [.NET Core](https://github.com/dotnet/core) or later is available.
-The current release is 0.2. Check out the [release notes](docs/release-notes/0.2/release-0.2.md).
+The current release is 0.3. Check out the [release notes](docs/release-notes/0.3/release-0.3.md).
First ensure you have installed [.NET Core 2.0](https://www.microsoft.com/net/learn/get-started) or later. ML.NET also works on the .NET Framework. Note that ML.NET currently must run in a 64 bit process.
@@ -44,9 +44,9 @@ To build ML.NET from source please visit our [developers guide](docs/project-doc
| | x64 Debug | x64 Release |
|:---|----------------:|------------------:|
-|**Linux**|[](https://ci2.dot.net/job/dotnet_machinelearning/job/master/job/linux_debug/lastCompletedBuild)|[](https://ci2.dot.net/job/dotnet_machinelearning/job/master/job/linux_release/lastCompletedBuild)|
+|**Linux**|[](https://dotnet.visualstudio.com/DotNet-Public/_build/latest?definitionId=104&branch=master)|[](https://dotnet.visualstudio.com/DotNet-Public/_build/latest?definitionId=104&branch=master)|
|**macOS**|[](https://ci2.dot.net/job/dotnet_machinelearning/job/master/job/osx10.13_debug/lastCompletedBuild)|[](https://ci2.dot.net/job/dotnet_machinelearning/job/master/job/osx10.13_release/lastCompletedBuild)|
-|**Windows**|[](https://ci2.dot.net/job/dotnet_machinelearning/job/master/job/windows_nt_debug/lastCompletedBuild)|[](https://ci2.dot.net/job/dotnet_machinelearning/job/master/job/windows_nt_release/lastCompletedBuild)|
+|**Windows**|[](https://dotnet.visualstudio.com/DotNet-Public/_build/latest?definitionId=104&branch=master)|[](https://dotnet.visualstudio.com/DotNet-Public/_build/latest?definitionId=104&branch=master)|
## Contributing
@@ -84,6 +84,9 @@ SentimentPrediction prediction = model.Predict(data);
Console.WriteLine("prediction: " + prediction.Sentiment);
```
+## Samples
+
+We have a [repo of samples](https://github.com/dotnet/machinelearning-samples) that you can look at.
## License
diff --git a/build/BranchInfo.props b/build/BranchInfo.props
index 00dbd8e318..193aff7d35 100644
--- a/build/BranchInfo.props
+++ b/build/BranchInfo.props
@@ -1,7 +1,7 @@
0
- 3
+ 4
0
preview
diff --git a/build/Dependencies.props b/build/Dependencies.props
index 5a528bcaa9..5325011f05 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -6,7 +6,7 @@
4.8.0
4.4.0
4.3.0
- 4.4.0
1.0.0-beta-62824-02
+ 2.1.2.2
diff --git a/build/ci/phase-template.yml b/build/ci/phase-template.yml
new file mode 100644
index 0000000000..bd326afb69
--- /dev/null
+++ b/build/ci/phase-template.yml
@@ -0,0 +1,37 @@
+parameters:
+ name: ''
+ buildScript: ''
+ dockerImage: ''
+ queue: {}
+
+phases:
+ - phase: ${{ parameters.name }}
+ variables:
+ _buildScript: ${{ parameters.buildScript }}
+ _phaseName: ${{ parameters.name }}
+ # if dockerImage is not equal to '' then run under docker container
+ ${{ if ne(parameters.dockerImage, '') }}:
+ _PREVIEW_VSTS_DOCKER_IMAGE: ${{ parameters.dockerImage }}
+ queue:
+ parallel: 2
+ matrix:
+ Build_Debug:
+ _configuration: Debug
+ Build_Release:
+ _configuration: Release
+ ${{ insert }}: ${{ parameters.queue }}
+ steps:
+ - script: $(_buildScript) -$(_configuration) -runtests
+ displayName: Build and Test
+ - task: PublishTestResults@2
+ displayName: Publish Test Results
+ condition: succeededOrFailed()
+ inputs:
+ testRunner: 'vSTest'
+ searchFolder: '$(System.DefaultWorkingDirectory)/bin'
+ testResultsFiles: '**/*.trx'
+ testRunTitle: Machinelearning_Tests_$(_phaseName)_$(_configuration)_$(Build.BuildNumber)
+ configuration: $(_configuration)
+ mergeTestResults: true
+ - script: $(_buildScript) -buildPackages
+ displayName: Build Packages
diff --git a/docs/building/unix-instructions.md b/docs/building/unix-instructions.md
index 0110ae6b8f..855c6b470f 100644
--- a/docs/building/unix-instructions.md
+++ b/docs/building/unix-instructions.md
@@ -42,9 +42,11 @@ macOS 10.12 or higher is needed to build dotnet/machinelearning.
On macOS a few components are needed which are not provided by a default developer setup:
* cmake 3.10.3
+* gcc
* All the requirements necessary to run .NET Core 2.0 applications. To view macOS prerequisites click [here](https://docs.microsoft.com/en-us/dotnet/core/macos-prerequisites?tabs=netcore2x).
-One way of obtaining CMake is via [Homebrew](http://brew.sh):
+One way of obtaining CMake and gcc is via [Homebrew](http://brew.sh):
```sh
$ brew install cmake
+$ brew install gcc
```
diff --git a/docs/code/EntryPoints.md b/docs/code/EntryPoints.md
new file mode 100644
index 0000000000..dbcc4e6bc9
--- /dev/null
+++ b/docs/code/EntryPoints.md
@@ -0,0 +1,231 @@
+# Entry Points And Helper Classes
+
+## Overview
+
+Entry points are a way to interface with ML.NET components, by specifying an execution graph of connected inputs and outputs of those components.
+Both the manifest describing available components and their inputs/outputs, and an "experiment" graph description, are expressed in JSON.
+The recommended way of interacting with ML.NET through other, non-.NET programming languages, is by composing, and exchanging pipelines or experiment graphs.
+
+Through the documentation, we also refer to entry points as 'entry points nodes', and that is because they are the nodes of the graph representing the experiment.
+The graph 'variables', the various values of the experiment graph JSON properties, serve to describe the relationship between the entry point nodes.
+The 'variables' are therefore the edges of the DAG (Directed Acyclic Graph).
+
+All of ML.NET entry points are described by their manifest. The manifest is another JSON object that documents and describes the structure of an entry points.
+Manifests are referenced to understand what an entry point does, and how it should be constructed, in a graph.
+
+This document briefly describes the structure of the entry points, the structure of an entry point manifest, and mentions the ML.NET classes that help construct an entry point graph.
+
+## EntryPoint manifest - the definition of an entry point
+
+The components manifest is build by scanning the ML.NET assemblies through reflection and searching for types having the: `SignatureEntryPointModule` signature in their `LoadableClass` assembly attribute definition.
+An example of an entry point manifest object, specifically for the `ColumnTypeConverter` transform, is:
+
+```javascript
+{
+ "Name": "Transforms.ColumnTypeConverter",
+ "Desc": "Converts a column to a different type, using standard conversions.",
+ "FriendlyName": "Convert Transform",
+ "ShortName": "Convert",
+ "Inputs": [
+ { "Name": "Column",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Struct",
+ "Fields": [
+ {
+ "Name": "ResultType",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [ "I1","I2","U2","I4","U4","I8","U8","R4","Num","R8","TX","Text","TXT","BL","Bool","TimeSpan","TS","DT","DateTime","DZ","DateTimeZone","UG","U16" ]
+ },
+ "Desc": "The result type",
+ "Aliases": [ "type" ],
+ "Required": false,
+ "SortOrder": 150,
+ "IsNullable": true,
+ "Default": null
+ },
+ { "Name": "Range",
+ "Type": "String",
+ "Desc": "For a key column, this defines the range of values",
+ "Aliases": [ "key" ],
+ "Required": false,
+ "SortOrder": 150,
+ "IsNullable": false,
+ "Default": null
+ },
+ { "Name": "Name",
+ "Type": "String",
+ "Desc": "Name of the new column",
+ "Aliases": [ "name" ],
+ "Required": false,
+ "SortOrder": 150,
+ "IsNullable": false,
+ "Default": null
+ },
+ { "Name": "Source",
+ "Type": "String",
+ "Desc": "Name of the source column",
+ "Aliases": [ "src" ],
+ "Required": false,
+ "SortOrder": 150,
+ "IsNullable": false,
+ "Default": null
+ }
+ ]
+ }
+ },
+ "Desc": "New column definition(s) (optional form: name:type:src)",
+ "Aliases": [ "col" ],
+ "Required": true,
+ "SortOrder": 1,
+ "IsNullable": false
+ },
+ { "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 2,
+ "IsNullable": false
+ },
+ { "Name": "ResultType",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [ "I1","I2","U2","I4","U4","I8","U8","R4","Num","R8","TX","Text","TXT","BL","Bool","TimeSpan","TS","DT","DateTime","DZ","DateTimeZone","UG","U16" ]
+ },
+ "Desc": "The result type",
+ "Aliases": [ "type" ],
+ "Required": false,
+ "SortOrder": 2,
+ "IsNullable": true,
+ "Default": null
+ },
+ { "Name": "Range",
+ "Type": "String",
+ "Desc": "For a key column, this defines the range of values",
+ "Aliases": [ "key" ],
+ "Required": false,
+ "SortOrder": 150,
+ "IsNullable": false,
+ "Default": null
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": ["ITransformInput" ],
+ "OutputKind": [ "ITransformOutput" ]
+}
+```
+
+The respective entry point, constructed based on this manifest would be:
+
+```javascript
+ {
+ "Name": "Transforms.ColumnTypeConverter",
+ "Inputs": {
+ "Column": [{
+ "Name": "Features",
+ "Source": "Features"
+ }],
+ "Data": "$data0",
+ "ResultType": "R4"
+ },
+ "Outputs": {
+ "OutputData": "$Convert_Output",
+ "Model": "$Convert_TransformModel"
+ }
+ }
+```
+
+## `EntryPointGraph`
+
+This class encapsulates the list of nodes (`EntryPointNode`) and edges
+(`EntryPointVariable` inside a `RunContext`) of the graph.
+
+## `EntryPointNode`
+
+This class represents a node in the graph, and wraps an entry point call. It
+has methods for creating and running entry points. It also has a reference to
+the `RunContext` to allow it to get and set values from `EntryPointVariable`s.
+
+To express the inputs that are set through variables, a set of dictionaries
+are used. The `InputBindingMap` maps an input parameter name to a list of
+`ParameterBinding`s. The `InputMap` maps a `ParameterBinding` to a
+`VariableBinding`. For example, if the JSON looks like this:
+
+```javascript
+'foo': '$bar'
+```
+
+the `InputBindingMap` will have one entry that maps the string "foo" to a list
+that has only one element, a `SimpleParameterBinding` with the name "foo" and
+the `InputMap` will map the `SimpleParameterBinding` to a
+`SimpleVariableBinding` with the name "bar". For a more complicated example,
+let's say we have this JSON:
+
+```javascript
+'foo': [ '$bar[3]', '$baz']
+```
+
+the `InputBindingMap` will have one entry that maps the string "foo" to a list
+that has two elements, an `ArrayIndexParameterBinding` with the name "foo" and
+index 0 and another one with index 1. The `InputMap` will map the first
+`ArrayIndexParameterBinding` to an `ArrayIndexVariableBinding` with name "bar"
+and index 3 and the second `ArrayIndexParameterBinding` to a
+`SimpleVariableBinding` with the name "baz".
+
+For outputs, a node assumes that an output is mapped to a variable, so the
+`OutputMap` is a simple dictionary from string to string.
+
+## `EntryPointVariable`
+
+This class represents an edge in the entry point graph. It has a name, a type
+and a value. Variables can be simple, arrays and/or dictionaries. Currently,
+only data views, file handles, predictor models and transform models are
+allowed as element types for a variable.
+
+## `RunContext`
+
+This class is just a container for all the variables in a graph.
+
+## `VariableBinding` and Derived Classes
+
+The abstract base class represents a "pointer to a (part of a) variable". It
+is used in conjunction with `ParameterBinding`s to specify inputs to an entry
+point node. The `SimpleVariableBinding` is a pointer to an entire variable,
+the `ArrayIndexVariableBinding` is a pointer to a specific index in an array
+variable, and the `DictionaryKeyVariableBinding` is a pointer to a specific
+key in a dictionary variable.
+
+## `ParameterBinding` and Derived Classes
+
+The abstract base class represents a "pointer to a (part of a) parameter". It
+parallels the `VariableBinding` hierarchy and it is used to specify the inputs
+to an entry point node. The `SimpleParameterBinding` is a pointer to a
+non-array, non-dictionary parameter, the `ArrayIndexParameterBinding` is a
+pointer to a specific index of an array parameter and the
+`DictionaryKeyParameterBinding` is a pointer to a specific key of a dictionary
+parameter.
+
+## How to create an entry point for an existing ML.NET component
+
+The steps to take, to create an entry point for an existing ML.NET component, are:
+1. Add the `SignatureEntryPointModule` signature to the `LoadableClass` assembly attribute.
+2. Create a public static method, that:
+ a. Takes as input, among others, an object representing the arguments of the component you want to expose.
+ b. Initializes and run the components, returning one of the nested classes of `Microsoft.ML.Runtime.EntryPoints.CommonOutputs`
+ c. Is annotated with the `TlcModule.EntryPoint` attribute
+
+Based on the type of entry point being created, there are further conventions on the name of the method, for example, the Trainers entry points are typically called: 'TrainMultiClass', 'TrainBinary' etc, based on the task.
+Look at [OnlineGradientDescent](../../src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs) for an example of a component and its entry point.
\ No newline at end of file
diff --git a/docs/code/GraphRunner.md b/docs/code/GraphRunner.md
new file mode 100644
index 0000000000..b7fddc9476
--- /dev/null
+++ b/docs/code/GraphRunner.md
@@ -0,0 +1,128 @@
+# Entry Point JSON Graph format
+
+The entry point graph in ML.NET is an array of _nodes_. More information about the definition of entry points and classes that help construct entry point graphs
+can be found in the [EntryPoint.md document](./EntryPoints.md).
+
+Each node is an object with the following fields:
+
+- _name_: string. Required. Name of the entry point.
+- _inputs_: object. Optional. Specifies non-default inputs to the entry point.
+Note that if the entry point has required inputs (which is very common), the _inputs_ field is required.
+- _outputs_: object. Optional. Specifies the variables that will hold the node's outputs.
+
+## Input and output types
+The following types are supported in JSON graphs:
+
+- `string`. Represented as a JSON string, maps to a C# string.
+- `float`. Represented as a JSON float, maps to a C# float or double.
+- `bool`. Represented as a JSON bool, maps to a C# bool.
+- `enum`. Represented as a JSON string, maps to a C# enum. The allowed values are those of the C# enum (they are also listed in the manifest).
+- `int`. Represented as a JSON integer, maps to a C# int or long.
+- `array` of the above. Represented as a JSON array, maps to a C# array.
+- `dictionary`. Currently not implemented. Represented as a JSON object, maps to a C# `Dictionary`.
+- `component`. Represented as a JSON object with 2 fields: _name_:string and _settings_:object.
+
+## Variables
+The following input/output types can not be represented as a JSON value:
+- `IDataView`
+- `IFileHandle`
+- `ITransformModel`
+- `IPredictorModel`
+
+These must be passed as _variables_. The variable is represented as a JSON string that begins with `$`.
+Note the following rules:
+
+- A variable can appear in the _outputs_ only once per graph. That is, the variable can be 'assigned' only once.
+- If the variable is present in _inputs_ of one node and in the _outputs_ of another node, this signifies a graph 'edge'.
+The same variable can participate in many edges.
+- If the variable is present only in _inputs_, but never in _outputs_, it is a _graph input_. All graph inputs must be provided before
+a graph can be run.
+- The variable has a type, which is the type of inputs (and, optionally, output) that it appears in. If the type of the variable is
+ambiguous, ML.NET throws an exception.
+- Circular references. The experiment graph is expected to be a DAG. If the circular dependency is detected, ML.NET throws an exception.
+_Currently, this is done lazily: if we couldn't ever run a node because it's waiting for inputs, we throw._
+
+### Variables for arrays and dictionaries.
+It is allowed to define variables for arrays and dictionaries, as long as the item types are valid variable types (the four types listed above).
+They are treated the same way as regular 'scalar' variables.
+
+If we want to reference an item of the collection, we can use the `[]` syntax:
+- `$var[5]` denotes 5th element of an array variable.
+- `$var[foo]` and `$var['foo']` both denote the element with key 'foo' of a dictionary variable.
+_This is not yet implemented._
+
+Conversely, if we want to build a collection (array or dictionary) of variables, we can do it using JSON arrays and objects:
+- `["$v1", "$v2", "$v3"]` denotes an array containing 3 variables.
+- `{"foo": "$v1", "bar": "$v2"}` denotes a collection containing 2 key-value pairs.
+_This is also not yet implemented._
+
+## Example of a JSON entry point manifest object, and the respective entry point graph node
+Let's consider the following manifest snippet, describing an entry point _'CVSplit.Split'_:
+
+```javascript
+ {
+ "name": "CVSplit.Split",
+ "desc": "Split the dataset into the specified number of cross-validation folds (train and test sets)",
+ "inputs": [
+ {
+ "name": "Data",
+ "type": "DataView",
+ "desc": "Input dataset",
+ "required": true
+ },
+ {
+ "name": "NumFolds",
+ "type": "Int",
+ "desc": "Number of folds to split into",
+ "required": false,
+ "default": 2
+ },
+ {
+ "name": "StratificationColumn",
+ "type": "String",
+ "desc": "Stratification column",
+ "aliases": [
+ "strat"
+ ],
+ "required": false,
+ "default": null
+ }
+ ],
+ "outputs": [
+ {
+ "name": "TrainData",
+ "type": {
+ "kind": "Array",
+ "itemType": "DataView"
+ },
+ "desc": "Training data (one dataset per fold)"
+ },
+ {
+ "name": "TestData",
+ "type": {
+ "kind": "Array",
+ "itemType": "DataView"
+ },
+ "desc": "Testing data (one dataset per fold)"
+ }
+ ]
+ }
+```
+
+As we can see, the entry point has 3 inputs (one of them required), and 2 outputs.
+The following is a correct graph containing call to this entry point:
+
+```javascript
+{
+ "nodes": [
+ {
+ "name": "CVSplit.Split",
+ "inputs": {
+ "Data": "$data1"
+ },
+ "outputs": {
+ "TrainData": "$cv"
+ }
+ }]
+}
+```
\ No newline at end of file
diff --git a/docs/code/SchemaComprehension.md b/docs/code/SchemaComprehension.md
new file mode 100644
index 0000000000..37238e0c0c
--- /dev/null
+++ b/docs/code/SchemaComprehension.md
@@ -0,0 +1,222 @@
+# Schema comprehension in ML.NET
+
+This document describes in detail the under-the-hood mechanism that ML.NET uses to automate the creation of `IDataView` schema, with the goal to make it as convenient to the end user as possible, while not incurring extra computational costs.
+
+For a better understanding of `IDataView` principles and type system please refer to:
+* [IDataView Design Principles](IDataViewDesignPrinciples.md)
+* [IDataView Type System](IDataViewTypeSystem.md)
+
+## Introduction
+
+Every dataset in ML.NET is represented as an `IDataView`, which is, for the purposes of this document, a collection of rows that share the same columns. The set of columns, their names, types and other metadata is known as the *schema* of the `IDataView`, and it's represented as an `ISchema` object.
+
+In this document, we will be using the terms *data view* and `IDataView` interchangeably, same for *schema* and `ISchema`.
+
+Before any new data enters ML.NET, the user needs to somehow define how the schema of the data will look like.
+To do this, the following questions need to be answered:
+- What are the column names?
+- What are their types?
+- What other metadata is associated with the columns?
+
+These items above are very similar to the definition of fields in a C# class: names and types of columns correspond to names and types of fields, and metadata can correspond to field attributes.
+Because of this similarity, ML.NET offers a common convenient mechanism for creating a schema: it is done via defining a C# class.
+
+For example, the below class definition can be used to define a data view with 5 float columns:
+```C#
+public class IrisData
+{
+ public float Label;
+ public float SepalLength;
+ public float SepalWidth;
+ public float PetalLength;
+ public float PetalWidth;
+}
+```
+
+## Using schema comprehension to make a data view and to read a data view
+
+The first obvious benefit of schema comprehension is that we can now create `IDataView`s out of in-memory enumerables of user-defined 'data types', without having to define the schema.
+It works in the other direction too: you can take an `IDataView`, and read it as an `IEnumerable` of user-defined 'data type' (which will fail if the user-provided schema does not match the real schema).
+
+Let's see how we can create a new `IDataView` out of an in-memory array, run some operations on it, and then read it back into the array.
+
+```C#
+public class IrisData
+{
+ public float Label;
+ public float SepalLength;
+ public float SepalWidth;
+ public float PetalLength;
+ public float PetalWidth;
+}
+
+public class IrisVectorData
+{
+ public float Label;
+ public float[] Features;
+}
+
+static void Main(string[] args)
+{
+ // Here's a data array that we want to work on.
+ var dataArray = new[] {
+ new IrisData{Label=1, PetalLength=1, SepalLength=1, PetalWidth=1, SepalWidth=1},
+ new IrisData{Label=0, PetalLength=2, SepalLength=2, PetalWidth=2, SepalWidth=2}
+ };
+
+ // Create the ML.NET environment.
+ var env = new Microsoft.ML.Runtime.Data.TlcEnvironment();
+
+ // Create the data view.
+ // This method will use the definition of IrisData to understand what columns there are in the
+ // data view.
+ var dv = env.CreateDataView(dataArray);
+
+ // Now let's do something to the data view. For example, concatenate all four non-label columns
+ // into 'Features' column.
+ dv = new Microsoft.ML.Runtime.Data.ConcatTransform(env, dv, "Features",
+ "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
+
+ // Read the data into an another array, this time we read the 'Features' and 'Label' columns
+ // of the data, and ignore the rest.
+ // This method will use the definition of IrisVectorData to understand which columns and of which types
+ // are expected to be present in the input data.
+ var arr = dv.AsEnumerable(env, reuseRowObject: false)
+ .ToArray();
+}
+```
+After this code runs, `arr` will contain two `IrisVectorData` objects, each having `Features` filled with the actual values of the features (the 4 concatenated columns).
+
+### Streaming data views
+
+What if the original data doesn't support seeking, like if it's some form of `IEnumerable` instead of `IList`? Well, we can simply use another helper function:
+```C#
+var streamingDv = env.CreateStreamingDataView(dataEnumerable);
+```
+The only subtle difference is, the resulting `streamingDv` will not support shuffling (a property that's useful to some ML application).
+
+### AsCursorable and reuseRowObject parameter
+
+When you read a data view as `AsEnumerable`, ML.NET will create and populate an object per row. If you do not need multiple row objects to exist in memory (for example, you are writing them to disk one by one, as you scan through the `IEnumerable`), you may want to set `reuseRowObject` to `true`. This will make ML.NET create *only one row object for the entire data view* when you enumerate it, and just re-populate the values every time.
+
+Obviously, in the example above this would lead to incorrect behavior, as the `arr` variable will hold two copies of the same `IrisVectorData` object. Please consider carefully whether you want to reuse the row object, because it is more efficient, but can lead to hard to find issues.
+
+Sometimes, we don't even want to *populate* the row object per row. For example, we only want to see every 100th row of the data, so there's no need to populate the remaining 99% row objects. In this case, you can use `AsCursorable` method:
+
+```C#
+var cursorable = dv.AsCursorable(env);
+// You can create as many simultaneous cursors as you like, they are independent.
+using (var cursor = cursorable.GetCursor())
+{
+ // We are now in charge of creating the row object.
+ var myRow = new IrisVectorData();
+ while (cursor.MoveNext())
+ {
+ if (cursor.Position % 100 == 99)
+ {
+ // Populate the values of the row object.
+ cursor.FillValues(myRow);
+ // Do something to the row.
+ }
+ }
+}
+```
+Please note that **cursors are not thread-safe**: they have mutable state inside, and they are meant to be used by one thread. If you want to read the data in parallel, use multiple cursors.
+
+## PredictionEngine and PredictorModel
+
+ML.NET's `PredictionEngine` is attempting to turn a sequence of data transforms (maybe capped by a predictor, but not necessarily) into a 'black box' that takes strongly typed inputs and returns strongly typed outputs. The name is a little misleading: the `PredictionEngine` object doesn't require a predictor to be present in the pipeline, it can be just a sequence of transforms like in the below example:
+
+```C#
+var engine = env.CreatePredictionEngine(dv);
+var output = engine.Predict(new IrisData { Label = 1, PetalLength = 1, SepalLength = 1, PetalWidth = 1, SepalWidth = 1 });
+```
+It is important to note that the `PredictionEngine` actually *validates* that the 'pipeline' conforms to the input and output schema requirements when it is created.
+
+The same can be said about the `PredictorModel`. This is a somewhat more restricted version of `PredictionEngine` that is created by `LearningPipeline.Train`.
+
+Please note that **`PredictionEngine` and `PredictorModel` are not thread-safe**: they hold an internal cursor object, and therefore cannot be used in a re-entrant fashion.
+If you ever see the error message that says: `An attempt was made to keep iterating after the pipe has been reset`, it most likely means that ML.NET has detected a race condition on the `PredictionEngine`.
+
+## Type system mapping
+
+`IDataView` [type system](IDataViewTypeSystem.md) differs slightly from the C# type system, so a 1-1 mapping between column types and C# types is not always feasible.
+Below are the most notable examples of the differences:
+
+* `IDataView` vector columns often have a fixed (and known) size. The C# array type best corresponds to a 'variable size' vector: the one that can have different number of slots on every row. You can use `[VectorType(N)]` attribute to an array field to specify that the column is a vector of fixed size N. This is often necessary: most ML components don't work with variable-size vectors, they require fixed-size ones.
+* `IDataView`'s [key types](IDataViewTypeSystem.md#key-types) don't have a natural underlying C# type either. To declare a key-type column, you need to make your field an `uint`, and decorate it with `[KeyType]` to denote that the field is a key, and not a regular unsigned integer.
+
+### Full list of type mappings
+The below table illustrates what C# types are mapped to what `IDataView` types:
+
+| `IDataView` type | C# type | C# type with extra conversion |
+| ---------------- | ----------- | ------------------------------ |
+| `I1` | `DvInt1` | `sbyte`, `sbyte?` |
+| `I2` | `DvInt2` | `short`, `short?` |
+| `I4` | `DvInt4` | `int`, `int?` |
+| `I8` | `DvInt8` | `long`, `long?` |
+| `U1` | `byte` | `byte?` |
+| `U2` | `ushort` | `ushort?` |
+| `U4` | `uint` | `uint?` |
+| `U8` | `ulong` | `ulong?` |
+| `UG` | `UInt128` | |
+| `R4` | `float` | `float?` |
+| `R8` | `double` | `double?` |
+| `TX` | `DvText`, `string` | |
+| `BL` | `DvBool` | `bool`, `bool?` |
+| `TS` | `DvTimeSpan` | |
+| `DT` | `DvDateTime` | |
+| `DZ` | `DvDateTimeZone` | |
+| Variable-size vector | `VBuffer` | `T[]`, and the vector is always dense |
+| Fixed-size vector | `VBuffer` with `[VectorType(N)]` | `T[]` with `VectorType(N)`, and the vector is always dense |
+| Key type | `uint` with `[KeyType]` | |
+
+### Additional attributes to affect type mapping
+
+There are two more attributes that can affect the way ML.NET conducts schema comprehension:
+* `[ColumnName]` lets you choose a different name for the `IDataView` column. By default it is the same as field name.
+ * This is a way to create or read back an `IDataView` column with a name containing 'invalid' characters (like whitespace).
+* `[NoColumn]` is an attribute that denotes that the below field should not be mapped to a column.
+
+### Using SchemaDefinition for run-time type mapping hints
+
+As you can see from the table and notes above, certain `IDataView` types can only be denoted with an additional field attribute. If the type parameters are not known at compile time (like the size of the fixed-size vector), this is tricky.
+
+You can use a `SchemaDefinition` object to re-map a type to an `IDataView` schema programmatically. It gives you the same powers as the attributes, but at runtime.
+Please see the below example.
+```C#
+// Vector size is only known at runtime.
+int numberOfFeatures = 4;
+
+// Create the default schema definition.
+var schemaDef = SchemaDefinition.Create(typeof(IrisVectorData));
+
+// Specify the right vector size.
+schemaDef["Features"].ColumnType = new VectorType(NumberType.R4, numberOfFeatures);
+
+// Create a data view.
+var dataView = env.CreateDataView(arr, schemaDef);
+
+// Create a prediction engine. You can add custom input and output schema definitions there.
+var predictionEngine = env.CreatePredictionEngine(dv, outputSchemaDefinition: schemaDef);
+```
+
+In addition to the above, you can use `SchemaDefinition` to add per-column metadata:
+```C#
+// Add column metadata.
+schemaDef["Label"].AddMetadata(MetadataUtils.Kinds.HasMissingValues, false);
+```
+
+## Limitations
+
+Certain things are not possible to do at all using the schema comprehensions, but are possible via the native `IDataView` programmatic interface.
+It was our design decision to not allow these scenarios, thus simplifying the other, more common scenarios.
+
+Here is the list of things that are only possible via the low-level interface:
+* Creating or reading a data view, where even column *types* are not known at compile time (so you cannot create a C# class to define the schema)
+ * This can happen if you write a general-purpose machine learning tool that can ingest different kinds of datasets.
+* Reading a subset of columns that differs from one row to another: the cursor always populates the entire row object.
+* Reading column metadata from the data view.
+* Accessing the 'hidden' data view columns by index.
+ * Hidden columns are those that have the same name as other columns and a smaller index. They are not accessible by name.
+* Creating 'cursor sets': this is a feature that lets you iterate over data in multiple parallel threads by splitting the data between multiple 'sibling' cursors.
diff --git a/docs/release-notes/0.3/release-0.3.md b/docs/release-notes/0.3/release-0.3.md
new file mode 100644
index 0000000000..6b88d37f58
--- /dev/null
+++ b/docs/release-notes/0.3/release-0.3.md
@@ -0,0 +1,114 @@
+# ML.NET 0.3 Release Notes
+
+Today we are releasing ML.NET 0.3. This release focuses on adding components
+to ML.NET from the internal codebase (such as Factorization Machines,
+LightGBM, Ensembles, and LightLDA), enabling export to the ONNX model format,
+and bug fixes.
+
+### Installation
+
+ML.NET supports Windows, MacOS, and Linux. See [supported OS versions of .NET
+Core
+2.0](https://github.com/dotnet/core/blob/master/release-notes/2.0/2.0-supported-os.md)
+for more details.
+
+You can install ML.NET NuGet from the CLI using:
+```
+dotnet add package Microsoft.ML
+```
+
+From package manager:
+```
+Install-Package Microsoft.ML
+```
+
+### Release Notes
+
+Below are some of the highlights from this release.
+
+* Added Field-Aware Factorization Machines (FFM) as a learner for binary
+ classification (#383)
+
+ * FFM is useful for various large sparse datasets, especially in areas
+ such as recommendations and click prediction. It has been used to win
+ various click prediction competitions such as the [Criteo Display
+ Advertising Challenge on
+ Kaggle](https://www.kaggle.com/c/criteo-display-ad-challenge). You can
+ learn more about the winning solution
+ [here](https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf).
+ * FFM is a streaming learner so it does not require the entire dataset to
+ fit in memory.
+ * You can learn more about FFM
+ [here](http://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf) and some of the
+ speedup approaches that are used in ML.NET
+ [here](https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf).
+
+* Added [LightGBM](https://github.com/Microsoft/LightGBM) as a learner for
+ binary classification, multiclass classification, and regression (#392)
+
+ * LightGBM is a tree based gradient boosting machine. It is under the
+ umbrella of the [DMTK](http://github.com/microsoft/dmtk) project at
+ Microsoft.
+ * The LightGBM repository shows various [comparison
+ experiments](https://github.com/Microsoft/LightGBM/blob/6488f319f243f7ff679a8e388a33e758c5802303/docs/Experiments.rst#comparison-experiment)
+ that show good accuracy and speed, so it is a great learner to try out.
+ It has also been used in winning solutions in various [ML
+ challenges](https://github.com/Microsoft/LightGBM/blob/a6e878e2fc6e7f545921cbe337cc511fbd1f500d/examples/README.md).
+ * This addition wraps LightGBM and exposes it in ML.NET.
+ * Note that LightGBM can also be used for ranking, but the ranking
+ evaluator is not yet exposed in ML.NET.
+
+* Added Ensemble learners for binary classification, multiclass
+ classification, and regression (#379)
+
+ * [Ensemble learners](https://en.wikipedia.org/wiki/Ensemble_learning)
+ enable using multiple learners in one model. As an example, the Ensemble
+ learner could train both `FastTree` and `AveragedPerceptron` and average
+ their predictions to get the final prediction.
+ * Combining multiple models of similar statistical performance may lead to
+ better performance than each model separately.
+
+* Added LightLDA transform for topic modeling (#377)
+
+ * LightLDA is an implementation of [Latent Dirichlet
+ Allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation)
+ which infers topical structure from text data.
+ * The implementation of LightLDA in ML.NET is based on [this
+ paper](https://arxiv.org/abs/1412.1576). There is a distributed
+ implementation of LightLDA
+ [here](https://github.com/Microsoft/lightlda).
+
+* Added One-Versus-All (OVA) learner for multiclass classification (#363)
+
+ * [OVA](https://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest)
+ (sometimes known as One-Versus-Rest) is an approach to using binary
+ classifiers in multiclass classification problems.
+ * While some binary classification learners in ML.NET natively support
+ multiclass classification (e.g. Logistic Regression), there are others
+ that do not (e.g. Averaged Perceptron). OVA enables using the latter
+ group for multiclass classification as well.
+
+* Enabled export of ML.NET models to the [ONNX](https://onnx.ai/) format
+ (#248)
+
+ * ONNX is a common format for representing deep learning models (also
+ supporting certain other types of models) which enables developers to
+ move models between different ML toolkits.
+ * ONNX models can be used in [Windows
+ ML](https://docs.microsoft.com/en-us/windows/uwp/machine-learning/overview)
+ which enables evaluating models on Windows 10 devices and taking
+ advantage of capabilities like hardware acceleration.
+ * Currently, only a subset of ML.NET components can be used in a model
+ that is converted to ONNX.
+
+Additional issues closed in this milestone can be found
+[here](https://github.com/dotnet/machinelearning/milestone/2?closed=1).
+
+### Acknowledgements
+
+Shoutout to [pkulikov](https://github.com/pkulikov),
+[veikkoeeva](https://github.com/veikkoeeva),
+[ross-p-smith](https://github.com/ross-p-smith),
+[jwood803](https://github.com/jwood803),
+[Nepomuceno](https://github.com/Nepomuceno), and the ML.NET team for their
+contributions as part of this release!
diff --git a/netci.groovy b/netci.groovy
index eb78c766a5..b955bf669f 100644
--- a/netci.groovy
+++ b/netci.groovy
@@ -6,46 +6,22 @@ import jobs.generation.Utilities;
def project = GithubProject
def branch = GithubBranchName
-['Windows_NT', 'Linux', 'OSX10.13'].each { os ->
+['OSX10.13'].each { os ->
['Debug', 'Release'].each { config ->
[true, false].each { isPR ->
// Calculate job name
def jobName = os.toLowerCase() + '_' + config.toLowerCase()
- def buildFile = '';
def machineAffinity = 'latest-or-auto'
- // Calculate the build command
- if (os == 'Windows_NT') {
- buildFile = ".\\build.cmd"
- } else {
- buildFile = "./build.sh"
- }
-
- def buildCommand = buildFile + " -$config -runtests"
- def packCommand = buildFile + " -buildPackages"
-
def newJob = job(Utilities.getFullJobName(project, jobName, isPR)) {
steps {
- if (os == 'Windows_NT') {
- batchFile(buildCommand)
- batchFile(packCommand)
- }
- else {
- // Shell
- shell(buildCommand)
- shell(packCommand)
- }
+ shell("./build.sh -$config -runtests")
+ shell("./build.sh -buildPackages")
}
}
- def osImageName = os
- if (os == 'Linux') {
- // Trigger a portable Linux build that runs on RHEL7.2
- osImageName = "RHEL7.2"
- }
-
- Utilities.setMachineAffinity(newJob, osImageName, machineAffinity)
+ Utilities.setMachineAffinity(newJob, os, machineAffinity)
Utilities.standardJobSetup(newJob, project, isPR, "*/${branch}")
if (isPR) {
diff --git a/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj b/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj
new file mode 100644
index 0000000000..918729d99d
--- /dev/null
+++ b/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj
@@ -0,0 +1,14 @@
+
+
+
+ netstandard2.0
+ netstandard2.0;netcoreapp3.0
+ Microsoft.ML.CpuMath contains optimized math routines for ML.NET.
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.symbols.nupkgproj b/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.symbols.nupkgproj
new file mode 100644
index 0000000000..360980c2c3
--- /dev/null
+++ b/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.LightGBM/Microsoft.ML.LightGBM.nupkgproj b/pkg/Microsoft.ML.LightGBM/Microsoft.ML.LightGBM.nupkgproj
new file mode 100644
index 0000000000..8cddd1719e
--- /dev/null
+++ b/pkg/Microsoft.ML.LightGBM/Microsoft.ML.LightGBM.nupkgproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ ML.NET component for LightGBM
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.LightGBM/Microsoft.ML.LightGBM.symbols.nupkgproj b/pkg/Microsoft.ML.LightGBM/Microsoft.ML.LightGBM.symbols.nupkgproj
new file mode 100644
index 0000000000..d7710ff60a
--- /dev/null
+++ b/pkg/Microsoft.ML.LightGBM/Microsoft.ML.LightGBM.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj
new file mode 100644
index 0000000000..bcc86939e2
--- /dev/null
+++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ ML.NET component for exporting ONNX Models
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj
new file mode 100644
index 0000000000..07807bb54b
--- /dev/null
+++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
index b0cedd6ad4..fc409ae21f 100644
--- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
+++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
@@ -6,12 +6,12 @@
-
+
+
-
diff --git a/pkg/Microsoft.ML/build/Microsoft.ML.props b/pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.props
similarity index 87%
rename from pkg/Microsoft.ML/build/Microsoft.ML.props
rename to pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.props
index 3970c4d664..fa15d15eb5 100644
--- a/pkg/Microsoft.ML/build/Microsoft.ML.props
+++ b/pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.props
@@ -7,7 +7,7 @@
-
PreserveNewest
false
diff --git a/pkg/Microsoft.ML/build/Microsoft.ML.targets b/pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.targets
similarity index 100%
rename from pkg/Microsoft.ML/build/Microsoft.ML.targets
rename to pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.targets
diff --git a/pkg/_._ b/pkg/_._
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/Directory.Build.props b/src/Directory.Build.props
index cedfa39442..113da3575a 100644
--- a/src/Directory.Build.props
+++ b/src/Directory.Build.props
@@ -11,7 +11,18 @@
$(NoWarn);1591
$(WarningsNotAsErrors);1591
-
+ $(MSBuildThisFileDirectory)\Source.ruleset
+
+
+ false
+ Analyzer
+
+
+
+
diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs
index 5b5936b5a8..8b8cb5871b 100644
--- a/src/Microsoft.ML.Api/ApiUtils.cs
+++ b/src/Microsoft.ML.Api/ApiUtils.cs
@@ -6,7 +6,6 @@
using System.Reflection;
using System.Reflection.Emit;
using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
namespace Microsoft.ML.Runtime.Api
{
@@ -19,11 +18,12 @@ internal static class ApiUtils
private static OpCode GetAssignmentOpCode(Type t)
{
// REVIEW: This should be a Dictionary based solution.
- // DvTexts, strings, arrays, and VBuffers.
+ // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128.
if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) ||
- t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
- (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) ||
- t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128))
+ t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
+ (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) ||
+ (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) ||
+ t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128))
{
return OpCodes.Stobj;
}
@@ -46,7 +46,7 @@ private static OpCode GetAssignmentOpCode(Type t)
///
/// Each of the specialized 'peek' methods copies the appropriate field value of an instance of T
- /// into the provided buffer. So, the call is 'peek(userObject, ref destination)' and the logic is
+ /// into the provided buffer. So, the call is 'peek(userObject, ref destination)' and the logic is
/// indentical to 'destination = userObject.##FIELD##', where ##FIELD## is defined per peek method.
///
internal static Delegate GeneratePeek(InternalSchemaDefinition.Column column)
@@ -83,7 +83,7 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op
///
/// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T
- /// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is
+ /// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is
/// indentical to 'userObject.##FIELD## = providedValue', where ##FIELD## is defined per poke method.
///
internal static Delegate GeneratePoke(InternalSchemaDefinition.Column column)
diff --git a/src/Microsoft.ML.Api/CodeGenerationUtils.cs b/src/Microsoft.ML.Api/CodeGenerationUtils.cs
index 74f262c57c..7af0fb85ed 100644
--- a/src/Microsoft.ML.Api/CodeGenerationUtils.cs
+++ b/src/Microsoft.ML.Api/CodeGenerationUtils.cs
@@ -97,12 +97,12 @@ public static string GetCSharpString(CSharpCodeProvider codeProvider, string val
}
///
- /// Gets the C# strings representing the type name for a variable corresponding to
- /// the column type.
- ///
- /// If the type is a vector, then controls whether the array field is
+ /// Gets the C# strings representing the type name for a variable corresponding to
+ /// the column type.
+ ///
+ /// If the type is a vector, then controls whether the array field is
/// generated or .
- ///
+ ///
/// If additional attributes are required, they are appended to the list.
///
private static string GetBackingTypeName(ColumnType colType, bool useVBuffer, List attributes)
diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs
index 73cfbf91a5..3080a8197c 100644
--- a/src/Microsoft.ML.Api/ComponentCreation.cs
+++ b/src/Microsoft.ML.Api/ComponentCreation.cs
@@ -11,14 +11,14 @@
namespace Microsoft.ML.Runtime.Api
{
///
- /// This class defines extension methods for an to facilitate creating
+ /// This class defines extension methods for an to facilitate creating
/// components (loaders, transforms, trainers, scorers, evaluators, savers).
///
public static class ComponentCreation
{
///
/// Create a new data view which is obtained by appending all columns of all the source data views.
- /// If the data views are of different length, the resulting data view will have the length equal to the
+ /// If the data views are of different length, the resulting data view will have the length equal to the
/// length of the shortest source.
///
/// The host environment to use.
@@ -52,18 +52,18 @@ public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView
env.CheckValueOrNull(weight);
env.CheckValueOrNull(custom);
- return TrainUtils.CreateExamples(data, label, features, group, weight, name: null, custom: custom);
+ return new RoleMappedData(data, label, features, group, weight, name: null, custom: custom);
}
///
/// Create a new over an in-memory collection of the items of user-defined type.
/// The user maintains ownership of the and the resulting data view will
/// never alter the contents of the .
- /// Since is assumed to be immutable, the user is expected to not
+ /// Since is assumed to be immutable, the user is expected to not
/// modify the contents of while the data view is being actively cursored.
- ///
+ ///
/// One typical usage for in-memory data view could be: create the data view, train a predictor.
- /// Once the predictor is fully trained, modify the contents of the underlying collection and
+ /// Once the predictor is fully trained, modify the contents of the underlying collection and
/// train another predictor.
///
/// The user-defined item type.
@@ -88,9 +88,9 @@ public static IDataView CreateDataView(this IHostEnvironment env, IList
is assumed to be immutable, the user is expected to support
/// multiple enumeration of the that would return the same results, unless
/// the user knows that the data will only be cursored once.
- ///
+ ///
/// One typical usage for streaming data view could be: create the data view that lazily loads data
- /// as needed, then apply pre-trained transformations to it and cursor through it for transformation
+ /// as needed, then apply pre-trained transformations to it and cursor through it for transformation
/// results. This is how is implemented.
///
/// The user-defined item type.
@@ -191,7 +191,7 @@ public static PredictionEngine CreatePredictionEngine(th
///
/// Create a prediction engine.
/// This encapsulates the 'classic' prediction problem, where the input is denoted by the float array of features,
- /// and the output is a float score. For binary classification predictors that can output probability, there are output
+ /// and the output is a float score. For binary classification predictors that can output probability, there are output
/// fields that report the predicted label and probability.
///
/// The host environment to use.
@@ -207,7 +207,7 @@ public static SimplePredictionEngine CreateSimplePredictionEngine(this IHostEnvi
///
/// Load the transforms (but not loader) from the model steram and apply them to the specified data.
- /// It is acceptable to have no transforms in the model stream: in this case the original
+ /// It is acceptable to have no transforms in the model stream: in this case the original
/// will be returned.
///
/// The host environment to use.
diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
index f185dc6c0b..e940ea9d4d 100644
--- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
+++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
@@ -119,7 +119,7 @@ private Delegate CreateGetter(int index)
var column = DataView._schema.SchemaDefn.Columns[index];
var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType;
-
+ var genericType = outputType;
Func del;
if (outputType.IsArray)
@@ -129,11 +129,66 @@ private Delegate CreateGetter(int index)
if (outputType.GetElementType() == typeof(string))
{
Ch.Assert(colType.ItemType.IsText);
- return CreateStringArrayToVBufferGetter(index);
+ return CreateConvertingArrayGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x));
+ }
+ else if (outputType.GetElementType() == typeof(int))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I4);
+ return CreateConvertingArrayGetterDelegate(index, x => x);
+ }
+ else if (outputType.GetElementType() == typeof(int?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I4);
+ return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt4.NA);
+ }
+ else if (outputType.GetElementType() == typeof(long))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I8);
+ return CreateConvertingArrayGetterDelegate(index, x => x);
+ }
+ else if (outputType.GetElementType() == typeof(long?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I8);
+ return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt8.NA);
+ }
+ else if (outputType.GetElementType() == typeof(short))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I2);
+ return CreateConvertingArrayGetterDelegate(index, x => x);
+ }
+ else if (outputType.GetElementType() == typeof(short?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I2);
+ return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt2.NA);
+ }
+ else if (outputType.GetElementType() == typeof(sbyte))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I1);
+ return CreateConvertingArrayGetterDelegate(index, x => x);
}
+ else if (outputType.GetElementType() == typeof(sbyte?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I1);
+ return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt1.NA);
+ }
+ else if (outputType.GetElementType() == typeof(bool))
+ {
+ Ch.Assert(colType.ItemType.IsBool);
+ return CreateConvertingArrayGetterDelegate(index, x => x);
+ }
+ else if (outputType.GetElementType() == typeof(bool?))
+ {
+ Ch.Assert(colType.ItemType.IsBool);
+ return CreateConvertingArrayGetterDelegate(index, x => x ?? DvBool.NA);
+ }
+
// T[] -> VBuffer
- Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType);
- del = CreateArrayToVBufferGetter;
+ if (outputType.GetElementType().IsGenericType && outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>))
+ Ch.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == colType.ItemType.RawType);
+ else
+ Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType);
+ del = CreateDirectArrayGetterDelegate;
+ genericType = outputType.GetElementType();
}
else if (colType.IsVector)
{
@@ -142,7 +197,8 @@ private Delegate CreateGetter(int index)
Ch.Assert(outputType.IsGenericType);
Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>));
Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType);
- del = CreateVBufferToVBufferDelegate;
+ del = CreateDirectVBufferGetterDelegate;
+ genericType = colType.ItemType.RawType;
}
else if (colType.IsPrimitive)
{
@@ -150,24 +206,74 @@ private Delegate CreateGetter(int index)
{
// String -> DvText
Ch.Assert(colType.IsText);
- return CreateStringToTextGetter(index);
+ return CreateConvertingGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x));
}
else if (outputType == typeof(bool))
{
// Bool -> DvBool
Ch.Assert(colType.IsBool);
- return CreateBooleanToDvBoolGetter(index);
+ return CreateConvertingGetterDelegate(index, x => x);
}
else if (outputType == typeof(bool?))
{
// Bool? -> DvBool
Ch.Assert(colType.IsBool);
- return CreateNullableBooleanToDvBoolGetter(index);
+ return CreateConvertingGetterDelegate(index, x => x ?? DvBool.NA);
+ }
+ else if (outputType == typeof(int))
+ {
+ // int -> DvInt4
+ Ch.Assert(colType == NumberType.I4);
+ return CreateConvertingGetterDelegate(index, x => x);
+ }
+ else if (outputType == typeof(int?))
+ {
+ // int? -> DvInt4
+ Ch.Assert(colType == NumberType.I4);
+ return CreateConvertingGetterDelegate(index, x => x ?? DvInt4.NA);
+ }
+ else if (outputType == typeof(short))
+ {
+ // short -> DvInt2
+ Ch.Assert(colType == NumberType.I2);
+ return CreateConvertingGetterDelegate(index, x => x);
+ }
+ else if (outputType == typeof(short?))
+ {
+ // short? -> DvInt2
+ Ch.Assert(colType == NumberType.I2);
+ return CreateConvertingGetterDelegate(index, x => x ?? DvInt2.NA);
+ }
+ else if (outputType == typeof(long))
+ {
+ // long -> DvInt8
+ Ch.Assert(colType == NumberType.I8);
+ return CreateConvertingGetterDelegate(index, x => x);
+ }
+ else if (outputType == typeof(long?))
+ {
+ // long? -> DvInt8
+ Ch.Assert(colType == NumberType.I8);
+ return CreateConvertingGetterDelegate(index, x => x ?? DvInt8.NA);
+ }
+ else if (outputType == typeof(sbyte))
+ {
+ // sbyte -> DvInt1
+ Ch.Assert(colType == NumberType.I1);
+ return CreateConvertingGetterDelegate(index, x => x);
+ }
+ else if (outputType == typeof(sbyte?))
+ {
+ // sbyte? -> DvInt1
+ Ch.Assert(colType == NumberType.I1);
+ return CreateConvertingGetterDelegate(index, x => x ?? DvInt1.NA);
}
-
// T -> T
- Ch.Assert(colType.RawType == outputType);
- del = CreateDirectGetter;
+ if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>))
+ Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
+ else
+ Ch.Assert(colType.RawType == outputType);
+ del = CreateDirectGetterDelegate;
}
else
{
@@ -175,66 +281,43 @@ private Delegate CreateGetter(int index)
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", outputType.FullName);
}
MethodInfo meth =
- del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType);
+ del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);
return (Delegate)meth.Invoke(this, new object[] { index });
}
- private Delegate CreateStringArrayToVBufferGetter(int index)
+ // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower
+ // than the 'direct' getter. We don't have good indication of this to the user, and the selection
+ // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats).
+ private Delegate CreateConvertingArrayGetterDelegate(int index, Func convert)
{
- var peek = DataView._peeks[index] as Peek;
+ var peek = DataView._peeks[index] as Peek;
Ch.AssertValue(peek);
-
- string[] buf = null;
-
- return (ValueGetter>)((ref VBuffer dst) =>
+ TSrc[] buf = default;
+ return (ValueGetter>)((ref VBuffer dst) =>
{
peek(GetCurrentRowObject(), Position, ref buf);
var n = Utils.Size(buf);
- dst = new VBuffer(n, Utils.Size(dst.Values) < n
- ? new DvText[n]
+ dst = new VBuffer(n, Utils.Size(dst.Values) < n
+ ? new TDst[n]
: dst.Values, dst.Indices);
for (int i = 0; i < n; i++)
- dst.Values[i] = new DvText(buf[i]);
- });
- }
-
- private Delegate CreateStringToTextGetter(int index)
- {
- var peek = DataView._peeks[index] as Peek;
- Ch.AssertValue(peek);
- string buf = null;
- return (ValueGetter)((ref DvText dst) =>
- {
- peek(GetCurrentRowObject(), Position, ref buf);
- dst = new DvText(buf);
- });
- }
-
- private Delegate CreateBooleanToDvBoolGetter(int index)
- {
- var peek = DataView._peeks[index] as Peek;
- Ch.AssertValue(peek);
- bool buf = false;
- return (ValueGetter)((ref DvBool dst) =>
- {
- peek(GetCurrentRowObject(), Position, ref buf);
- dst = (DvBool)buf;
+ dst.Values[i] = convert(buf[i]);
});
}
- private Delegate CreateNullableBooleanToDvBoolGetter(int index)
+ private Delegate CreateConvertingGetterDelegate(int index, Func convert)
{
- var peek = DataView._peeks[index] as Peek;
+ var peek = DataView._peeks[index] as Peek;
Ch.AssertValue(peek);
- bool? buf = null;
- return (ValueGetter)((ref DvBool dst) =>
+ TSrc buf = default;
+ return (ValueGetter)((ref TDst dst) =>
{
peek(GetCurrentRowObject(), Position, ref buf);
- dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA;
+ dst = convert(buf);
});
}
- private Delegate CreateArrayToVBufferGetter(int index)
+ private Delegate CreateDirectArrayGetterDelegate(int index)
{
var peek = DataView._peeks[index] as Peek;
Ch.AssertValue(peek);
@@ -250,26 +333,29 @@ private Delegate CreateArrayToVBufferGetter(int index)
});
}
- private Delegate CreateVBufferToVBufferDelegate(int index)
+ private Delegate CreateDirectVBufferGetterDelegate(int index)
{
var peek = DataView._peeks[index] as Peek>;
Ch.AssertValue(peek);
VBuffer buf = default(VBuffer);
return (ValueGetter>)((ref VBuffer dst) =>
- {
- // The peek for a VBuffer is just a simple assignment, so there is
- // no copy going on in the peek, so we must do that as a second
- // step to the destination.
- peek(GetCurrentRowObject(), Position, ref buf);
- buf.CopyTo(ref dst);
- });
+ {
+ // The peek for a VBuffer is just a simple assignment, so there is
+ // no copy going on in the peek, so we must do that as a second
+ // step to the destination.
+ peek(GetCurrentRowObject(), Position, ref buf);
+ buf.CopyTo(ref dst);
+ });
}
- private Delegate CreateDirectGetter(int index)
+ private Delegate CreateDirectGetterDelegate(int index)
{
var peek = DataView._peeks[index] as Peek;
Ch.AssertValue(peek);
- return (ValueGetter)((ref TDst dst) => { peek(GetCurrentRowObject(), Position, ref dst); });
+ return (ValueGetter)((ref TDst dst) =>
+ {
+ peek(GetCurrentRowObject(), Position, ref dst);
+ });
}
protected abstract TRow GetCurrentRowObject();
@@ -311,7 +397,7 @@ private void CheckColumnInRange(int columnIndex)
}
///
- /// An in-memory data view based on the IList of data.
+ /// An in-memory data view based on the IList of data.
/// Supports shuffling.
///
private sealed class ListDataView : DataViewBase
@@ -407,11 +493,11 @@ protected override bool MoveManyCore(long count)
}
///
- /// An in-memory data view based on the IEnumerable of data.
+ /// An in-memory data view based on the IEnumerable of data.
/// Doesn't support shuffling.
- ///
+ ///
/// This class is public because prediction engine wants to call its
- /// for performance reasons.
+ /// for performance reasons.
///
public sealed class StreamingDataView : DataViewBase
where TRow : class
@@ -493,7 +579,7 @@ protected override bool MoveNextCore()
///
/// This represents the 'infinite data view' over one (mutable) user-defined object.
- /// The 'current row' object can be updated at any time, this will affect all the
+ /// The 'current row' object can be updated at any time, this will affect all the
/// newly created cursors, but not the ones already existing.
///
public sealed class SingleRowLoopDataView : DataViewBase
@@ -646,7 +732,7 @@ public abstract partial class MetadataInfo
///
public ColumnType MetadataType;
///
- /// The string identifier of the metadata. Some identifiers have special meaning,
+ /// The string identifier of the metadata. Some identifiers have special meaning,
/// like "SlotNames", but any other identifiers can be used.
///
public readonly string Kind;
@@ -672,7 +758,7 @@ public sealed class MetadataInfo : MetadataInfo
///
/// Constructor for metadata of value type T.
///
- /// The string identifier of the metadata. Some identifiers have special meaning,
+ /// The string identifier of the metadata. Some identifiers have special meaning,
/// like "SlotNames", but any other identifiers can be used.
/// Metadata value.
/// Type of the metadata.
diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
index 0b45bcc4bb..26136971af 100644
--- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs
+++ b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
@@ -21,7 +21,7 @@ namespace Microsoft.ML.Runtime.Api
{
///
/// Generates the sample prediction code for a given model file, with correct input and output classes.
- ///
+ ///
/// REVIEW: Consider adding support for generating VBuffers instead of arrays, maybe for high dimensionality vectors.
///
public sealed class GenerateCodeCommand : ICommand
@@ -45,7 +45,7 @@ public sealed class Arguments
ShortName = "sparse", SortOrder = 102)]
public bool SparseVectorDeclaration;
- // REVIEW: currently, it's only used in unit testing to not generate the paths into the test output folder.
+ // REVIEW: currently, it's only used in unit testing to not generate the paths into the test output folder.
// However, it might be handy for automation scenarios, so I've added this as a hidden option.
[Argument(ArgumentType.AtMostOnce, HelpText = "A location of the model file to put into generated file", Hide = true)]
public string ModelNameOverride;
@@ -108,8 +108,8 @@ public void Run()
{
var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs);
scorer = roles != null
- ? _host.CreateDefaultScorer(RoleMappedData.CreateOpt(transformPipe, roles), pred)
- : _host.CreateDefaultScorer(_host.CreateExamples(transformPipe, "Features"), pred);
+ ? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt: true), pred)
+ : _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label: null, "Features"), pred);
}
var nonScoreSb = new StringBuilder();
diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs
index 2b0f056214..3edf7599a4 100644
--- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs
+++ b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs
@@ -76,12 +76,12 @@ private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = n
}
///
- /// Function that checks whether the InternalSchemaDefinition.Column is a valid one.
+ /// Function that checks whether the InternalSchemaDefinition.Column is a valid one.
/// To be valid, the Column must:
/// 1. Have non-empty values for ColumnName and ColumnType
/// 2. Have a non-empty value for FieldInfo iff it is a field column, else
/// ReturnParameterInfo and Generator iff it is a computed column
- /// 3. Generator must have the method inputs (TRow rowObject,
+ /// 3. Generator must have the method inputs (TRow rowObject,
/// long position, ref TValue outputValue) in that order.
///
[Conditional("DEBUG")]
@@ -133,7 +133,7 @@ private InternalSchemaDefinition(Column[] columns)
///
/// Given a field info on a type, returns whether this appears to be a vector type,
/// and also the associated data kind for this type. If a data kind could not
- /// be determined, this will throw.
+ /// be determined, this will throw.
///
/// The field info to inspect.
/// Whether this appears to be a vector type.
@@ -149,7 +149,7 @@ public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out
///
/// Given a parameter info on a type, returns whether this appears to be a vector type,
/// and also the associated data kind for this type. If a data kind could not
- /// be determined, this will throw.
+ /// be determined, this will throw.
///
/// The parameter info to inspect.
/// Whether this appears to be a vector type.
@@ -165,7 +165,7 @@ public static void GetVectorAndKind(ParameterInfo parameterInfo, out bool isVect
///
/// Given a type and name for a variable, returns whether this appears to be a vector type,
/// and also the associated data kind for this type. If a data kind could not
- /// be determined, this will throw.
+ /// be determined, this will throw.
///
/// The type of the variable to inspect.
/// The name of the variable to inspect.
@@ -222,7 +222,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
col.MemberName,
userType.FullName);
- //Clause to handle the field that may be used to expose the cursor channel.
+ //Clause to handle the field that may be used to expose the cursor channel.
//This field does not need a column.
if (fieldInfo.FieldType == typeof(IChannel))
continue;
@@ -251,7 +251,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
}
else
{
- // Make sure that the types are compatible with the declared type, including
+ // Make sure that the types are compatible with the declared type, including
// whether it is a vector type.
if (isVector != col.ColumnType.IsVector)
{
diff --git a/src/Microsoft.ML.Api/LambdaTransform.cs b/src/Microsoft.ML.Api/LambdaTransform.cs
index 93635c6d7d..506c675524 100644
--- a/src/Microsoft.ML.Api/LambdaTransform.cs
+++ b/src/Microsoft.ML.Api/LambdaTransform.cs
@@ -37,7 +37,7 @@ public static class LambdaTransform
/// different data by calling ), and the transformed data (which can be
/// enumerated upon by calling GetRowCursor or AsCursorable{TRow}). If or
/// implement the interface, they will be disposed after use.
- ///
+ ///
/// This is a 'stateless non-savable' version of the transform.
///
/// The host environment to use.
@@ -78,7 +78,7 @@ public static ITransformTemplate CreateMap(IHostEnvironment env, IDa
/// different data by calling ), and the transformed data (which can be
/// enumerated upon by calling GetRowCursor or AsCursorable{TRow}). If or
/// implement the interface, they will be disposed after use.
- ///
+ ///
/// This is a 'stateless savable' version of the transform: save and load routines must be provided.
///
/// The host environment to use.
@@ -123,7 +123,7 @@ public static ITransformTemplate CreateMap(IHostEnvironment env, IDa
///
/// This is a 'stateful non-savable' version of the map transform: the mapping function is guaranteed to be invoked once per
- /// every row of the data set, in sequence; one user-defined state object will be allocated per cursor and passed to the
+ /// every row of the data set, in sequence; one user-defined state object will be allocated per cursor and passed to the
/// map function every time. If , , or
/// implement the interface, they will be disposed after use.
///
@@ -164,7 +164,7 @@ public static ITransformTemplate CreateMap(IHostEnvironment
///
/// This is a 'stateful savable' version of the map transform: the mapping function is guaranteed to be invoked once per
- /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
+ /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
/// map function every time; save and load routines must be provided. If , ,
/// or implement the interface, they will be disposed after use.
///
@@ -217,8 +217,8 @@ public static ITransformTemplate CreateMap(IHostEnvironment
/// This creates a filter transform that can 'accept' or 'decline' any row of the data based on the contents of the row
/// or state of the cursor.
/// This is a 'stateful non-savable' version of the filter: the filter function is guaranteed to be invoked once per
- /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
- /// filter function every time.
+ /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
+ /// filter function every time.
/// If or implement the interface, they will be disposed after use.
///
/// The type that describes what 'source' columns are consumed from the
@@ -251,7 +251,7 @@ public static ITransformTemplate CreateFilter(IHostEnvironment env
/// This creates a filter transform that can 'accept' or 'decline' any row of the data based on the contents of the row
/// or state of the cursor.
/// This is a 'stateful savable' version of the filter: the filter function is guaranteed to be invoked once per
- /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
+ /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
/// filter function every time; save and load routines must be provided.
/// If or implement the interface, they will be disposed after use.
///
@@ -294,11 +294,11 @@ public static ITransformTemplate CreateFilter(IHostEnvironment env
}
///
- /// Defines common ancestor for various flavors of lambda-based user-defined transforms that may or may not be
+ /// Defines common ancestor for various flavors of lambda-based user-defined transforms that may or may not be
/// serializable.
- ///
+ ///
/// In order for the transform to be serializable, the user should specify a save and load delegate.
- /// Specifically, for this the user has to provide the following things:
+ /// Specifically, for this the user has to provide the following things:
/// * a custom save action that serializes the transform 'state' to the binary writer.
/// * a custom load action that de-serializes the transform from the binary reader. This must be a public static method of a public class.
///
diff --git a/src/Microsoft.ML.Api/MapTransform.cs b/src/Microsoft.ML.Api/MapTransform.cs
index 914bb63c07..4426721620 100644
--- a/src/Microsoft.ML.Api/MapTransform.cs
+++ b/src/Microsoft.ML.Api/MapTransform.cs
@@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Api
/// It doesn't change the number of rows, and can be seen as a result of application of the user's function
/// to every row of the input data.
/// Similarly to the existing 's, this object can be treated as both the 'transformation' algorithm
- /// (which can be then applied to different data by calling ), and the transformed data (which can
+ /// (which can be then applied to different data by calling ), and the transformed data (which can
/// be enumerated upon by calling GetRowCursor or AsCursorable{TRow}).
///
/// The type that describes what 'source' columns are consumed from the input .
@@ -36,8 +36,8 @@ internal sealed class MapTransform : LambdaTransformBase, ITransform
private static string RegistrationName { get { return string.Format(RegistrationNameTemplate, typeof(TSrc).FullName, typeof(TDst).FullName); } }
///
- /// Create a a map transform that is savable iff and are
- /// not null.
+ /// Create a a map transform that is savable iff and are
+ /// not null.
///
/// The host environment
/// The dataview upon which we construct the transform
@@ -47,7 +47,7 @@ internal sealed class MapTransform : LambdaTransformBase, ITransform
/// A function that given the serialization stream and a data view, returns
/// an . The intent is, this returned object should itself be a
/// , but this is not strictly necessary. This delegate should be
- /// a static non-lambda method that this assembly can legally call. May be null simultaneously with
+ /// a static non-lambda method that this assembly can legally call. May be null simultaneously with
/// .
/// The schema definition overrides for
/// The schema definition overrides for
diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Api/PredictionEngine.cs
index eacf8d2218..9410d3b50e 100644
--- a/src/Microsoft.ML.Api/PredictionEngine.cs
+++ b/src/Microsoft.ML.Api/PredictionEngine.cs
@@ -49,8 +49,8 @@ internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ig
{
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream);
pipe = roles != null
- ? env.CreateDefaultScorer(RoleMappedData.CreateOpt(pipe, roles), predictor)
- : env.CreateDefaultScorer(env.CreateExamples(pipe, "Features"), predictor);
+ ? env.CreateDefaultScorer(new RoleMappedData(pipe, roles, opt: true), predictor)
+ : env.CreateDefaultScorer(new RoleMappedData(pipe, label: null, "Features"), predictor);
}
_pipeEngine = new PipeEngine(env, pipe, ignoreMissingColumns, outputSchemaDefinition);
@@ -72,12 +72,12 @@ internal BatchPredictionEngine(IHostEnvironment env, IDataView dataPipeline, boo
}
///
- /// Run the prediction pipe. This will enumerate the exactly once,
- /// cache all the examples (by reference) into its internal representation and then run
+ /// Run the prediction pipe. This will enumerate the exactly once,
+ /// cache all the examples (by reference) into its internal representation and then run
/// the transformation pipe.
///
/// The examples to run the prediction on.
- /// If true, the engine will not allocate memory per output, and
+ /// If true, the engine will not allocate memory per output, and
/// the returned objects will actually always be the same object. The user is
/// expected to clone the values himself if needed.
/// The that contains all the pipeline results.
@@ -141,7 +141,7 @@ public void Reset()
/// in-memory data, one example at a time.
/// This can also be used with trained pipelines that do not end with a predictor: in this case, the
/// 'prediction' will be just the outcome of all the transformations.
- /// This is essentially a wrapper for that throws if
+ /// This is essentially a wrapper for that throws if
/// more than one result is returned per call to .
///
/// The user-defined type that holds the example.
@@ -198,7 +198,7 @@ public TDst Predict(TSrc example)
///
/// This class encapsulates the 'classic' prediction problem, where the input is denoted by the float array of features,
- /// and the output is a float score. For binary classification predictors that can output probability, there are output
+ /// and the output is a float score. For binary classification predictors that can output probability, there are output
/// fields that report the predicted label and probability.
///
public sealed class SimplePredictionEngine
diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs
index 5f84712625..e08845a87e 100644
--- a/src/Microsoft.ML.Api/SchemaDefinition.cs
+++ b/src/Microsoft.ML.Api/SchemaDefinition.cs
@@ -63,7 +63,7 @@ public VectorTypeAttribute(params int[] dims)
}
///
- /// Describes column information such as name and the source columns indicies that this
+ /// Describes column information such as name and the source columns indicies that this
/// column encapsulates.
///
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
@@ -81,12 +81,12 @@ public ColumnAttribute(string ordinal, string name = null)
public string Name { get; }
///
- /// Contains positions of indices of source columns in the form
- /// of ranges. Examples of range: if we want to include just column
- /// with index 1 we can write the range as 1, if we want to include
+ /// Contains positions of indices of source columns in the form
+ /// of ranges. Examples of range: if we want to include just column
+ /// with index 1 we can write the range as 1, if we want to include
/// columns 1 to 10 then we can write the range as 1-10 and we want to include all the
/// columns from column with index 1 until end then we can write 1-*.
- ///
+ ///
/// This takes sequence of ranges that are comma seperated, example:
/// 1,2-5,10-*
///
@@ -125,7 +125,7 @@ public sealed class NoColumnAttribute : Attribute
}
///
- /// Mark a member that implements exactly IChannel as being permitted to receive
+ /// Mark a member that implements exactly IChannel as being permitted to receive
/// channel information from an external channel.
///
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
@@ -133,11 +133,11 @@ public sealed class CursorChannelAttribute : Attribute
{
///
/// When passed some object, and a channel, it attempts to pass the channel to the object. It
- /// passes the channel to the object iff the object has exactly one field marked with the
- /// CursorChannelAttribute, and that field implements only the IChannel interface.
- ///
- /// The function returns the modified object, as well as a boolean indicator of whether it was
- /// able to pass the channel to the object.
+ /// passes the channel to the object iff the object has exactly one field marked with the
+ /// CursorChannelAttribute, and that field implements only the IChannel interface.
+ ///
+ /// The function returns the modified object, as well as a boolean indicator of whether it was
+ /// able to pass the channel to the object.
///
/// The object that attempts to acquire the channel.
/// The channel to pass to the object.
@@ -206,13 +206,13 @@ public sealed class Column
public ColumnType ColumnType { get; set; }
///
- /// Whether the column is a computed type.
+ /// Whether the column is a computed type.
///
public bool IsComputed { get { return Generator != null; } }
///
- /// The generator function. if the column is computed.
- ///
+ /// The generator function. if the column is computed.
+ ///
public Delegate Generator { get; set; }
public Type ReturnType => Generator?.GetMethodInfo().GetParameters().LastOrDefault().ParameterType.GetElementType();
@@ -277,7 +277,7 @@ public IEnumerable> GetMetadataTypes
}
///
- /// Get or set the column definition by column name.
+ /// Get or set the column definition by column name.
/// If there's no such column:
/// - get returns null,
/// - set adds a new column.
@@ -287,9 +287,7 @@ public IEnumerable> GetMetadataTypes
///
public Column this[string columnName]
{
-#pragma warning disable TLC_NoThis // Do not use 'this' keyword for member access
get => this.FirstOrDefault(x => x.ColumnName == columnName);
-#pragma warning restore TLC_NoThis // Do not use 'this' keyword for member access
set
{
Contracts.CheckValue(value, nameof(value));
@@ -323,12 +321,15 @@ public static SchemaDefinition Create(Type userType)
HashSet colNames = new HashSet();
foreach (var fieldInfo in userType.GetFields())
{
- // Clause to handle the field that may be used to expose the cursor channel.
+ // Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
- // REVIEW: maybe validate the channel attribute now, instead
+ // REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.
if (fieldInfo.FieldType == typeof(IChannel))
continue;
+ // Const fields do not need to be mapped.
+ if (fieldInfo.IsLiteral)
+ continue;
if (fieldInfo.GetCustomAttribute() != null)
continue;
diff --git a/src/Microsoft.ML.Api/SerializableLambdaTransform.cs b/src/Microsoft.ML.Api/SerializableLambdaTransform.cs
index 5f761a042b..7de6e522d8 100644
--- a/src/Microsoft.ML.Api/SerializableLambdaTransform.cs
+++ b/src/Microsoft.ML.Api/SerializableLambdaTransform.cs
@@ -79,7 +79,7 @@ public static ITransformTemplate Create(IHostEnvironment env, ModelLoadContext c
/// that method that should be enough to "recover" it, assuming it is a "recoverable" method (recoverable
/// here is a loose definition, meaning that is capable
/// of creating it, which includes among other things that it's static, non-lambda, accessible to
- /// this assembly, etc.).
+ /// this assembly, etc.).
///
/// The method that should be "recoverable"
/// A string array describing the input method
diff --git a/src/Microsoft.ML.Api/StatefulFilterTransform.cs b/src/Microsoft.ML.Api/StatefulFilterTransform.cs
index b7e0cf473b..f47b8620a8 100644
--- a/src/Microsoft.ML.Api/StatefulFilterTransform.cs
+++ b/src/Microsoft.ML.Api/StatefulFilterTransform.cs
@@ -9,10 +9,10 @@
namespace Microsoft.ML.Runtime.Api
{
- // REVIEW: the current interface to 'state' object may be inadequate: instead of insisting on
+ // REVIEW: the current interface to 'state' object may be inadequate: instead of insisting on
// parameterless constructor, we could take a delegate that would create the state per cursor.
///
- /// This transform is similar to , but it allows per-cursor state,
+ /// This transform is similar to , but it allows per-cursor state,
/// as well as the ability to 'accept' or 'filter out' some rows of the supplied .
/// The downside is that the provided lambda is eagerly called on every row (not lazily when needed), and
/// parallel cursors are not allowed.
@@ -38,8 +38,8 @@ internal sealed class StatefulFilterTransform : LambdaTransf
private static string RegistrationName { get { return string.Format(RegistrationNameTemplate, typeof(TSrc).FullName, typeof(TDst).FullName); } }
///
- /// Create a filter transform that is savable iff and are
- /// not null.
+ /// Create a filter transform that is savable iff and are
+ /// not null.
///
/// The host environment
/// The dataview upon which we construct the transform
@@ -51,7 +51,7 @@ internal sealed class StatefulFilterTransform : LambdaTransf
/// A function that given the serialization stream and a data view, returns
/// an . The intent is, this returned object should itself be a
/// , but this is not strictly necessary. This delegate should be
- /// a static non-lambda method that this assembly can legally call. May be null simultaneously with
+ /// a static non-lambda method that this assembly can legally call. May be null simultaneously with
/// .
/// The schema definition overrides for
/// The schema definition overrides for
diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs
index 29fee77a02..cd8198e14d 100644
--- a/src/Microsoft.ML.Api/TypedCursor.cs
+++ b/src/Microsoft.ML.Api/TypedCursor.cs
@@ -57,7 +57,7 @@ public interface ICursorable
///
/// Implementation of the strongly typed Cursorable.
- /// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that
+ /// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that
/// write directly into the fields of the user-defined type.
///
internal sealed class TypedCursorable : ICursorable
@@ -271,7 +271,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit
var colType = input.Schema.GetColumnType(index);
var fieldInfo = column.FieldInfo;
var fieldType = fieldInfo.FieldType;
-
+ var genericType = fieldType;
Func> del;
if (fieldType.IsArray)
{
@@ -280,11 +280,66 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit
if (fieldType.GetElementType() == typeof(string))
{
Ch.Assert(colType.ItemType.IsText);
- return CreateVBufferToStringArraySetter(input, index, poke, peek);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => x.ToString());
+ }
+ else if (fieldType.GetElementType() == typeof(bool))
+ {
+ Ch.Assert(colType.ItemType.IsBool);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (bool)x);
+ }
+ else if (fieldType.GetElementType() == typeof(bool?))
+ {
+ Ch.Assert(colType.ItemType.IsBool);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (bool?)x);
+ }
+ else if (fieldType.GetElementType() == typeof(int))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I4);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x);
+ }
+ else if (fieldType.GetElementType() == typeof(int?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I4);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int?)x);
+ }
+ else if (fieldType.GetElementType() == typeof(short))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I2);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short)x);
+ }
+ else if (fieldType.GetElementType() == typeof(short?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I2);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short?)x);
+ }
+ else if (fieldType.GetElementType() == typeof(long))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I8);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long)x);
+ }
+ else if (fieldType.GetElementType() == typeof(long?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I8);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long?)x);
}
+ else if (fieldType.GetElementType() == typeof(sbyte))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I1);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte)x);
+ }
+ else if (fieldType.GetElementType() == typeof(sbyte?))
+ {
+ Ch.Assert(colType.ItemType == NumberType.I1);
+ return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte?)x);
+ }
+
// VBuffer -> T[]
- Ch.Assert(fieldType.GetElementType() == colType.ItemType.RawType);
- del = CreateVBufferToArraySetter;
+ if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>))
+ Ch.Assert(colType.ItemType.RawType == Nullable.GetUnderlyingType(fieldType.GetElementType()));
+ else
+ Ch.Assert(colType.ItemType.RawType == fieldType.GetElementType());
+ del = CreateDirectVBufferSetter;
+ genericType = fieldType.GetElementType();
}
else if (colType.IsVector)
{
@@ -294,6 +349,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit
Ch.Assert(fieldType.GetGenericTypeDefinition() == typeof(VBuffer<>));
Ch.Assert(fieldType.GetGenericArguments()[0] == colType.ItemType.RawType);
del = CreateVBufferToVBufferSetter;
+ genericType = colType.ItemType.RawType;
}
else if (colType.IsPrimitive)
{
@@ -302,53 +358,111 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit
// DvText -> String
Ch.Assert(colType.IsText);
Ch.Assert(peek == null);
- return CreateTextToStringSetter(input, index, poke);
+ return CreateConvertingActionSetter(input, index, poke, x => x.ToString());
}
else if (fieldType == typeof(bool))
{
Ch.Assert(colType.IsBool);
Ch.Assert(peek == null);
- return CreateDvBoolToBoolSetter(input, index, poke);
+ return CreateConvertingActionSetter(input, index, poke, x => (bool)x);
}
- else
+ else if (fieldType == typeof(bool?))
{
- // T -> T
- Ch.Assert(colType.RawType == fieldType);
- del = CreateDirectSetter;
+ Ch.Assert(colType.IsBool);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (bool?)x);
+ }
+ else if (fieldType == typeof(int))
+ {
+ Ch.Assert(colType == NumberType.I4);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (int)x);
}
+ else if (fieldType == typeof(int?))
+ {
+ Ch.Assert(colType == NumberType.I4);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (int?)x);
+ }
+ else if (fieldType == typeof(short))
+ {
+ Ch.Assert(colType == NumberType.I2);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (short)x);
+ }
+ else if (fieldType == typeof(short?))
+ {
+ Ch.Assert(colType == NumberType.I2);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (short?)x);
+ }
+ else if (fieldType == typeof(long))
+ {
+ Ch.Assert(colType == NumberType.I8);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (long)x);
+ }
+ else if (fieldType == typeof(long?))
+ {
+ Ch.Assert(colType == NumberType.I8);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (long?)x);
+ }
+ else if (fieldType == typeof(sbyte))
+ {
+ Ch.Assert(colType == NumberType.I1);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (sbyte)x);
+ }
+ else if (fieldType == typeof(sbyte?))
+ {
+ Ch.Assert(colType == NumberType.I1);
+ Ch.Assert(peek == null);
+ return CreateConvertingActionSetter(input, index, poke, x => (sbyte?)x);
+ }
+ // T -> T
+ if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>))
+ Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType));
+ else
+ Ch.Assert(colType.RawType == fieldType);
+
+ del = CreateDirectSetter;
}
else
{
// REVIEW: Is this even possible?
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName);
}
- MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType);
+ MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);
return (Action)meth.Invoke(this, new object[] { input, index, poke, peek });
}
- private Action CreateVBufferToStringArraySetter(IRow input, int col, Delegate poke, Delegate peek)
+ // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower
+ // than the 'direct' getter. We don't have good indication of this to the user, and the selection
+ // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats).
+ private Action CreateConvertingVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert)
{
- var getter = input.GetGetter>(col);
- var typedPoke = poke as Poke;
- var typedPeek = peek as Peek;
+ var getter = input.GetGetter>(col);
+ var typedPoke = poke as Poke;
+ var typedPeek = peek as Peek;
Contracts.AssertValue(typedPoke);
Contracts.AssertValue(typedPeek);
- VBuffer value = default(VBuffer);
- string[] buf = null;
+ VBuffer value = default;
+ TDst[] buf = null;
return row =>
{
getter(ref value);
typedPeek(row, Position, ref buf);
if (Utils.Size(buf) != value.Length)
- buf = new string[value.Length];
+ buf = new TDst[value.Length];
foreach (var pair in value.Items(true))
- buf[pair.Key] = pair.Value.ToString();
+ buf[pair.Key] = convert(pair.Value);
typedPoke(row, buf);
};
}
- private Action CreateVBufferToArraySetter(IRow input, int col, Delegate poke, Delegate peek)
+ private Action CreateDirectVBufferSetter(IRow input, int col, Delegate poke, Delegate peek)
{
var getter = input.GetGetter>(col);
var typedPoke = poke as Poke;
@@ -386,29 +500,17 @@ private Action CreateVBufferToArraySetter(IRow input, int col, Deleg
};
}
- private static Action CreateTextToStringSetter(IRow input, int col, Delegate poke)
- {
- var getter = input.GetGetter(col);
- var typedPoke = poke as Poke;
- Contracts.AssertValue(typedPoke);
- DvText value = default(DvText);
- return row =>
- {
- getter(ref value);
- typedPoke(row, value.ToString());
- };
- }
-
- private static Action CreateDvBoolToBoolSetter(IRow input, int col, Delegate poke)
+ private static Action CreateConvertingActionSetter(IRow input, int col, Delegate poke, Func convert)
{
- var getter = input.GetGetter(col);
- var typedPoke = poke as Poke;
+ var getter = input.GetGetter(col);
+ var typedPoke = poke as Poke;
Contracts.AssertValue(typedPoke);
- DvBool value = default(DvBool);
+ TSrc value = default;
return row =>
{
getter(ref value);
- typedPoke(row, Convert.ToBoolean(value.RawValue));
+ var toPoke = convert(value);
+ typedPoke(row, toPoke);
};
}
diff --git a/src/Microsoft.ML.Console/Console.cs b/src/Microsoft.ML.Console/Console.cs
index 12e6254cce..152d65951a 100644
--- a/src/Microsoft.ML.Console/Console.cs
+++ b/src/Microsoft.ML.Console/Console.cs
@@ -8,4 +8,4 @@ public static class Console
{
public static int Main(string[] args) => Maml.Main(args);
}
-}
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
index 8039371695..ae327a26c2 100644
--- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
+++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
@@ -3,18 +3,33 @@
true
CORECLR
- Microsoft.ML
- netcoreapp2.0
- Exe
- MML
- Microsoft.ML.Runtime.Tools.Console.Console
+ netcoreapp2.0
+ Exe
+ MML
+ Microsoft.ML.Runtime.Tools.Console.Console
+
+
+
+
+
+
+
-
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
index b9b9506cf9..eb85fcce12 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
@@ -493,7 +493,7 @@ public static string ArgumentsUsage(IHostEnvironment env, Type type, object defa
#if CORECLR
///
- /// Fix the window width for the Core build to remove the kernel32.dll dependency.
+ /// Fix the window width for the Core build to remove the kernel32.dll dependency.
///
///
public static int GetConsoleWindowWidth()
@@ -620,7 +620,7 @@ private static ArgumentInfo GetArgumentInfo(Type type, object defaults)
string[] nicks;
// Semantics of ShortName:
// The string provided represents an array of names separated by commas and spaces, once empty entries are removed.
- // 'null' or a singleton array with containing only the long field name means "use the default short name",
+ // 'null' or a singleton array with containing only the long field name means "use the default short name",
// and is represented by the null 'nicks' array.
// 'String.Empty' or a string containing only spaces and commas means "no short name", and is represented by an empty 'nicks' array.
if (attr.ShortName == null)
@@ -1666,7 +1666,7 @@ public bool Finish(CmdParser owner, ArgValue val, object destination)
}
else if (IsMultiSubComponent)
{
- // REVIEW: the kind should not be separated from settings: everything related
+ // REVIEW: the kind should not be separated from settings: everything related
// to one item should go into one value, not multiple values
if (IsTaggedCollection)
{
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index 28666e7f44..3b56e8bb36 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -385,7 +385,7 @@ private static void CacheLoadedAssemblies()
{
if (_assemblyQueue == null)
{
- // Create the loaded assembly queue and dictionary, set up the AssemblyLoad / AssemblyResolve
+ // Create the loaded assembly queue and dictionary, set up the AssemblyLoad / AssemblyResolve
// event handlers and populate the queue / dictionary with all assemblies that are currently loaded.
Contracts.Assert(_assemblyQueue == null);
Contracts.Assert(_loadedAssemblies == null);
@@ -413,7 +413,7 @@ private static void CacheLoadedAssemblies()
// Load all assemblies in our directory.
var moduleName = typeof(ComponentCatalog).Module.FullyQualifiedName;
- // If were are loaded in the context of SQL CLR then the FullyQualifiedName and Name properties are set to
+ // If were are loaded in the context of SQL CLR then the FullyQualifiedName and Name properties are set to
// string "" and we skip scanning current directory.
if (moduleName != "")
{
@@ -451,7 +451,7 @@ private static void CacheLoadedAssemblies()
#if TRACE_ASSEMBLY_LOADING
// The "" no-op argument is necessary because WriteLine has multiple overloads, and with two strings
- // it will be the one that is message/category, rather than format string with
+ // it will be the one that is message/category, rather than format string with
System.Diagnostics.Debug.WriteLine("*** Caching classes in {0}", assembly.FullName, "");
#endif
int added = 0;
diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs
index 780ef7a7d7..0cff911e77 100644
--- a/src/Microsoft.ML.Core/Data/ColumnType.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnType.cs
@@ -325,7 +325,7 @@ public static PrimitiveType FromKind(DataKind kind)
///
public sealed class TextType : PrimitiveType
{
- private volatile static TextType _instance;
+ private static volatile TextType _instance;
public static TextType Instance
{
get
@@ -370,7 +370,7 @@ private NumberType(DataKind kind, string name)
Contracts.Assert(IsNumber);
}
- private volatile static NumberType _instI1;
+ private static volatile NumberType _instI1;
public static NumberType I1
{
get
@@ -381,7 +381,7 @@ public static NumberType I1
}
}
- private volatile static NumberType _instU1;
+ private static volatile NumberType _instU1;
public static NumberType U1
{
get
@@ -392,7 +392,7 @@ public static NumberType U1
}
}
- private volatile static NumberType _instI2;
+ private static volatile NumberType _instI2;
public static NumberType I2
{
get
@@ -403,7 +403,7 @@ public static NumberType I2
}
}
- private volatile static NumberType _instU2;
+ private static volatile NumberType _instU2;
public static NumberType U2
{
get
@@ -414,7 +414,7 @@ public static NumberType U2
}
}
- private volatile static NumberType _instI4;
+ private static volatile NumberType _instI4;
public static NumberType I4
{
get
@@ -425,7 +425,7 @@ public static NumberType I4
}
}
- private volatile static NumberType _instU4;
+ private static volatile NumberType _instU4;
public static NumberType U4
{
get
@@ -436,7 +436,7 @@ public static NumberType U4
}
}
- private volatile static NumberType _instI8;
+ private static volatile NumberType _instI8;
public static NumberType I8
{
get
@@ -447,7 +447,7 @@ public static NumberType I8
}
}
- private volatile static NumberType _instU8;
+ private static volatile NumberType _instU8;
public static NumberType U8
{
get
@@ -458,7 +458,7 @@ public static NumberType U8
}
}
- private volatile static NumberType _instUG;
+ private static volatile NumberType _instUG;
public static NumberType UG
{
get
@@ -469,7 +469,7 @@ public static NumberType UG
}
}
- private volatile static NumberType _instR4;
+ private static volatile NumberType _instR4;
public static NumberType R4
{
get
@@ -480,7 +480,7 @@ public static NumberType R4
}
}
- private volatile static NumberType _instR8;
+ private static volatile NumberType _instR8;
public static NumberType R8
{
get
@@ -496,7 +496,7 @@ public static NumberType Float
get { return R4; }
}
- public new static NumberType FromKind(DataKind kind)
+ public static new NumberType FromKind(DataKind kind)
{
switch (kind)
{
@@ -557,7 +557,7 @@ public override string ToString()
///
public sealed class BoolType : PrimitiveType
{
- private volatile static BoolType _instance;
+ private static volatile BoolType _instance;
public static BoolType Instance
{
get
@@ -589,7 +589,7 @@ public override string ToString()
public sealed class DateTimeType : PrimitiveType
{
- private volatile static DateTimeType _instance;
+ private static volatile DateTimeType _instance;
public static DateTimeType Instance
{
get
@@ -621,7 +621,7 @@ public override string ToString()
public sealed class DateTimeZoneType : PrimitiveType
{
- private volatile static DateTimeZoneType _instance;
+ private static volatile DateTimeZoneType _instance;
public static DateTimeZoneType Instance
{
get
@@ -656,7 +656,7 @@ public override string ToString()
///
public sealed class TimeSpanType : PrimitiveType
{
- private volatile static TimeSpanType _instance;
+ private static volatile TimeSpanType _instance;
public static TimeSpanType Instance
{
get
@@ -692,11 +692,11 @@ public override string ToString()
/// meaningful. Examples are SSNs, phone numbers, auto-generated/incremented key values,
/// class numbers, etc. For example, in multi-class classification, the label is typically
/// a class number which is naturally a KeyType.
- ///
+ ///
/// KeyTypes can be contiguous (the class number example), in which case they can have
/// a cardinality/Count. For non-contiguous KeyTypes the Count property returns zero.
/// Any KeyType (contiguous or not) can have a Min value. The Min value is always >= 0.
- ///
+ ///
/// Note that the representation value does not necessarily match the logical value.
/// For example, if a KeyType has range 1000-5000, then it has a Min of 1000, Count
/// of 4001, but the representational values are 1-4001. The representation value zero
@@ -951,7 +951,7 @@ public bool IsSubtypeOf(VectorType other)
if (other == null)
return false;
- // REVIEW: Perhaps we should allow the case when _itemType is
+ // REVIEW: Perhaps we should allow the case when _itemType is
// a sub-type of other._itemType (in particular for key types)
if (!_itemType.Equals(other._itemType))
return false;
diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs
index 5ed5ded1c1..32325f44a1 100644
--- a/src/Microsoft.ML.Core/Data/DataKind.cs
+++ b/src/Microsoft.ML.Core/Data/DataKind.cs
@@ -30,7 +30,7 @@ public enum DataKind : byte
Num = R4,
TX = 11,
-#pragma warning disable TLC_GeneralName // The data kind enum has its own logic, independnet of C# naming conventions.
+#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independnet of C# naming conventions.
TXT = TX,
Text = TX,
@@ -46,7 +46,7 @@ public enum DataKind : byte
UG = 16, // Unsigned 16-byte integer.
U16 = UG,
-#pragma warning restore TLC_GeneralName
+#pragma warning restore MSML_GeneralName
}
///
@@ -83,22 +83,22 @@ public static ulong ToMaxInt(this DataKind kind)
{
switch (kind)
{
- case DataKind.I1:
- return (ulong)sbyte.MaxValue;
- case DataKind.U1:
- return byte.MaxValue;
- case DataKind.I2:
- return (ulong)short.MaxValue;
- case DataKind.U2:
- return ushort.MaxValue;
- case DataKind.I4:
- return int.MaxValue;
- case DataKind.U4:
- return uint.MaxValue;
- case DataKind.I8:
- return long.MaxValue;
- case DataKind.U8:
- return ulong.MaxValue;
+ case DataKind.I1:
+ return (ulong)sbyte.MaxValue;
+ case DataKind.U1:
+ return byte.MaxValue;
+ case DataKind.I2:
+ return (ulong)short.MaxValue;
+ case DataKind.U2:
+ return ushort.MaxValue;
+ case DataKind.I4:
+ return int.MaxValue;
+ case DataKind.U4:
+ return uint.MaxValue;
+ case DataKind.I8:
+ return long.MaxValue;
+ case DataKind.U8:
+ return ulong.MaxValue;
}
return 0;
@@ -112,22 +112,22 @@ public static long ToMinInt(this DataKind kind)
{
switch (kind)
{
- case DataKind.I1:
- return sbyte.MinValue;
- case DataKind.U1:
- return byte.MinValue;
- case DataKind.I2:
- return short.MinValue;
- case DataKind.U2:
- return ushort.MinValue;
- case DataKind.I4:
- return int.MinValue;
- case DataKind.U4:
- return uint.MinValue;
- case DataKind.I8:
- return long.MinValue;
- case DataKind.U8:
- return 0;
+ case DataKind.I1:
+ return sbyte.MinValue;
+ case DataKind.U1:
+ return byte.MinValue;
+ case DataKind.I2:
+ return short.MinValue;
+ case DataKind.U2:
+ return ushort.MinValue;
+ case DataKind.I4:
+ return int.MinValue;
+ case DataKind.U4:
+ return uint.MinValue;
+ case DataKind.I8:
+ return long.MinValue;
+ case DataKind.U8:
+ return 0;
}
return 1;
@@ -140,38 +140,38 @@ public static Type ToType(this DataKind kind)
{
switch (kind)
{
- case DataKind.I1:
- return typeof(DvInt1);
- case DataKind.U1:
- return typeof(byte);
- case DataKind.I2:
- return typeof(DvInt2);
- case DataKind.U2:
- return typeof(ushort);
- case DataKind.I4:
- return typeof(DvInt4);
- case DataKind.U4:
- return typeof(uint);
- case DataKind.I8:
- return typeof(DvInt8);
- case DataKind.U8:
- return typeof(ulong);
- case DataKind.R4:
- return typeof(Single);
- case DataKind.R8:
- return typeof(Double);
- case DataKind.TX:
- return typeof(DvText);
- case DataKind.BL:
- return typeof(DvBool);
- case DataKind.TS:
- return typeof(DvTimeSpan);
- case DataKind.DT:
- return typeof(DvDateTime);
- case DataKind.DZ:
- return typeof(DvDateTimeZone);
- case DataKind.UG:
- return typeof(UInt128);
+ case DataKind.I1:
+ return typeof(DvInt1);
+ case DataKind.U1:
+ return typeof(byte);
+ case DataKind.I2:
+ return typeof(DvInt2);
+ case DataKind.U2:
+ return typeof(ushort);
+ case DataKind.I4:
+ return typeof(DvInt4);
+ case DataKind.U4:
+ return typeof(uint);
+ case DataKind.I8:
+ return typeof(DvInt8);
+ case DataKind.U8:
+ return typeof(ulong);
+ case DataKind.R4:
+ return typeof(Single);
+ case DataKind.R8:
+ return typeof(Double);
+ case DataKind.TX:
+ return typeof(DvText);
+ case DataKind.BL:
+ return typeof(DvBool);
+ case DataKind.TS:
+ return typeof(DvTimeSpan);
+ case DataKind.DT:
+ return typeof(DvDateTime);
+ case DataKind.DZ:
+ return typeof(DvDateTimeZone);
+ case DataKind.UG:
+ return typeof(UInt128);
}
return null;
@@ -185,29 +185,29 @@ public static bool TryGetDataKind(this Type type, out DataKind kind)
Contracts.CheckValueOrNull(type);
// REVIEW: Make this more efficient. Should we have a global dictionary?
- if (type == typeof(DvInt1))
+ if (type == typeof(DvInt1) || type == typeof(sbyte) || type == typeof(sbyte?))
kind = DataKind.I1;
- else if (type == typeof(byte))
+ else if (type == typeof(byte) || type == typeof(byte?))
kind = DataKind.U1;
- else if (type == typeof(DvInt2))
+ else if (type == typeof(DvInt2)|| type== typeof(short) || type == typeof(short?))
kind = DataKind.I2;
- else if (type == typeof(ushort))
+ else if (type == typeof(ushort)|| type == typeof(ushort?))
kind = DataKind.U2;
- else if (type == typeof(DvInt4))
+ else if (type == typeof(DvInt4) || type == typeof(int)|| type == typeof(int?))
kind = DataKind.I4;
- else if (type == typeof(uint))
+ else if (type == typeof(uint)|| type == typeof(uint?))
kind = DataKind.U4;
- else if (type == typeof(DvInt8))
+ else if (type == typeof(DvInt8) || type==typeof(long)|| type == typeof(long?))
kind = DataKind.I8;
- else if (type == typeof(ulong))
+ else if (type == typeof(ulong)|| type == typeof(ulong?))
kind = DataKind.U8;
- else if (type == typeof(Single))
+ else if (type == typeof(Single)|| type == typeof(Single?))
kind = DataKind.R4;
- else if (type == typeof(Double))
+ else if (type == typeof(Double)|| type == typeof(Double?))
kind = DataKind.R8;
else if (type == typeof(DvText))
kind = DataKind.TX;
- else if (type == typeof(DvBool) || type == typeof(bool) ||type ==typeof(bool?))
+ else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?))
kind = DataKind.BL;
else if (type == typeof(DvTimeSpan))
kind = DataKind.TS;
@@ -234,38 +234,38 @@ public static string GetString(this DataKind kind)
{
switch (kind)
{
- case DataKind.I1:
- return "I1";
- case DataKind.I2:
- return "I2";
- case DataKind.I4:
- return "I4";
- case DataKind.I8:
- return "I8";
- case DataKind.U1:
- return "U1";
- case DataKind.U2:
- return "U2";
- case DataKind.U4:
- return "U4";
- case DataKind.U8:
- return "U8";
- case DataKind.R4:
- return "R4";
- case DataKind.R8:
- return "R8";
- case DataKind.BL:
- return "BL";
- case DataKind.TX:
- return "TX";
- case DataKind.TS:
- return "TS";
- case DataKind.DT:
- return "DT";
- case DataKind.DZ:
- return "DZ";
- case DataKind.UG:
- return "UG";
+ case DataKind.I1:
+ return "I1";
+ case DataKind.I2:
+ return "I2";
+ case DataKind.I4:
+ return "I4";
+ case DataKind.I8:
+ return "I8";
+ case DataKind.U1:
+ return "U1";
+ case DataKind.U2:
+ return "U2";
+ case DataKind.U4:
+ return "U4";
+ case DataKind.U8:
+ return "U8";
+ case DataKind.R4:
+ return "R4";
+ case DataKind.R8:
+ return "R8";
+ case DataKind.BL:
+ return "BL";
+ case DataKind.TX:
+ return "TX";
+ case DataKind.TS:
+ return "TS";
+ case DataKind.DT:
+ return "DT";
+ case DataKind.DZ:
+ return "DZ";
+ case DataKind.UG:
+ return "UG";
}
return "";
}
diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs
index 52b30b5bb6..d11be2a494 100644
--- a/src/Microsoft.ML.Core/Data/DateTime.cs
+++ b/src/Microsoft.ML.Core/Data/DateTime.cs
@@ -230,7 +230,7 @@ public DvDateTimeZone(DvDateTime dt, DvTimeSpan offset)
/// are within the valid range, and returns a DvDateTime representing the UTC time (dateTime-offset).
///
/// The clock time
- /// The offset. This value is assumed to be validated as a legal offset:
+ /// The offset. This value is assumed to be validated as a legal offset:
/// a value in whole minutes, between -14 and 14 hours.
/// The UTC DvDateTime representing the input clock time minus the offset
private static DvDateTime ValidateDate(DvDateTime dateTime, ref DvInt2 offset)
diff --git a/src/Microsoft.ML.Core/Data/ICursor.cs b/src/Microsoft.ML.Core/Data/ICursor.cs
index 264eaa55bb..e1efc842f4 100644
--- a/src/Microsoft.ML.Core/Data/ICursor.cs
+++ b/src/Microsoft.ML.Core/Data/ICursor.cs
@@ -18,7 +18,7 @@ public interface ICounted
/// This is incremented for ICursor when the underlying contents changes, giving clients a way to detect change.
/// Generally it's -1 when the object is in an invalid state. In particular, for an , this is -1
/// when the is or .
- ///
+ ///
/// Note that this position is not position within the underlying data, but position of this cursor only.
/// If one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's
/// first valid entry would always have position 0.
@@ -30,7 +30,7 @@ public interface ICounted
/// batch numbers should be non-decreasing. Furthermore, any given batch number should only appear in one
/// of the streams. Order is determined by batch number. The reconciler ensures that each stream (that is
/// still active) has at least one item available, then takes the item with the smallest batch number.
- ///
+ ///
/// Note that there is no suggestion that the batches for a particular entry will be consistent from
/// cursoring to cursoring, except for the consistency in resulting in the same overall ordering. The same
/// entry could have different batch numbers from one cursoring to another. There is also no requirement
@@ -45,7 +45,7 @@ public interface ICounted
/// will produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID
/// exists for applications that need to reconcile which entry is actually which. Ideally this ID should
/// be unique, but for practical reasons, it suffices if collisions are simply extremely improbable.
- ///
+ ///
/// Note that this ID, while it must be consistent for multiple streams according to the semantics
/// above, is not considered part of the data per se. So, to take the example of a data view specifically,
/// a single data view must render consistent IDs across all cursorings, but there is no suggestion at
@@ -77,7 +77,7 @@ public interface ICursor : ICounted, IDisposable
/// Returns the state of the cursor. Before the first call to or
/// this should be . After
/// any call those move functions that returns true, this should return
- /// ,
+ /// ,
///
CursorState State { get; }
diff --git a/src/Microsoft.ML.Core/Data/IDataView.cs b/src/Microsoft.ML.Core/Data/IDataView.cs
index db83c15fd9..052a07dc9e 100644
--- a/src/Microsoft.ML.Core/Data/IDataView.cs
+++ b/src/Microsoft.ML.Core/Data/IDataView.cs
@@ -89,7 +89,7 @@ public interface IDataView : ISchematized
/// call. This indicates, that the transform does not YET know the number of rows, but
/// may in the future. If lazy is false, then this is permitted to do some work (no more
/// that it would normally do for cursoring) to determine the number of rows.
- ///
+ ///
/// Most components will return the same answer whether lazy is true or false. Some, like
/// a cache, might return null until the cache is fully populated (when lazy is true). When
/// lazy is false, such a cache would block until the cache was populated.
@@ -110,7 +110,7 @@ public interface IDataView : ISchematized
/// has no recommendation, and the implementation should have some default behavior to cover
/// this case. Note that this is strictly a recommendation: it is entirely possible that
/// an implementation can return a different number of cursors.
- ///
+ ///
/// The cursors should return the same data as returned through
/// , except partitioned: no two cursors
/// should return the "same" row as would have been returned through the regular serial cursor,
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index 7589ef13ad..b463e52a8e 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -62,7 +62,7 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
/// Note that IFileHandle derives from IDisposable. Clients may dispose the IFileHandle when it is
/// no longer needed, but they are not required to. The host environment should track all temp file
/// handles and ensure that they are disposed properly when the environment is "shut down".
- ///
+ ///
/// The suffix and prefix are optional. A common use for suffix is to specify an extension, eg, ".txt".
/// The use of suffix and prefix, including whether they have any affect, is up to the host enviroment.
///
diff --git a/src/Microsoft.ML.Core/Data/IMlState.cs b/src/Microsoft.ML.Core/Data/IMlState.cs
index 98c0e8e5aa..52b0828256 100644
--- a/src/Microsoft.ML.Core/Data/IMlState.cs
+++ b/src/Microsoft.ML.Core/Data/IMlState.cs
@@ -5,7 +5,7 @@
namespace Microsoft.ML.Runtime.EntryPoints
{
///
- /// Dummy interface to allow reference to the AutoMlState object in the C# API (since AutoMlState
+ /// Dummy interface to allow reference to the AutoMlState object in the C# API (since AutoMlState
/// has things that reference C# API, leading to circular dependency). Makes state object an opaque
/// black box to the graph. The macro itself will then case to the concrete type.
///
diff --git a/src/Microsoft.ML.Core/Data/IProgressChannel.cs b/src/Microsoft.ML.Core/Data/IProgressChannel.cs
index b5bae12c0b..0f673d9b2a 100644
--- a/src/Microsoft.ML.Core/Data/IProgressChannel.cs
+++ b/src/Microsoft.ML.Core/Data/IProgressChannel.cs
@@ -10,7 +10,7 @@ namespace Microsoft.ML.Runtime
/// This is a factory interface for .
/// Both and implement this interface,
/// to allow for nested progress reporters.
- ///
+ ///
/// REVIEW: make implement this, instead of the environment?
///
public interface IProgressChannelProvider
@@ -24,10 +24,10 @@ public interface IProgressChannelProvider
///
/// A common interface for progress reporting.
/// It is expected that the progress channel interface is used from only one thread.
- ///
+ ///
/// Supported workflow:
/// 1) Create the channel via .
- /// 2) Call as many times as desired (including 0).
+ /// 2) Call as many times as desired (including 0).
/// Each call to supersedes the previous one.
/// 3) Report checkpoints (0 or more) by calling .
/// 4) Repeat steps 2-3 as often as necessary.
@@ -39,13 +39,13 @@ public interface IProgressChannel : IProgressChannelProvider, IDisposable
/// Set up the reporting structure:
/// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
/// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
- ///
+ ///
/// It is acceptable to call multiple times (or none), regardless of whether the calculation is running
- /// or not. Because of synchronization, the computation should not deny calls to the 'old'
+ /// or not. Because of synchronization, the computation should not deny calls to the 'old'
/// delegates even after a new one is provided.
///
/// The header object.
- /// The delegate to provide actual progress. The parameter of
+ /// The delegate to provide actual progress. The parameter of
/// the delegate will correspond to the provided .
void SetHeader(ProgressHeader header, Action fillAction);
@@ -53,10 +53,10 @@ public interface IProgressChannel : IProgressChannelProvider, IDisposable
/// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
/// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
/// at certain moments ('checkpoints') of the computation.
- ///
+ ///
/// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
/// every iteration.
- ///
+ ///
/// The only parameter, , is interpreted in the following fashion:
/// * First MetricNames.Length items, if present, are metrics.
/// * Subsequent ProgressNames.Length items, if present, are progress units.
@@ -92,11 +92,11 @@ public sealed class ProgressHeader
/// progress or metrics to report, it is always better to report them.
///
/// The metrics that the calculation reports. These are completely independent, and there
- /// is no contract on whether the metric values should increase or not. As naming convention,
+ /// is no contract on whether the metric values should increase or not. As naming convention,
/// can have multiple words with spaces, and should be title-cased.
/// The names of the progress units, listed from least granular to most granular.
/// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
- /// As naming convention, should be lower-cased and typically plural
+ /// As naming convention, should be lower-cased and typically plural
/// (e.g. iterations, clusters, examples).
public ProgressHeader(string[] metricNames, string[] unitNames)
{
@@ -108,7 +108,7 @@ public ProgressHeader(string[] metricNames, string[] unitNames)
}
///
- /// A constructor for no metrics, just progress units. As naming convention, should be lower-cased
+ /// A constructor for no metrics, just progress units. As naming convention, should be lower-cased
/// and typically plural (e.g. iterations, clusters, examples).
///
public ProgressHeader(params string[] unitNames)
@@ -118,7 +118,7 @@ public ProgressHeader(params string[] unitNames)
}
///
- /// A metric/progress holder item.
+ /// A metric/progress holder item.
///
public interface IProgressEntry
{
@@ -130,7 +130,7 @@ public interface IProgressEntry
///
/// Set the progress value for the index to ,
- /// and the limit value to . If is a NAN, it is set to null instead.
+ /// and the limit value to . If is a NAN, it is set to null instead.
///
void SetProgress(int index, Double value, Double lim);
diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
index 466611c11a..6adac55f1b 100644
--- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
+++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
@@ -9,15 +9,15 @@ namespace Microsoft.ML.Runtime.Data
{
///
/// A mapper that can be bound to a (which is an ISchema, with mappings from column kinds
- /// to columns). Binding an to a produces an
+ /// to columns). Binding an to a produces an
/// , which is an interface that has methods to return the names and indices of the input columns
/// needed by the mapper to compute its output. The is an extention to this interface, that
- /// can also produce an output IRow given an input IRow. The IRow produced generally contains only the output columns of the mapper, and not
+ /// can also produce an output IRow given an input IRow. The IRow produced generally contains only the output columns of the mapper, and not
/// the input columns (but there is nothing preventing an from mapping input columns directly to outputs).
- /// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
+ /// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
/// features column. New predictors can implement directly. Implementing
/// includes implementing a corresponding (or ) and a corresponding ISchema
- /// for the output schema of the . In case the interface is implemented,
+ /// for the output schema of the . In case the interface is implemented,
/// the SimpleRow class can be used in the method.
///
public interface ISchemaBindableMapper
@@ -54,7 +54,7 @@ public interface ISchemaBoundMapper
///
/// This interface extends with an additional method: . This method
- /// takes an input IRow and a predicate indicating which output columns are active, and returns a new IRow
+ /// takes an input IRow and a predicate indicating which output columns are active, and returns a new IRow
/// containing the output columns.
///
public interface ISchemaBoundRowMapper : ISchemaBoundMapper
@@ -67,11 +67,11 @@ public interface ISchemaBoundRowMapper : ISchemaBoundMapper
///
/// Get an IRow based on the input IRow with the indicated active columns. The active columns are those for which
- /// predicate(col) returns true. The schema of the returned IRow will be the same as the OutputSchema, but getting
+ /// predicate(col) returns true. The schema of the returned IRow will be the same as the OutputSchema, but getting
/// values on inactive columns will throw. Null predicates are disallowed.
/// The schema of input should match the InputSchema.
/// This method creates a live connection between the input IRow and the output IRow. In particular, when the
- /// getters of the output IRow are invoked, they invoke the getters of the input row and base the output values on
+ /// getters of the output IRow are invoked, they invoke the getters of the input row and base the output values on
/// the current values of the input IRow. The output IRow values are re-computed when requested through the getters.
/// The optional disposer is invoked by the cursor wrapping, when it no longer needs the IRow.
/// If no action is needed when the cursor is Disposed, the override should set disposer to null,
@@ -101,7 +101,7 @@ public interface IRowToRowMapper
/// predicate(col) returns true. Getting values on inactive columns will throw. Null predicates are disallowed.
/// The schema of input should match the InputSchema.
/// This method creates a live connection between the input IRow and the output IRow. In particular, when the
- /// getters of the output IRow are invoked, they invoke the getters of the input row and base the output values on
+ /// getters of the output IRow are invoked, they invoke the getters of the input row and base the output values on
/// the current values of the input IRow. The output IRow values are re-computed when requested through the getters.
/// The optional disposer is invoked by the cursor wrapping, when it no longer needs the IRow.
/// If no action is needed when the cursor is Disposed, the override should set disposer to null,
diff --git a/src/Microsoft.ML.Core/Data/ITrainerArguments.cs b/src/Microsoft.ML.Core/Data/ITrainerArguments.cs
index af74a9abfc..e4fdbbdc59 100644
--- a/src/Microsoft.ML.Core/Data/ITrainerArguments.cs
+++ b/src/Microsoft.ML.Core/Data/ITrainerArguments.cs
@@ -6,7 +6,7 @@ namespace Microsoft.ML.Runtime
{
// This is basically a no-op interface put in primarily
// for backward binary compat support for AFx.
- // REVIEW: This interface was removed in TLC 3.0 as part of the
+ // REVIEW: This interface was removed in TLC 3.0 as part of the
// deprecation of the *Factory interfaces, but added back as a temporary
// hack. Remove it asap.
public interface ITrainerArguments
diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index 04f31d844a..116d521756 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -74,9 +74,9 @@ public static class Kinds
///
/// Metadata kind that indicates the ranges within a column that are categorical features.
- /// The value is a vector type of ints with dimension of two. The first dimension
+ /// The value is a vector type of ints with dimension of two. The first dimension
/// represents the number of categorical features and second dimension represents the range
- /// and is of size two. The range has start and end index(both inclusive) of categorical
+ /// and is of size two. The range has start and end index(both inclusive) of categorical
/// slots within that column.
///
public const string CategoricalSlotRanges = "CategoricalSlotRanges";
@@ -156,7 +156,7 @@ public static VectorType GetNamesType(int size)
}
///
- /// Returns a vector type with item type int and the given size.
+ /// Returns a vector type with item type int and the given size.
/// The range count must be a positive integer.
/// This is a standard type for metadata consisting of multiple int values that represent
/// categorical slot ranges with in a column.
@@ -312,7 +312,6 @@ public static bool HasSlotNames(this ISchema schema, int col, int vectorSize)
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer slotNames)
{
Contracts.CheckValueOrNull(schema);
- Contracts.CheckValue(role.Value, nameof(role));
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
IReadOnlyList list;
@@ -335,6 +334,22 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
&& type.ItemType.IsText;
}
+ ///
+ /// Returns whether a column has the metadata set to true.
+ /// That metadata should be set when the data has undergone transforms that would render it
+ /// "normalized."
+ ///
+ /// The schema to query
+ /// Which column in the schema to query
+ /// True if and only if the column has the metadata
+ /// set to the scalar value
+ public static bool IsNormalized(this ISchema schema, int col)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ var value = default(DvBool);
+ return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value.IsTrue;
+ }
+
///
/// Tries to get the metadata kind of the specified type for a column.
///
@@ -347,6 +362,9 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
/// True if the metadata of the right type exists, false otherwise
public static bool TryGetMetadata(this ISchema schema, PrimitiveType type, string kind, int col, ref T value)
{
+ Contracts.CheckValue(schema, nameof(schema));
+ Contracts.CheckValue(type, nameof(type));
+
var metadataType = schema.GetMetadataTypeOrNull(kind, col);
if (!type.Equals(metadataType))
return false;
@@ -363,17 +381,17 @@ public static bool IsHidden(this ISchema schema, int col)
string name = schema.GetColumnName(col);
int top;
bool tmp = schema.TryGetColumnIndex(name, out top);
- Contracts.Assert(tmp, "Why did TryGetColumnIndex return false?");
+ Contracts.Assert(tmp); // This would only be false if the implementation of schema were buggy.
return !tmp || top != col;
}
///
- /// The categoricalFeatures is a vector of the indices of categorical features slots.
+ /// The categoricalFeatures is a vector of the indices of categorical features slots.
/// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
/// So if its value is the range of numbers: 0,2,3,4,8,9
/// look at it as [0,2],[3,4],[8,9].
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
- /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
+ /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
///
public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, out int[] categoricalFeatures)
{
diff --git a/src/Microsoft.ML.Core/Data/ProgressReporter.cs b/src/Microsoft.ML.Core/Data/ProgressReporter.cs
index 384e1bfb61..5f9575cca5 100644
--- a/src/Microsoft.ML.Core/Data/ProgressReporter.cs
+++ b/src/Microsoft.ML.Core/Data/ProgressReporter.cs
@@ -202,8 +202,8 @@ private ProgressEntry BuildJointEntry(ProgressEntry rootEntry)
///
/// This is a 'derived' or 'subordinate' progress channel.
- ///
- /// The subordinates' Start/Stop events and checkpoints will not be propagated.
+ ///
+ /// The subordinates' Start/Stop events and checkpoints will not be propagated.
/// When the status is requested, all of the subordinate channels are also invoked,
/// and the resulting metrics are then returned in the order of their 'subordinate level'.
/// If there's more than one channel with the same level, the order is not defined.
@@ -278,7 +278,7 @@ private void Stop()
public void Checkpoint(params Double?[] values)
{
// We are ignoring all checkpoints from subordinates.
- // REVIEW: maybe this could be changed in the future. Right now it seems that
+ // REVIEW: maybe this could be changed in the future. Right now it seems that
// this limitation is reasonable.
}
}
@@ -287,7 +287,7 @@ public void Checkpoint(params Double?[] values)
///
/// This class listens to the progress reporting channels, caches all checkpoints and
/// start/stop events and, on demand, requests current progress on all active calculations.
- ///
+ ///
/// The public methods of this class should only be called from one thread.
///
public sealed class ProgressTracker
@@ -303,7 +303,7 @@ public sealed class ProgressTracker
///
/// For each calculation, its properties.
/// This list is protected by , and it's updated every time a new calculation starts.
- /// The entries are cleaned up when the start and stop events are reported (that is, after the first
+ /// The entries are cleaned up when the start and stop events are reported (that is, after the first
/// pull request after the calculation's 'Stop' event).
///
private readonly List _infos;
@@ -319,8 +319,8 @@ public sealed class ProgressTracker
private readonly HashSet _namesUsed;
///
- /// This class is an 'event log' for one calculation.
- ///
+ /// This class is an 'event log' for one calculation.
+ ///
/// Every time a calculation is 'started', it gets its own log, so if there are multiple 'start' calls,
/// there will be multiple logs.
///
@@ -425,12 +425,12 @@ public void Log(ProgressChannel source, ProgressEvent.EventKind kind, ProgressEn
}
///
- /// Get progress reports from all current calculations.
+ /// Get progress reports from all current calculations.
/// For every calculation the following events will be returned:
/// * A start event.
/// * Each checkpoint.
- /// * If the calculation is finished, the stop event.
- ///
+ /// * If the calculation is finished, the stop event.
+ ///
/// Each of the above events will be returned exactly once.
/// If, for one calculation, there's no events in the above categories, the tracker will
/// request ('pull') the current progress and return this as an event.
@@ -490,14 +490,14 @@ public sealed class ProgressEntry : IProgressEntry
///
/// The actual progress (amount of completed units), in the units that are contained in the header.
/// Parallel to the header's . Null value indicates 'not applicable now'.
- ///
+ ///
/// The computation should not modify these arrays directly, and instead rely on ,
/// and .
///
public readonly Double?[] Progress;
///
- /// The lim values of each progress unit.
+ /// The lim values of each progress unit.
/// Parallel to the header's . Null value indicates unbounded or unknown.
///
public readonly Double?[] ProgressLim;
diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
index 7d609454bf..2e35be86b7 100644
--- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
+++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
@@ -2,15 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Internal.Utilities;
namespace Microsoft.ML.Runtime.Data
{
///
- /// This contains information about a column in an IDataView. It is essentially a convenience
- /// cache containing the name, column index, and column type for the column.
+ /// This contains information about a column in an . It is essentially a convenience cache
+ /// containing the name, column index, and column type for the column. The intended usage is that users of
+ /// will have a convenient method of getting the index and type without having to separately query it through the ,
+ /// since practically the first thing a consumer of a will want to do once they get a mappping is get
+ /// the type and index of the corresponding column.
///
public sealed class ColumnInfo
{
@@ -31,9 +33,8 @@ private ColumnInfo(string name, int index, ColumnType type)
///
public static ColumnInfo CreateFromName(ISchema schema, string name, string descName)
{
- ColumnInfo colInfo;
- if (!TryCreateFromName(schema, name, out colInfo))
- throw Contracts.ExceptParam(nameof(name), "{0} column '{1}' not found", descName, name);
+ if (!TryCreateFromName(schema, name, out var colInfo))
+ throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found");
return colInfo;
}
@@ -48,8 +49,7 @@ public static bool TryCreateFromName(ISchema schema, string name, out ColumnInfo
Contracts.CheckNonEmpty(name, nameof(name));
colInfo = null;
- int index;
- if (!schema.TryGetColumnIndex(name, out index))
+ if (!schema.TryGetColumnIndex(name, out int index))
return false;
colInfo = new ColumnInfo(name, index, schema.GetColumnType(index));
@@ -71,13 +71,34 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
}
///
- /// Encapsulates an ISchema plus column role mapping information. It has convenience fields for
- /// several common column roles, but can hold an arbitrary set of column infos. The convenience
- /// fields are non-null iff there is a unique column with the corresponding role. When there are
- /// no such columns or more than one such column, the field is null. The Has, HasUnique, and
- /// HasMultiple methods provide some cardinality information.
- /// Note that all columns assigned roles are guaranteed to be non-hidden in this schema.
+ /// Encapsulates an plus column role mapping information. The purpose of role mappings is to
+ /// provide information on what the intended usage is for. That is: while a given data view may have a column named
+ /// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
+ /// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
+ /// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
+ /// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
+ /// multiple features columns to consume that information.
+ ///
+ /// This class has convenience fields for several common column roles (se.g., , ), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only
+ /// if there is a unique column with the corresponding role. When there are no such columns or more than one such
+ /// column, the field is null. The , , and
+ /// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden
+ /// in this schema.
///
+ ///
+ /// Note that instances of this class are, like instances of , immutable.
+ ///
+ /// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
+ /// that case, please use the class.
+ ///
+ /// Note that there is no need for components consuming a or
+ /// to make use of every defined mapping. Consuming components are also expected to ignore any
+ /// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s)
+ /// mapped from the role are not of the form they require.
+ ///
+ ///
+ ///
public sealed class RoleMappedSchema
{
private const string FeatureString = "Feature";
@@ -85,21 +106,59 @@ public sealed class RoleMappedSchema
private const string GroupString = "Group";
private const string WeightString = "Weight";
private const string NameString = "Name";
- private const string IdString = "Id";
private const string FeatureContributionsString = "FeatureContributions";
+ ///
+ /// Instances of this are the keys of a . This class also holds some important
+ /// commonly used pre-defined instances available (e.g., , ) that should
+ /// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom
+ /// roles if approppriate for their task.
+ ///
public struct ColumnRole
{
- public static ColumnRole Feature { get { return new ColumnRole(FeatureString); } }
- public static ColumnRole Label { get { return new ColumnRole(LabelString); } }
- public static ColumnRole Group { get { return new ColumnRole(GroupString); } }
- public static ColumnRole Weight { get { return new ColumnRole(WeightString); } }
- public static ColumnRole Name { get { return new ColumnRole(NameString); } }
- public static ColumnRole Id { get { return new ColumnRole(IdString); } }
- public static ColumnRole FeatureContributions { get { return new ColumnRole(FeatureContributionsString); } }
-
+ ///
+ /// Role for features. Commonly used as the independent variables given to trainers, and scorers.
+ ///
+ public static ColumnRole Feature => FeatureString;
+
+ ///
+ /// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators.
+ ///
+ public static ColumnRole Label => LabelString;
+
+ ///
+ /// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or
+ /// sequence classification, for defining the boundaries of an utterance.
+ ///
+ public static ColumnRole Group => GroupString;
+
+ ///
+ /// Role for sample weights. Commonly used to point to a number to make trainers give more weight
+ /// to a particular example.
+ ///
+ public static ColumnRole Weight => WeightString;
+
+ ///
+ /// Role for sample names. Useful for informational and tracking purposes when scoring, but typically
+ /// without affecting results.
+ ///
+ public static ColumnRole Name => NameString;
+
+ // REVIEW: Does this really belong here?
+ ///
+ /// Role for feature contributions. Useful for specific diagnostic functionality.
+ ///
+ public static ColumnRole FeatureContributions => FeatureContributionsString;
+
+ ///
+ /// The string value for the role. Guaranteed to be non-empty.
+ ///
public readonly string Value;
+ ///
+ /// Constructor for the column role.
+ ///
+ /// The value for the role. Must be non-empty.
public ColumnRole(string value)
{
Contracts.CheckNonEmpty(value, nameof(value));
@@ -107,55 +166,51 @@ public ColumnRole(string value)
}
public static implicit operator ColumnRole(string value)
- {
- return new ColumnRole(value);
- }
-
+ => new ColumnRole(value);
+
+ ///
+ /// Convenience method for creating a mapping pair from a role to a column name
+ /// for giving to constructors of and .
+ ///
+ /// The column name to map to. Can be null, in which case when used
+ /// to construct a role mapping structure this pair will be ignored
+ /// A key-value pair with this instance as the key and as the value
public KeyValuePair Bind(string name)
- {
- return new KeyValuePair(this, name);
- }
+ => new KeyValuePair(this, name);
}
public static KeyValuePair CreatePair(ColumnRole role, string name)
- {
- return new KeyValuePair(role, name);
- }
-
- ///
- /// The source ISchema.
- ///
- public readonly ISchema Schema;
+ => new KeyValuePair(role, name);
///
- /// The Feature column, when there is exactly one (null otherwise).
+ /// The source .
///
- public readonly ColumnInfo Feature;
+ public ISchema Schema { get; }
///
- /// The Label column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Label;
+ public ColumnInfo Feature { get; }
///
- /// The Group column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Group;
+ public ColumnInfo Label { get; }
///
- /// The Weight column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Weight;
+ public ColumnInfo Group { get; }
///
- /// The Name column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Name;
+ public ColumnInfo Weight { get; }
///
- /// The Id column, when there is exactly one (null otherwise).
+ /// The column, when there is exactly one (null otherwise).
///
- public readonly ColumnInfo Id;
+ public ColumnInfo Name { get; }
// Maps from role to the associated column infos.
private readonly Dictionary> _map;
@@ -179,24 +234,21 @@ private RoleMappedSchema(ISchema schema, Dictionary> map, ColumnRole rol
Contracts.AssertNonEmpty(role.Value);
Contracts.AssertValue(info);
- List list;
- if (!map.TryGetValue(role.Value, out list))
+ if (!map.TryGetValue(role.Value, out var list))
{
list = new List();
map.Add(role.Value, list);
@@ -222,36 +273,21 @@ private static void Add(Dictionary> map, ColumnRole rol
list.Add(info);
}
- private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles)
- {
- Contracts.AssertValue(schema, "schema");
- Contracts.AssertValue(roles, "roles");
-
- var map = new Dictionary>();
- foreach (var kvp in roles)
- {
- Contracts.CheckNonEmpty(kvp.Key.Value, nameof(roles), "Bad column role");
- if (string.IsNullOrEmpty(kvp.Value))
- continue;
- var info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
- Add(map, kvp.Key.Value, info);
- }
- return map;
- }
-
- private static Dictionary> MapFromNamesOpt(ISchema schema, IEnumerable> roles)
+ private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles, bool opt = false)
{
- Contracts.AssertValue(schema, "schema");
- Contracts.AssertValue(roles, "roles");
+ Contracts.AssertValue(schema);
+ Contracts.AssertValue(roles);
var map = new Dictionary>();
foreach (var kvp in roles)
{
- Contracts.CheckNonEmpty(kvp.Key.Value, nameof(roles), "Bad column role");
+ Contracts.AssertNonEmpty(kvp.Key.Value);
if (string.IsNullOrEmpty(kvp.Value))
continue;
ColumnInfo info;
- if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
+ if (!opt)
+ info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
+ else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
continue;
Add(map, kvp.Key.Value, info);
}
@@ -262,39 +298,26 @@ private static Dictionary> MapFromNamesOpt(ISchema sche
/// Returns whether there are any columns with the given column role.
///
public bool Has(ColumnRole role)
- {
- return role.Value != null && _map.ContainsKey(role.Value);
- }
+ => _map.ContainsKey(role.Value);
///
/// Returns whether there is exactly one column of the given role.
///
public bool HasUnique(ColumnRole role)
- {
- IReadOnlyList cols;
- return role.Value != null && _map.TryGetValue(role.Value, out cols) && cols.Count == 1;
- }
+ => _map.TryGetValue(role.Value, out var cols) && cols.Count == 1;
///
/// Returns whether there are two or more columns of the given role.
///
public bool HasMultiple(ColumnRole role)
- {
- IReadOnlyList cols;
- return role.Value != null && _map.TryGetValue(role.Value, out cols) && cols.Count > 1;
- }
+ => _map.TryGetValue(role.Value, out var cols) && cols.Count > 1;
///
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
/// it returns null.
///
public IReadOnlyList GetColumns(ColumnRole role)
- {
- IReadOnlyList list;
- if (role.Value != null && _map.TryGetValue(role.Value, out list))
- return list;
- return null;
- }
+ => _map.TryGetValue(role.Value, out var list) ? list : null;
///
/// An enumerable over all role-column associations within this object.
@@ -326,14 +349,20 @@ public IEnumerable> GetColumnRoleNames()
///
public IEnumerable> GetColumnRoleNames(ColumnRole role)
{
- IReadOnlyList list;
- if (role.Value != null && _map.TryGetValue(role.Value, out list))
+ if (_map.TryGetValue(role.Value, out var list))
{
foreach (var info in list)
yield return new KeyValuePair(role, info.Name);
}
}
+ ///
+ /// Returns the corresponding to if there is
+ /// exactly one such mapping, and otherwise throws an exception.
+ ///
+ /// The role to look up
+ /// The info corresponding to that role, assuming there was only one column
+ /// mapped to that
public ColumnInfo GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
@@ -355,64 +384,102 @@ private static Dictionary> Copy(Dictionary
- /// Creates a RoleMappedSchema from the given schema with no column role assignments.
+ /// Constructor given a schema, and mapping pairs of roles to columns in the schema.
+ /// This skips null or empty column-names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedSchema Create(ISchema schema)
+ /// The schema over which roles are defined
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in will result in an exception being thrown,
+ /// but if true such values will be ignored
+ /// The column role to column name mappings
+ public RoleMappedSchema(ISchema schema, bool opt = false, params KeyValuePair[] roles)
+ : this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt)
{
- Contracts.CheckValue(schema, nameof(schema));
- return new RoleMappedSchema(schema, new Dictionary>());
}
///
- /// Creates a RoleMappedSchema from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
+ /// Constructor given a schema, and mapping pairs of roles to columns in the schema.
+ /// This skips null or empty column names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedSchema Create(ISchema schema, params KeyValuePair[] roles)
+ /// The schema over which roles are defined
+ /// The column role to column name mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedSchema(ISchema schema, IEnumerable> roles, bool opt = false)
+ : this(Contracts.CheckRef(schema, nameof(schema)),
+ MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedSchema(schema, MapFromNames(schema, roles));
}
- ///
- /// Creates a RoleMappedSchema from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
- ///
- public static RoleMappedSchema Create(ISchema schema, IEnumerable> roles)
+ private static IEnumerable> PredefinedRolesHelper(
+ string label, string feature, string group, string weight, string name,
+ IEnumerable> custom = null)
{
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedSchema(schema, MapFromNames(schema, roles));
+ if (!string.IsNullOrWhiteSpace(label))
+ yield return ColumnRole.Label.Bind(label);
+ if (!string.IsNullOrWhiteSpace(feature))
+ yield return ColumnRole.Feature.Bind(feature);
+ if (!string.IsNullOrWhiteSpace(group))
+ yield return ColumnRole.Group.Bind(group);
+ if (!string.IsNullOrWhiteSpace(weight))
+ yield return ColumnRole.Weight.Bind(weight);
+ if (!string.IsNullOrWhiteSpace(name))
+ yield return ColumnRole.Name.Bind(name);
+ if (custom != null)
+ {
+ foreach (var role in custom)
+ yield return role;
+ }
}
///
- /// Creates a RoleMappedSchema from the given schema and role/column-name pairs.
- /// This skips null or empty column-names, or column-names that are not found in the schema.
+ /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
+ /// is null or whitespace, it is ignored.
///
- public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable> roles)
+ /// The schema over which roles are defined
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// Any additional desired custom column role mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedSchema(ISchema schema, string label, string feature,
+ string group = null, string weight = null, string name = null,
+ IEnumerable> custom = null, bool opt = false)
+ : this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt)
{
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedSchema(schema, MapFromNamesOpt(schema, roles));
+ Contracts.CheckValueOrNull(label);
+ Contracts.CheckValueOrNull(feature);
+ Contracts.CheckValueOrNull(group);
+ Contracts.CheckValueOrNull(weight);
+ Contracts.CheckValueOrNull(name);
+ Contracts.CheckValueOrNull(custom);
}
}
///
- /// Encapsulates an IDataView plus a corresponding RoleMappedSchema. Note that the schema of the
- /// RoleMappedSchema is guaranteed to be the same schema of the IDataView, that is,
- /// Data.Schema == Schema.Schema.
+ /// Encapsulates an plus a corresponding .
+ /// Note that the schema of of is
+ /// guaranteed to equal the the of .
///
public sealed class RoleMappedData
{
///
/// The data.
///
- public readonly IDataView Data;
+ public IDataView Data { get; }
///
- /// The role mapped schema. Note that Schema.Schema is guaranteed to be the same as Data.Schema.
+ /// The role mapped schema. Note that 's is
+ /// guaranteed to be the same as 's .
///
- public readonly RoleMappedSchema Schema;
+ public RoleMappedSchema Schema { get; }
private RoleMappedData(IDataView data, RoleMappedSchema schema)
{
@@ -424,45 +491,61 @@ private RoleMappedData(IDataView data, RoleMappedSchema schema)
}
///
- /// Creates a RoleMappedData from the given data with no column role assignments.
- ///
- public static RoleMappedData Create(IDataView data)
- {
- Contracts.CheckValue(data, nameof(data));
- return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema));
- }
-
- ///
- /// Creates a RoleMappedData from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
+ /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
+ /// This skips null or empty column-names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedData Create(IDataView data, params KeyValuePair[] roles)
+ /// The data over which roles are defined
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in 's schema will result in an exception being thrown,
+ /// but if true such values will be ignored
+ /// The column role to column name mappings
+ public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair[] roles)
+ : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema, roles));
}
///
- /// Creates a RoleMappedData from the given schema and role/column-name pairs.
- /// This skips null or empty column-names.
+ /// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
+ /// This skips null or empty column-names. It will also skip column-names that are not
+ /// found in the schema if is true.
///
- public static RoleMappedData Create(IDataView data, IEnumerable> roles)
+ /// The schema over which roles are defined
+ /// The column role to column name mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in 's schema will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedData(IDataView data, IEnumerable> roles, bool opt = false)
+ : this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedData(data, RoleMappedSchema.Create(data.Schema, roles));
}
///
- /// Creates a RoleMappedData from the given schema and role/column-name pairs.
- /// This skips null or empty column-names, or column-names that are not found in the schema.
+ /// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
+ /// is null or whitespace, it is ignored.
///
- public static RoleMappedData CreateOpt(IDataView data, IEnumerable> roles)
+ /// The data over which roles are defined
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// The column name that will be mapped to the role
+ /// Any additional desired custom column role mappings
+ /// Whether to consider the column names specified "optional" or not. If false then any non-empty
+ /// values for the column names that does not appear in 's schema will result in an exception being thrown,
+ /// but if true such values will be ignored
+ public RoleMappedData(IDataView data, string label, string feature,
+ string group = null, string weight = null, string name = null,
+ IEnumerable> custom = null, bool opt = false)
+ : this(Contracts.CheckRef(data, nameof(data)),
+ new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt))
{
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(roles, nameof(roles));
- return new RoleMappedData(data, RoleMappedSchema.CreateOpt(data.Schema, roles));
+ Contracts.CheckValueOrNull(label);
+ Contracts.CheckValueOrNull(feature);
+ Contracts.CheckValueOrNull(group);
+ Contracts.CheckValueOrNull(weight);
+ Contracts.CheckValueOrNull(name);
+ Contracts.CheckValueOrNull(custom);
}
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs
index d5cc611e1d..1ac3858636 100644
--- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs
@@ -6,7 +6,7 @@
namespace Microsoft.ML.Runtime.Data
{
- // REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
+ // REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
// ownership of the channel so the derived classes don't have to.
///
diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs
index 9c75c19937..5cde023e69 100644
--- a/src/Microsoft.ML.Core/Data/ServerChannel.cs
+++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs
@@ -26,7 +26,7 @@ public sealed class ServerChannel : ServerChannel.IPendingBundleNotification, ID
private readonly string _identifier;
// This holds the running collection of named delegates, if any. The dictionary itself
- // is lazily initialized only when a listener
+ // is lazily initialized only when a listener
private Dictionary _toPublish;
private Action _onPublish;
private Bundle _published;
diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
index b94e25c9e3..ad07ec86a5 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
@@ -35,7 +35,7 @@ public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, objec
Contracts.AssertValue(val);
Func fn = IsValueWithinRange;
// Avoid trying to cast double as float. If range
- // was specified using floats, but value being checked
+ // was specified using floats, but value being checked
// is double, change range to be of type double
if (range.Type == typeof(float) && val is double)
range.CastToDouble();
diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
index 99cfec0dd9..8a4ab8ca43 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
@@ -15,7 +15,7 @@
namespace Microsoft.ML.Runtime.EntryPoints
{
///
- /// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining
+ /// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining
/// the module interface.
///
public static class TlcModule
@@ -124,7 +124,7 @@ public sealed class OutputAttribute : Attribute
public string Desc { get; set; }
///
- /// The rank order of the output. Because .NET reflection returns members in an unspecfied order, this
+ /// The rank order of the output. Because .NET reflection returns members in an unspecfied order, this
/// is the only way to ensure consistency.
///
public Double SortOrder { get; set; }
@@ -527,6 +527,11 @@ public sealed class EntryPointAttribute : Attribute
/// Short name of the Entry Point
///
public string ShortName { get; set; }
+
+ ///
+ /// The path to the XML documentation on the CSharpAPI component
+ ///
+ public string[] XmlInclude { get; set; }
}
///
@@ -539,11 +544,11 @@ public enum DataKind
///
Unknown = 0,
///
- /// Integer, including long.
+ /// Integer, including long.
///
Int,
///
- /// Unsigned integer, including ulong.
+ /// Unsigned integer, including ulong.
///
UInt,
///
@@ -583,11 +588,11 @@ public enum DataKind
///
Enum,
///
- /// An array (0 or more values of the same type, accessible by index).
+ /// An array (0 or more values of the same type, accessible by index).
///
Array,
///
- /// A dictionary (0 or more values of the same type, identified by a unique string key).
+ /// A dictionary (0 or more values of the same type, identified by a unique string key).
/// The underlying C# representation is
///
Dictionary,
@@ -598,7 +603,7 @@ public enum DataKind
///
Component,
///
- /// An C# object that represents state, such as .
+ /// An C# object that represents state, such as .
///
State
}
@@ -677,8 +682,8 @@ protected Optional(bool isExplicit)
/// This is a 'maybe' class that is able to differentiate the cases when the value is set 'explicitly', or 'implicitly'.
/// The idea is that if the default value is specified by the user, in some cases it needs to be treated differently
/// than if it's auto-filled.
- ///
- /// An example is the weight column: the default behavior is to use 'Weight' column if it's present. But if the user explicitly sets
+ ///
+ /// An example is the weight column: the default behavior is to use 'Weight' column if it's present. But if the user explicitly sets
/// the weight column to be 'Weight', we need to actually enforce the presence of the column.
///
/// The type of the value
@@ -714,7 +719,7 @@ public static implicit operator T(Optional optional)
}
///
- /// The implicit conversion from .
+ /// The implicit conversion from .
/// This will assume that the parameter is set 'explicitly'.
///
public static implicit operator Optional(T value)
diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
index 498a75c9e5..60511bfd39 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
@@ -44,6 +44,7 @@ public sealed class EntryPointInfo
public readonly string Description;
public readonly string ShortName;
public readonly string FriendlyName;
+ public readonly string[] XmlInclude;
public readonly MethodInfo Method;
public readonly Type InputType;
public readonly Type OutputType;
@@ -63,6 +64,7 @@ internal EntryPointInfo(IExceptionContext ectx, MethodInfo method,
Method = method;
ShortName = attribute.ShortName;
FriendlyName = attribute.UserName;
+ XmlInclude = attribute.XmlInclude;
ObsoleteAttribute = obsoleteAttribute;
// There are supposed to be 2 parameters, env and input for non-macro nodes.
@@ -259,7 +261,7 @@ private bool ScanForComponents(IExceptionContext ectx, Type nestedType)
}
///
- /// The valid names for the components and entry points must consist of letters, digits, underscores and dots,
+ /// The valid names for the components and entry points must consist of letters, digits, underscores and dots,
/// and begin with a letter or digit.
///
private static readonly Regex _nameRegex = new Regex(@"^\w[_\.\w]*$", RegexOptions.Compiled);
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 631de3eb77..d4ff5ccd96 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -87,7 +87,7 @@ public interface IMessageSource
///
/// A that is also a channel listener can attach
- /// listeners for messages, as sent through .
+ /// listeners for messages, as sent through .
///
public interface IMessageDispatcher : IHostEnvironment
{
@@ -109,7 +109,7 @@ public interface IMessageDispatcher : IHostEnvironment
///
/// A basic host environment suited for many environments.
- /// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the
+ /// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the
/// AddListener/RemoveListener methods, and exposes the to
/// query progress.
///
@@ -315,7 +315,7 @@ protected sealed class Dispatcher : Dispatcher
/// This field is actually used as a , which holds the listener actions
/// for all listeners that are currently subscribed. The action itself is an immutable object, so every time
/// any listener subscribes or unsubscribes, the field is replaced with a modified version of the delegate.
- ///
+ ///
/// The field can be null, if no listener is currently subscribed.
///
private volatile Action _listenerAction;
@@ -488,10 +488,8 @@ protected virtual IProgressChannel StartProgressChannelCore(HostBase host, strin
///
protected virtual IFileHandle OpenInputFileCore(IHostEnvironment env, string path)
{
-#pragma warning disable TLC_NoThis // Do not use 'this' keyword for member access
this.AssertValue(env);
this.CheckNonWhiteSpace(path, nameof(path));
-#pragma warning restore TLC_NoThis // Do not use 'this' keyword for member access
if (Master != null)
return Master.OpenInputFileCore(env, path);
return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
@@ -511,10 +509,8 @@ public IFileHandle CreateOutputFile(string path)
///
protected virtual IFileHandle CreateOutputFileCore(IHostEnvironment env, string path)
{
-#pragma warning disable TLC_NoThis // Do not use 'this' keyword for member access
this.AssertValue(env);
this.CheckNonWhiteSpace(path, nameof(path));
-#pragma warning restore TLC_NoThis // Do not use 'this' keyword for member access
if (Master != null)
return Master.CreateOutputFileCore(env, path);
return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
@@ -532,9 +528,7 @@ public IFileHandle CreateTempFile(string suffix = null, string prefix = null)
///
protected IFileHandle CreateAndRegisterTempFile(IHostEnvironment env, string suffix = null, string prefix = null)
{
-#pragma warning disable TLC_NoThis // Do not use 'this' keyword for member access
this.AssertValue(env);
-#pragma warning restore TLC_NoThis // Do not use 'this' keyword for member access
if (Master != null)
return Master.CreateAndRegisterTempFile(env, suffix, prefix);
@@ -556,10 +550,8 @@ protected IFileHandle CreateAndRegisterTempFile(IHostEnvironment env, string suf
protected virtual IFileHandle CreateTempFileCore(IHostEnvironment env, string suffix = null, string prefix = null)
{
-#pragma warning disable TLC_NoThis // Do not use 'this' keyword for member access
this.CheckParam(!HasBadFileCharacters(suffix), nameof(suffix));
this.CheckParam(!HasBadFileCharacters(prefix), nameof(prefix));
-#pragma warning restore TLC_NoThis // Do not use 'this' keyword for member access
Guid guid = Guid.NewGuid();
string path = Path.GetFullPath(Path.Combine(Path.GetTempPath(), prefix + guid.ToString() + suffix));
diff --git a/src/Microsoft.ML.Core/Environment/TlcEnvironment.cs b/src/Microsoft.ML.Core/Environment/TlcEnvironment.cs
index ccf60dc28a..13781c5c11 100644
--- a/src/Microsoft.ML.Core/Environment/TlcEnvironment.cs
+++ b/src/Microsoft.ML.Core/Environment/TlcEnvironment.cs
@@ -225,7 +225,7 @@ public void GetAndPrintAllProgress(ProgressReporting.ProgressTracker progressTra
if (PrintDot())
{
- // We need to print an extended status line. At this point, every event should be
+ // We need to print an extended status line. At this point, every event should be
// a non-checkpoint progress event.
bool needPrepend = entries.Count > 1;
foreach (var ev in entries)
@@ -306,7 +306,7 @@ private void EnsureNewLine(bool isError = false)
return;
// If _err and _out is the same writer, we need to print new line as well.
- // If _out and _err writes to Console.Out and Console.Error respectively,
+ // If _out and _err writes to Console.Out and Console.Error respectively,
// in the general user scenario they ends up with writing to the same underlying stream,.
// so write a new line to the stream anyways.
if (isError && _err != _out && (_out != Console.Out || _err != Console.Error))
diff --git a/src/Microsoft.ML.Core/Prediction/ISweeper.cs b/src/Microsoft.ML.Core/Prediction/ISweeper.cs
index a0a1850be0..fe887e0ae2 100644
--- a/src/Microsoft.ML.Core/Prediction/ISweeper.cs
+++ b/src/Microsoft.ML.Core/Prediction/ISweeper.cs
@@ -174,6 +174,11 @@ public override string ToString()
{
return string.Join(" ", _parameterValues.Select(kvp => string.Format("{0}={1}", kvp.Value.Name, kvp.Value.ValueText)).ToArray());
}
+
+ public override int GetHashCode()
+ {
+ return _hash;
+ }
}
///
@@ -205,8 +210,8 @@ public sealed class RunResult : IRunResult
private readonly bool _isMetricMaximizing;
///
- /// This switch changes the behavior of the CompareTo function, switching the greater than / less than
- /// behavior, depending on if it is set to True.
+ /// This switch changes the behavior of the CompareTo function, switching the greater than / less than
+ /// behavior, depending on if it is set to True.
///
public bool IsMetricMaximizing { get { return _isMetricMaximizing; } }
@@ -262,8 +267,8 @@ IComparable IRunResult.MetricValue
///
/// The metric class, used by smart sweeping algorithms.
- /// Ideally we would like to move towards the new IDataView/ISchematized, this is
- /// just a simple view instead, and it is decoupled from RunResult so we can move
+ /// Ideally we would like to move towards the new IDataView/ISchematized, this is
+ /// just a simple view instead, and it is decoupled from RunResult so we can move
/// in that direction in the future.
///
public sealed class RunMetric
diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
index cd8c6d12c8..b38a742d9a 100644
--- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs
+++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
@@ -2,9 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
-using System.IO;
+using Microsoft.ML.Runtime.Data;
namespace Microsoft.ML.Runtime
{
@@ -27,149 +26,79 @@ namespace Microsoft.ML.Runtime
public delegate void SignatureSequenceTrainer();
public delegate void SignatureMatrixRecommendingTrainer();
- ///
- /// Interface to provide extra information about a trainer.
- ///
- public interface ITrainerEx : ITrainer
- {
- // REVIEW: Ideally trainers should be able to communicate
- // something about the type of data they are capable of being trained
- // on, e.g., what ColumnKinds they want, how many of each, of what type,
- // etc. This interface seems like the most natural conduit for that sort
- // of extra information.
-
- // REVIEW: Can we please have consistent naming here?
- // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
- // be 'Needs' / 'Wants' anyway.
-
- ///
- /// Whether the trainer needs to see data in normalized form.
- ///
- bool NeedNormalization { get; }
-
- ///
- /// Whether the trainer needs calibration to produce probabilities.
- ///
- bool NeedCalibration { get; }
-
- ///
- /// Whether this trainer could benefit from a cached view of the data.
- ///
- bool WantCaching { get; }
- }
-
- public interface ITrainerHost
- {
- Random Rand { get; }
- int Verbosity { get; }
-
- TextWriter StdOut { get; }
- TextWriter StdErr { get; }
- }
-
- // The Trainer (of Factory) can optionally implement this.
- public interface IModelCombiner
- where TPredictor : IPredictor
- {
- TPredictor CombineModels(IEnumerable models);
- }
+ public delegate void SignatureModelCombiner(PredictionKind kind);
///
- /// Weakly typed interface for a trainer "session" that produces a predictor.
+ /// The base interface for a trainers. Implementors should not implement this interface directly,
+ /// but rather implement the more specific .
///
public interface ITrainer
{
///
- /// Return the type of prediction task for the produced predictor.
+ /// Auxiliary information about the trainer in terms of its capabilities
+ /// and requirements.
///
- PredictionKind PredictionKind { get; }
+ TrainerInfo Info { get; }
///
- /// Returns the trained predictor.
- /// REVIEW: Consider removing this.
+ /// Return the type of prediction task for the produced predictor.
///
- IPredictor CreatePredictor();
- }
-
- ///
- /// Interface implemented by the MetalinearLearners base class.
- /// Used to distinguish the MetaLinear Learners from the other learners
- ///
- public interface IMetaLinearTrainer
- {
-
- }
+ PredictionKind PredictionKind { get; }
- public interface ITrainer : ITrainer
- {
///
- /// Trains a predictor using the specified dataset.
+ /// Trains a predictor.
///
- /// Training dataset
- void Train(TDataSet data);
+ /// A context containing at least the training data
+ /// The trained predictor
+ ///
+ IPredictor Train(TrainContext context);
}
///
- /// Strongly typed generic interface for a trainer. A trainer object takes
- /// supervision data and produces a predictor.
+ /// Strongly typed generic interface for a trainer. A trainer object takes training data
+ /// and produces a predictor.
///
- /// Type of the training dataset
/// Type of predictor produced
- public interface ITrainer : ITrainer
+ public interface ITrainer : ITrainer
where TPredictor : IPredictor
{
///
- /// Returns the trained predictor.
+ /// Trains a predictor.
///
- /// Trained predictor ready to make predictions
- new TPredictor CreatePredictor();
+ /// A context containing at least the training data
+ /// The trained predictor
+ new TPredictor Train(TrainContext context);
}
- ///
- /// Trainers that want data to do their own validation implement this interface.
- ///
- public interface IValidatingTrainer : ITrainer
+ public static class TrainerExtensions
{
///
- /// Trains a predictor using the specified dataset.
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
+ /// Equivalent to calling
+ /// on a constructed with .
///
- /// Training dataset
- /// Validation dataset
- void Train(TDataSet data, TDataSet validData);
- }
+ /// The trainer
+ /// The training data.
+ /// The trained predictor
+ public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
+ => trainer.Train(new TrainContext(trainData));
- public interface IIncrementalTrainer : ITrainer
- {
- ///
- /// Trains a predictor using the specified dataset and a trained predictor.
- ///
- /// Training dataset
- /// A trained predictor
- void Train(TDataSet data, TPredictor predictor);
- }
-
- public interface IIncrementalValidatingTrainer : ITrainer
- {
///
- /// Trains a predictor using the specified dataset and a trained predictor.
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
+ /// Equivalent to calling
+ /// on a constructed with .
///
- /// Training dataset
- /// Validation dataset
- /// A trained predictor
- void Train(TDataSet data, TDataSet validData, TPredictor predictor);
+ /// The trainer
+ /// The training data.
+ /// The trained predictor
+ public static TPredictor Train(this ITrainer trainer, RoleMappedData trainData) where TPredictor : IPredictor
+ => trainer.Train(new TrainContext(trainData));
}
-#if FUTURE
- public interface IMultiTrainer :
- IMultiTrainer
- {
- }
-
- public interface IMultiTrainer :
- ITrainer
+ // A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
+ public interface IModelCombiner
+ where TPredictor : IPredictor
{
- void UpdatePredictor(TDataBatch trainInstance);
- IPredictor GetCurrentPredictor();
+ TPredictor CombineModels(IEnumerable models);
}
-#endif
}
diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
new file mode 100644
index 0000000000..be93ce68aa
--- /dev/null
+++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
@@ -0,0 +1,56 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime
+{
+ ///
+ /// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed
+ /// into or .
+ /// This holds at least a training set, as well as optioonally a predictor.
+ ///
+ public sealed class TrainContext
+ {
+ ///
+ /// The training set. Cannot be null.
+ ///
+ public RoleMappedData TrainingSet { get; }
+
+ ///
+ /// The validation set. Can be null. Note that passing a non-null validation set into
+ /// a trainer that does not support validation sets should not be considered an error condition. It
+ /// should simply be ignored in that case.
+ ///
+ public RoleMappedData ValidationSet { get; }
+
+ ///
+ /// The initial predictor, for incremental training. Note that if a implementor
+ /// does not support incremental training, then it can ignore it similarly to how one would ignore
+ /// . However, if the trainer does support incremental training and there
+ /// is something wrong with a non-null value of this, then the trainer ought to throw an exception.
+ ///
+ public IPredictor InitialPredictor { get; }
+
+ ///
+ /// Constructor, given a training set and optional other arguments.
+ ///
+ /// Will set to this value. This must be specified
+ /// Will set to this value if specified
+ /// Will set to this value if specified
+ public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null)
+ {
+ Contracts.CheckValue(trainingSet, nameof(trainingSet));
+ Contracts.CheckValueOrNull(validationSet);
+ Contracts.CheckValueOrNull(initialPredictor);
+
+ // REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
+ // That is, all the role mappings are the same and the columns between them have identical types?
+
+ TrainingSet = trainingSet;
+ ValidationSet = validationSet;
+ InitialPredictor = initialPredictor;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
new file mode 100644
index 0000000000..cce728e09a
--- /dev/null
+++ b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
@@ -0,0 +1,71 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Runtime
+{
+ ///
+ /// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
+ /// The intended usage is as the value for .
+ ///
+ public sealed class TrainerInfo
+ {
+ // REVIEW: Ideally trainers should be able to communicate
+ // something about the type of data they are capable of being trained
+ // on, e.g., what ColumnKinds they want, how many of each, of what type,
+ // etc. This interface seems like the most natural conduit for that sort
+ // of extra information.
+
+ ///
+ /// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce
+ /// normalization here.
+ ///
+ public bool NeedNormalization { get; }
+
+ ///
+ /// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce
+ /// binary classifier predictors that also do not have a natural probabilistic interpretation should have a
+ /// true value here.
+ ///
+ public bool NeedCalibration { get; }
+
+ ///
+ /// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the
+ /// data, or that need to build their own custom data structure over the data, will have a false here.
+ ///
+ public bool WantCaching { get; }
+
+ ///
+ /// Whether the trainer supports validation sets via . Not implementing
+ /// this interface and returning true from this property is an indication the trainer does not support
+ /// that.
+ ///
+ public bool SupportsValidation { get; }
+
+ ///
+ /// Whether the trainer can support incremental trainers via . Not
+ /// implementing this interface and returning true from this property is an indication the trainer does
+ /// not support that.
+ ///
+ public bool SupportsIncrementalTraining { get; }
+
+ ///
+ /// Initializes with the given parameters. The parameters have default values for the most typical values
+ /// for most classical trainers.
+ ///
+ /// The value for the property
+ /// The value for the property
+ /// The value for the property
+ /// The value for the property
+ /// The value for the property
+ public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true,
+ bool supportValid = false, bool supportIncrementalTrain = false)
+ {
+ NeedNormalization = normalization;
+ NeedCalibration = calibration;
+ WantCaching = caching;
+ SupportsValidation = supportValid;
+ SupportsIncrementalTraining = supportIncrementalTrain;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Core/Utilities/BigArray.cs b/src/Microsoft.ML.Core/Utilities/BigArray.cs
index d6c6ef7b9b..ba2e67b0d9 100644
--- a/src/Microsoft.ML.Core/Utilities/BigArray.cs
+++ b/src/Microsoft.ML.Core/Utilities/BigArray.cs
@@ -7,14 +7,14 @@
namespace Microsoft.ML.Runtime.Internal.Utilities
{
///
- /// An array-like data structure that supports storing more than
- /// many entries, up to 0x7FEFFFFF00000L.
- /// The entries are indexed by 64-bit integers, and a single entry can be accessed by
+ /// An array-like data structure that supports storing more than
+ /// many entries, up to 0x7FEFFFFF00000L.
+ /// The entries are indexed by 64-bit integers, and a single entry can be accessed by
/// the indexer if no modifications to the entries is desired, or the
/// method. Efficient looping can be accomplished by calling the method.
- /// This data structure employs the "length and capacity" pattern. The logical length
+ /// This data structure employs the "length and capacity" pattern. The logical length
/// can be retrieved from the property, which can possibly be strictly less
- /// than the total capacity.
+ /// than the total capacity.
///
/// The type of entries.
public sealed class BigArray
@@ -38,8 +38,8 @@ public sealed class BigArray
// The 2-D jagged array containing the entries.
// Its total size is larger than or equal to _length, but
// less than Length + BlockSize.
- // Each one-dimension subarray has length equal to BlockSize,
- // except for the last one, which has a positive length
+ // Each one-dimension subarray has length equal to BlockSize,
+ // except for the last one, which has a positive length
// less than or equal to BlockSize.
private T[][] _entries;
@@ -53,13 +53,13 @@ public sealed class BigArray
public long Length { get { return _length; } }
///
- /// Gets or sets the entry at .
+ /// Gets or sets the entry at .
///
///
- /// This indexer is not efficient for looping. If looping access to entries is desired,
+ /// This indexer is not efficient for looping. If looping access to entries is desired,
/// use the method instead.
- /// Note that unlike a normal array, the value returned from this indexer getter cannot be modified
- /// (e.g., by ++ operator or passing into a method as a ref parameter). To modify an entry, use
+ /// Note that unlike a normal array, the value returned from this indexer getter cannot be modified
+ /// (e.g., by ++ operator or passing into a method as a ref parameter). To modify an entry, use
/// the method instead.
///
public T this[long index]
@@ -113,7 +113,7 @@ public BigArray(long size = 0)
public delegate void Visitor(long index, ref T item);
///
- /// Applies a method at a given .
+ /// Applies a method at a given .
///
public void ApplyAt(long index, Visitor manip)
{
@@ -190,16 +190,16 @@ public void FillRange(long min, long lim, T value)
}
///
- /// Resizes the array so that its logical length equals . This method
- /// is more efficient than initialize another array and copy the entries because it preserves
+ /// Resizes the array so that its logical length equals . This method
+ /// is more efficient than initialize another array and copy the entries because it preserves
/// existing blocks. The actual capacity of the array may become larger than .
/// If equals , then no operation is done.
/// If is less than , the array shrinks in size
/// so that both its length and its capacity equal .
/// If is larger than , the array capacity grows
- /// to the smallest integral multiple of that is larger than ,
- /// unless is less than , in which case the capacity
- /// grows to double its current capacity or , which ever is larger,
+ /// to the smallest integral multiple of that is larger than ,
+ /// unless is less than , in which case the capacity
+ /// grows to double its current capacity or , which ever is larger,
/// but up to .
///
public void Resize(long newLength)
@@ -304,7 +304,7 @@ public void TrimCapacity()
}
///
- /// Appends the first elements of to the end.
+ /// Appends the first elements of to the end.
/// This method is thread safe related to calls to (assuming those copy operations
/// are happening over ranges already added), but concurrent calls to
/// should not be attempted. Intended usage is that
@@ -373,10 +373,10 @@ public void AddRange(T[] src, int length)
}
///
- /// Copies the subarray starting from index of length
- /// to the destination array .
- /// Concurrent calls to this method is valid even with one single concurrent call
- /// to .
+ /// Copies the subarray starting from index of length
+ /// to the destination array .
+ /// Concurrent calls to this method is valid even with one single concurrent call
+ /// to .
///
public void CopyTo(long idx, T[] dst, int length)
{
diff --git a/src/Microsoft.ML.Core/Utilities/CharUtils.cs b/src/Microsoft.ML.Core/Utilities/CharUtils.cs
index e459452041..bf7ae4677e 100644
--- a/src/Microsoft.ML.Core/Utilities/CharUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/CharUtils.cs
@@ -13,8 +13,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities
public static class CharUtils
{
private const int CharsCount = 0x10000;
- private volatile static char[] _lowerInvariantChars;
- private volatile static char[] _upperInvariantChars;
+ private static volatile char[] _lowerInvariantChars;
+ private static volatile char[] _upperInvariantChars;
private static char[] EnsureLowerInvariant()
{
diff --git a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs
index babc135eb0..f2d5573211 100644
--- a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs
+++ b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs
@@ -433,12 +433,6 @@ public static bool TryParse(out Double value, string s, int ichMin, int ichLim,
if (FloatUtils.GetBits(x) != 0 || FloatUtils.GetBits(value) != TopBit || !neg)
{
System.Diagnostics.Debug.WriteLine("*** FloatParser disagrees with Double.TryParse on: {0} ({1} vs {2})", str, FloatUtils.GetBits(x), FloatUtils.GetBits(value));
- //if (!_failed)
- //{
- // // REVIEW: Double.Parse gets several things wrong, like mapping 148e-325 to 0x2 instead of 0x3.
- // _failed = true;
- // Contracts.Assert(false, string.Format("FloatParser disagrees with Double.TryParse on: {0} ({1} vs {2})", str, FloatUtils.GetBits(x), FloatUtils.GetBits(value)));
- //}
}
}
#endif
diff --git a/src/Microsoft.ML.Core/Utilities/HashArray.cs b/src/Microsoft.ML.Core/Utilities/HashArray.cs
index c76ceb9482..27f0ec9b5d 100644
--- a/src/Microsoft.ML.Core/Utilities/HashArray.cs
+++ b/src/Microsoft.ML.Core/Utilities/HashArray.cs
@@ -243,7 +243,7 @@ private static class HashHelpers
{
// Note: This HashHelpers class was adapted from the BCL code base.
- // This is the maximum prime smaller than Array.MaxArrayLength
+ // This is the maximum prime smaller than Array.MaxArrayLength
public const int MaxPrimeArrayLength = 0x7FEFFFFD;
// Table of prime numbers to use as hash table sizes.
@@ -271,7 +271,7 @@ public static int GetPrime(int min)
return min + 1;
}
- // Returns size of hashtable to grow to.
+ // Returns size of hashtable to grow to.
public static int ExpandPrime(int oldSize)
{
int newSize = 2 * oldSize;
diff --git a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs
index 8825369455..73b4c4a828 100644
--- a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs
+++ b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs
@@ -19,26 +19,24 @@ public sealed class HybridMemoryStream : Stream
{
private MemoryStream _memStream;
private Stream _overflowStream;
- private string _overflowPath;
private readonly int _overflowBoundary;
private const int _defaultMaxLen = 1 << 30;
private bool _disposed;
- private Stream MyStream { get { return _memStream ?? _overflowStream; } }
+ private Stream MyStream => _memStream ?? _overflowStream;
- private bool IsMemory { get { return _memStream != null; } }
+ private bool IsMemory => _memStream != null;
- public override long Position
- {
- get { return MyStream.Position; }
- set { Seek(value, SeekOrigin.Begin); }
+ public override long Position {
+ get => MyStream.Position;
+ set => Seek(value, SeekOrigin.Begin);
}
- public override long Length { get { return MyStream.Length; } }
- public override bool CanWrite { get { return MyStream.CanWrite; } }
- public override bool CanSeek { get { return MyStream.CanSeek; } }
- public override bool CanRead { get { return MyStream.CanRead; } }
+ public override long Length => MyStream.Length;
+ public override bool CanWrite => MyStream.CanWrite;
+ public override bool CanSeek => MyStream.CanSeek;
+ public override bool CanRead => MyStream.CanRead;
///
/// Constructs an initially empty read-write stream. Once the number of
@@ -123,27 +121,24 @@ protected override void Dispose(bool disposing)
var overflow = _overflowStream;
_overflowStream = null;
overflow.Dispose();
- Contracts.AssertValue(_overflowPath);
- File.Delete(_overflowPath);
- _overflowPath = null;
}
_disposed = true;
AssertInvariants();
+ base.Dispose(disposing);
}
}
public override void Close()
{
AssertInvariants();
- if (MyStream != null)
- MyStream.Close();
+ // The base Stream class Close will call Dispose(bool).
+ base.Close();
}
public override void Flush()
{
AssertInvariants();
- if (MyStream != null)
- MyStream.Flush();
+ MyStream?.Flush();
AssertInvariants();
}
@@ -164,9 +159,9 @@ private void EnsureOverflow()
// been closed.
Contracts.Check(_memStream.CanRead, "attempt to perform operation on closed stream");
- Contracts.Assert(_overflowPath == null);
- _overflowPath = Path.GetTempFileName();
- _overflowStream = new FileStream(_overflowPath, FileMode.Open, FileAccess.ReadWrite);
+ string overflowPath = Path.GetTempFileName();
+ _overflowStream = new FileStream(overflowPath, FileMode.Open, FileAccess.ReadWrite,
+ FileShare.None, bufferSize: 4096, FileOptions.DeleteOnClose);
// The documentation is not clear on this point, but the source code for
// memory stream makes clear that this buffer is exposable for a memory
diff --git a/src/Microsoft.ML.Core/Utilities/MathUtils.cs b/src/Microsoft.ML.Core/Utilities/MathUtils.cs
index 8106ff5a2c..fb68ee82d6 100644
--- a/src/Microsoft.ML.Core/Utilities/MathUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/MathUtils.cs
@@ -133,7 +133,7 @@ public static Float Min(Float[] a)
///
/// Finds the first index of the max element of the array.
- /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
+ /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
/// returned. The caller should distinguish in this case between two
/// possibilities:
/// 1) The number of the element to consider is zero.
@@ -147,8 +147,8 @@ public static int ArgMax(Float[] a)
}
///
- /// Finds the first index of the max element of the array.
- /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
+ /// Finds the first index of the max element of the array.
+ /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
/// returned. The caller should distinguish in this case between two
/// possibilities:
/// 1) The number of the element to consider is zero.
@@ -179,7 +179,7 @@ public static int ArgMax(Float[] a, int count)
///
/// Finds the first index of the minimum element of the array.
- /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
+ /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
/// returned. The caller should distinguish in this case between two
/// possibilities:
/// 1) The number of the element to consider is zero.
@@ -194,7 +194,7 @@ public static int ArgMin(Float[] a)
///
/// Finds the first index of the minimum element of the array.
- /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
+ /// NaNs are ignored. If all the elements to consider are NaNs, -1 is
/// returned. The caller should distinguish in this case between two
/// possibilities:
/// 1) The number of the element to consider is zero.
@@ -258,10 +258,6 @@ public static Float SoftMax(Float[] inputs, int count)
if (count == 1)
return max;
- //else if (leng == 2) {
- // return SoftMax(inputs[0], inputs[1]);
- //}
-
double intermediate = 0.0;
Float cutoff = max - LogTolerance;
@@ -335,9 +331,9 @@ public static bool AlmostEqual(Float a, Float b, Float maxRelErr, Float maxAbsEr
return (absDiff / maxAbs) <= maxRelErr;
}
- private readonly static int[] _possiblePrimeMod30 = new int[] { 1, 7, 11, 13, 17, 19, 23, 29 };
- private readonly static double _constantForLogGamma = 0.5 * Math.Log(2 * Math.PI);
- private readonly static double[] _coeffsForLogGamma = { 12.0, -360.0, 1260.0, -1680.0, 1188.0 };
+ private static readonly int[] _possiblePrimeMod30 = new int[] { 1, 7, 11, 13, 17, 19, 23, 29 };
+ private static readonly double _constantForLogGamma = 0.5 * Math.Log(2 * Math.PI);
+ private static readonly double[] _coeffsForLogGamma = { 12.0, -360.0, 1260.0, -1680.0, 1188.0 };
///
/// Returns the log of the gamma function, using the Stirling approximation
@@ -853,7 +849,7 @@ public static Float LnSum(IEnumerable terms)
}
///
- /// Math.Sin returns the input value for inputs with large magnitude. We return NaN instead, for consistency
+ /// Math.Sin returns the input value for inputs with large magnitude. We return NaN instead, for consistency
/// with Math.Sin(infinity).
///
public static double Sin(double a)
@@ -863,7 +859,7 @@ public static double Sin(double a)
}
///
- /// Math.Cos returns the input value for inputs with large magnitude. We return NaN instead, for consistency
+ /// Math.Cos returns the input value for inputs with large magnitude. We return NaN instead, for consistency
/// with Math.Cos(infinity).
///
public static double Cos(double a)
diff --git a/src/Microsoft.ML.Core/Utilities/MemUtils.cs b/src/Microsoft.ML.Core/Utilities/MemUtils.cs
index 736ae90892..1dba9205e9 100644
--- a/src/Microsoft.ML.Core/Utilities/MemUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/MemUtils.cs
@@ -10,7 +10,7 @@ public static class MemUtils
// .Net 4.6's Buffer.MemoryCopy.
// REVIEW: Remove once we're on a version of .NET which includes
// Buffer.MemoryCopy.
- public unsafe static void MemoryCopy(void* source, void* destination, long destinationSizeInBytes, long sourceBytesToCopy)
+ public static unsafe void MemoryCopy(void* source, void* destination, long destinationSizeInBytes, long sourceBytesToCopy)
{
// MemCpy has undefined behavior when handed overlapping source and
// destination buffers.
diff --git a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs
index d29bfe23c1..8c44315ba6 100644
--- a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs
+++ b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs
@@ -12,7 +12,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities
/// entities of known count, where you want to iteratively provide critical sections
/// for each depending on which comes first, but you do not necessarily know what
/// constitutes "first" until all such entities tell you where they stand in line.
- ///
+ ///
/// The anticipated usage is that whatever entity is using the
/// to synchronize itself, will register itself using
/// so as to unblock any "lower" waiters as soon as it knows what value it needs to
@@ -65,7 +65,7 @@ public MinWaiter(int waiters)
/// point when we actually want to wait. This method itself has the potential to
/// signal other events, if by registering ourselves the waiter becomes aware of
/// the maximum number of waiters, allowing that waiter to enter its critical state.
- ///
+ ///
/// If multiple events are associated with the minimum value, then only one will
/// be signaled, and the rest will remain unsignaled. Which is chosen is undefined.
///
@@ -75,7 +75,7 @@ public ManualResetEventSlim Register(long position)
lock (_waiters)
{
Contracts.Check(_maxWaiters > 0, "All waiters have been retired, Wait should not be called at this point");
- // We should never reach the state
+ // We should never reach the state
Contracts.Assert(_waiters.Count < _maxWaiters);
ev = new WaitStats(position);
// REVIEW: Optimize the case where this is the minimum?
diff --git a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
index 46486dc937..4a65286551 100644
--- a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
+++ b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
@@ -39,7 +39,7 @@ public abstract class ObjectPoolBase
public int Count => _pool.Count;
public int NumCreated { get { return _numCreated; } }
- protected internal ObjectPoolBase()
+ private protected ObjectPoolBase()
{
_pool = new ConcurrentBag();
}
diff --git a/src/Microsoft.ML.Core/Utilities/PathUtils.cs b/src/Microsoft.ML.Core/Utilities/PathUtils.cs
index 74ccec30c0..6698c11f7f 100644
--- a/src/Microsoft.ML.Core/Utilities/PathUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/PathUtils.cs
@@ -36,19 +36,19 @@ private static string DllDir
/// Attempts to find a file that is expected to be distributed with a TLC component. Searches
/// in the following order:
/// 1. In the customSearchDir directory, if it is provided.
- /// 2. In the custom search directory specified by the
+ /// 2. In the custom search directory specified by the
/// environment variable.
/// 3. In the root folder of the provided assembly.
/// 4. In the folder of this assembly.
/// In each case it searches the file in the directory provided and combined with folderPrefix.
- ///
+ ///
/// If any of these locations contain the file, a full local path will be returned, otherwise this
/// method will return null.
///
/// File name to find
/// folder prefix, relative to the current or customSearchDir
///
- /// Custom directory to search for resources.
+ /// Custom directory to search for resources.
/// If null, the path specified in the environment variable
/// will be used.
///
diff --git a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs
index a755788fb4..69b57fea45 100644
--- a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs
+++ b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs
@@ -9,8 +9,8 @@
namespace Microsoft.ML.Runtime.Internal.Utilities
{
///
- /// This is an interface for creating samples of a requested size from a stream of data of type .
- /// The sample is created in one pass by calling for every data point in the stream. Implementations should have
+ /// This is an interface for creating samples of a requested size from a stream of data of type .
+ /// The sample is created in one pass by calling for every data point in the stream. Implementations should have
/// a delegate for getting the next data point, which is invoked if the current data point should go into the reservoir.
///
public interface IReservoirSampler
@@ -44,10 +44,10 @@ public interface IReservoirSampler
}
///
- /// This class produces a sample without replacement from a stream of data of type .
- /// It is instantiated with a delegate that gets the next data point, and builds a reservoir in one pass by calling
+ /// This class produces a sample without replacement from a stream of data of type .
+ /// It is instantiated with a delegate that gets the next data point, and builds a reservoir in one pass by calling
/// for every data point in the stream. In case the next data point does not get 'picked' into the reservoir, the delegate is not invoked.
- /// Sampling is done according to the algorithm in this paper: .
+ /// Sampling is done according to the algorithm in this paper: http://epubs.siam.org/doi/pdf/10.1137/1.9781611972740.53.
///
public sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler
{
@@ -117,10 +117,10 @@ public IEnumerable GetSample()
}
///
- /// This class produces a sample with replacement from a stream of data of type .
- /// It is instantiated with a delegate that gets the next data point, and builds a reservoir in one pass by calling
+ /// This class produces a sample with replacement from a stream of data of type .
+ /// It is instantiated with a delegate that gets the next data point, and builds a reservoir in one pass by calling
/// for every data point in the stream. In case the next data point does not get 'picked' into the reservoir, the delegate is not invoked.
- /// Sampling is done according to the algorithm in this paper: .
+ /// Sampling is done according to the algorithm in this paper: http://epubs.siam.org/doi/pdf/10.1137/1.9781611972740.53.
///
public sealed class ReservoirSamplerWithReplacement : IReservoirSampler
{
@@ -237,7 +237,7 @@ public void Lock()
}
///
- /// Gets a reservoir sample with replacement of the elements sampled so far. Users should not change the
+ /// Gets a reservoir sample with replacement of the elements sampled so far. Users should not change the
/// elements returned since multiple elements in the reservoir might be pointing to the same memory.
///
public IEnumerable GetSample()
diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
index ccb4b0c90c..2cfa8c185a 100644
--- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
@@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities
///
public sealed class ResourceManagerUtils
{
- private volatile static ResourceManagerUtils _instance;
+ private static volatile ResourceManagerUtils _instance;
public static ResourceManagerUtils Instance
{
get
@@ -91,7 +91,7 @@ public static string GetUrl(string suffix)
/// The relative url from which to download.
/// This is appended to the url defined in .
/// The name of the file to save.
- /// The directory where the file should be saved to. The file will be saved in a directory with the specified name inside
+ /// The directory where the file should be saved to. The file will be saved in a directory with the specified name inside
/// a folder called "tlc-resources" in the directory.
/// An integer indicating the number of milliseconds to wait before timing out while downloading a resource.
/// The download results, containing the file path where the resources was (or should have been) downloaded to, and an error message
diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs
index 8b22e46380..41c794e17f 100644
--- a/src/Microsoft.ML.Core/Utilities/Stream.cs
+++ b/src/Microsoft.ML.Core/Utilities/Stream.cs
@@ -979,7 +979,7 @@ public static BitArray ReadBitArray(this BinaryReader reader)
return returnArray;
}
- public unsafe static void ReadBytes(this BinaryReader reader, void* destination, long destinationSizeInBytes, long bytesToRead, ref byte[] work)
+ public static unsafe void ReadBytes(this BinaryReader reader, void* destination, long destinationSizeInBytes, long bytesToRead, ref byte[] work)
{
Contracts.AssertValue(reader);
Contracts.Assert(bytesToRead >= 0);
@@ -1007,7 +1007,7 @@ public unsafe static void ReadBytes(this BinaryReader reader, void* destination,
}
}
- public unsafe static void ReadBytes(this BinaryReader reader, void* destination, long destinationSizeInBytes, long bytesToRead)
+ public static unsafe void ReadBytes(this BinaryReader reader, void* destination, long destinationSizeInBytes, long bytesToRead)
{
byte[] work = null;
ReadBytes(reader, destination, destinationSizeInBytes, bytesToRead, ref work);
@@ -1097,10 +1097,10 @@ public static bool TryGetBuffer(this MemoryStream mem, out ArraySegment bu
// REVIEW: need to plumb IExceptionContext into the method.
///
/// Checks that the directory of the file name passed in already exists.
- /// This is meant to be called before calling an API that creates the file,
+ /// This is meant to be called before calling an API that creates the file,
/// so the file need not exist.
///
- /// An absolute or relative file path, or null to skip the check
+ /// An absolute or relative file path, or null to skip the check
/// (useful for optional user parameters)
/// The user level parameter name, as exposed by the command line help
public static void CheckOptionalUserDirectory(string file, string userArgument)
@@ -1113,7 +1113,7 @@ public static void CheckOptionalUserDirectory(string file, string userArgument)
return;
string dir;
-#pragma warning disable TLC_ContractsNameUsesNameof
+#pragma warning disable MSML_ContractsNameUsesNameof
try
{
// Relative paths are interpreted as local.
@@ -1134,6 +1134,6 @@ public static void CheckOptionalUserDirectory(string file, string userArgument)
if (!Directory.Exists(dir))
throw Contracts.ExceptUserArg(userArgument, "Cannot find directory '{0}'.", dir);
}
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs
index 00dbf68d2c..a96e8df5c2 100644
--- a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs
+++ b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs
@@ -11,12 +11,12 @@
namespace Microsoft.ML.Runtime.Internal.Utilities
{
///
- /// This class performs discretization of (value, label) pairs into bins in a way that minimizes
+ /// This class performs discretization of (value, label) pairs into bins in a way that minimizes
/// the target function "minimum description length".
/// The algorithm is outlineed in an article
/// "Multi-Interval Discretization of Continuous-Valued Attributes for Classification Learning"
/// [Fayyad, Usama M.; Irani, Keki B. (1993)] http://ijcai.org/Past%20Proceedings/IJCAI-93-VOL2/PDF/022.pdf
- ///
+ ///
/// The class can be used several times sequentially, it is stateful and not thread-safe.
/// Both Single and Double precision processing is implemented, and is identical.
///
@@ -117,7 +117,7 @@ public Single[] FindBins(int maxBins, int minBinSize, int nLabels, IList
result[i] = BinFinderBase.GetSplitValue(distinctValues[split - 1], distinctValues[split]);
// Even though distinctValues may contain infinities, the boundaries may not be infinite:
- // GetSplitValue(a,b) only returns +-inf if a==b==+-inf,
+ // GetSplitValue(a,b) only returns +-inf if a==b==+-inf,
// and distinctValues won't contain more than one +inf or -inf.
Contracts.Assert(FloatUtils.IsFinite(result[i]));
}
@@ -195,7 +195,7 @@ public Double[] FindBins(int maxBins, int minBinSize, int nLabels, IList
result[i] = BinFinderBase.GetSplitValue(distinctValues[split - 1], distinctValues[split]);
// Even though distinctValues may contain infinities, the boundaries may not be infinite:
- // GetSplitValue(a,b) only returns +-inf if a==b==+-inf,
+ // GetSplitValue(a,b) only returns +-inf if a==b==+-inf,
// and distinctValues won't contain more than one +inf or -inf.
Contracts.Assert(FloatUtils.IsFinite(result[i]));
}
@@ -259,7 +259,7 @@ public SplitInterval(SupervisedBinFinder binFinder, int min, int lim, bool skipS
Contracts.Assert(leftCount + rightCount == totalCount);
// This term corresponds to the 'fixed cost associated with a split'
- // It's a simplification of a Delta(A,T;S) term calculated in the paper
+ // It's a simplification of a Delta(A,T;S) term calculated in the paper
var delta = logN - binFinder._labelCardinality * (totalEntropy - leftEntropy - rightEntropy);
var curGain = totalCount * totalEntropy // total cost of transmitting non-split content
diff --git a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs
index ab83bf40f0..5c05275ba7 100644
--- a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs
+++ b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs
@@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities
/// compensates by inserting \n line feed characters at the end of every
/// input line, including the last one.
///
- public class TextReaderStream : Stream
+ public sealed class TextReaderStream : Stream
{
private readonly TextReader _baseReader;
private readonly Encoding _encoding;
@@ -38,19 +38,11 @@ public class TextReaderStream : Stream
public override bool CanWrite => false;
public override long Length
- {
- get
- {
- throw Contracts.ExceptNotSupp("Stream cannot determine length.");
- }
- }
+ => throw Contracts.ExceptNotSupp("Stream cannot determine length.");
public override long Position
{
- get
- {
- return _position;
- }
+ get => _position;
set
{
if (value != Position)
@@ -96,6 +88,7 @@ public override void Close()
protected override void Dispose(bool disposing)
{
_baseReader.Dispose();
+ base.Dispose(disposing);
}
public override void Flush()
@@ -182,18 +175,12 @@ public override int ReadByte()
}
public override long Seek(long offset, SeekOrigin origin)
- {
- throw Contracts.ExceptNotSupp("Stream cannot seek.");
- }
+ => throw Contracts.ExceptNotSupp("Stream cannot seek.");
public override void Write(byte[] buffer, int offset, int count)
- {
- throw Contracts.ExceptNotSupp("Stream is not writable.");
- }
+ => throw Contracts.ExceptNotSupp("Stream is not writable.");
public override void SetLength(long value)
- {
- throw Contracts.ExceptNotSupp("Stream is not writable.");
- }
+ => throw Contracts.ExceptNotSupp("Stream is not writable.");
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs
index e7bc27235f..859ae7b28d 100644
--- a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs
@@ -64,7 +64,7 @@ public sealed class ExceptionMarshaller : IDisposable
private readonly CancellationTokenSource _ctSource;
private readonly object _lock;
- // The stored exception
+ // The stored exception
private string _component;
private Exception _ex;
diff --git a/src/Microsoft.ML.Core/Utilities/Tree.cs b/src/Microsoft.ML.Core/Utilities/Tree.cs
index 880afc4083..7d030cf46c 100644
--- a/src/Microsoft.ML.Core/Utilities/Tree.cs
+++ b/src/Microsoft.ML.Core/Utilities/Tree.cs
@@ -53,7 +53,7 @@ public Tree this[TKey key]
///
/// This is the key for this child node in its parent, if any. If this is not
- /// a child of any parent, that is, it is the root of its own tree, then
+ /// a child of any parent, that is, it is the root of its own tree, then
///
public TKey Key { get { return _key; } }
@@ -129,7 +129,7 @@ public void Add(KeyValuePair> item)
}
///
- /// Adds a node as a child of this node. This will disconnect the
+ /// Adds a node as a child of this node. This will disconnect the
///
///
///
diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs
index 48993de785..96c23a0fe3 100644
--- a/src/Microsoft.ML.Core/Utilities/Utils.cs
+++ b/src/Microsoft.ML.Core/Utilities/Utils.cs
@@ -898,7 +898,7 @@ private static MethodInfo MarshalInvokeCheckAndCreate(Type genArg, Delegat
/// but whose code depends on some sort of generic type parameter. This utility method exists to make
/// this common pattern more convenient, and also safer so that the arguments, if any, can be type
/// checked at compile time instead of at runtime.
- ///
+ ///
/// Because it is strongly typed, this can only be applied to methods whose return type
/// is known at compile time, that is, that do not depend on the type parameter of the method itself.
///
diff --git a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
index ae2463da0c..1be9c77ee4 100644
--- a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
@@ -349,7 +349,7 @@ public static void Apply(ref VBuffer dst, SlotValueManipulator manip)
/// The vector to modify
/// The slot of the vector to modify
/// The manipulation function
- /// A predicate that returns true if we should skip insertion of a value into
+ /// A predicate that returns true if we should skip insertion of a value into
/// sparse vector if it was default. If the predicate is null, we insert any non-default.
public static void ApplyAt(ref VBuffer dst, int slot, SlotValueManipulator manip, ValuePredicate pred = null)
{
@@ -489,7 +489,7 @@ public static void DensifyFirst(ref VBuffer dst, int denseCount)
}
///
- /// Creates a maybe sparse copy of a VBuffer.
+ /// Creates a maybe sparse copy of a VBuffer.
/// Whether the created copy is sparse or not is determined by the proportion of non-default entries compared to the sparsity parameter.
///
public static void CreateMaybeSparseCopy(ref VBuffer src, ref VBuffer dst, RefPredicate isDefaultPredicate, float sparsityThreshold = SparsityThreshold)
@@ -580,9 +580,9 @@ public static void ApplyWith(ref VBuffer src, ref VBuffer
/// Applies the to each pair of elements
- /// where is defined, in order of index. It stores the result
- /// in another vector. If there is some value at an index in
- /// that is not defined in , that slot value is copied to the
+ /// where is defined, in order of index. It stores the result
+ /// in another vector. If there is some value at an index in
+ /// that is not defined in , that slot value is copied to the
/// corresponding slot in the result vector without any further modification.
/// If either of the vectors are dense, the resulting
/// will be dense. Otherwise, if both are sparse, the output will be sparse iff
@@ -616,7 +616,7 @@ public static void ApplyWithEitherDefined(ref VBuffer src, ref
///
/// Applies the to each pair of elements
/// where either or , has an element
- /// defined at that index. It stores the result in another vector .
+ /// defined at that index. It stores the result in another vector .
/// If either of the vectors are dense, the resulting
/// will be dense. Otherwise, if both are sparse, the output will be sparse iff
/// there is any slot that is not explicitly represented in either vector.
@@ -1147,11 +1147,11 @@ private static void ApplyWithCoreCopy(ref VBuffer src, ref VBu
/// storing the result in , overwriting any of its existing contents.
/// The contents of do not affect calculation. If you instead wish
/// to calculate a function that reads and writes , see
- /// and . Post-operation,
+ /// and . Post-operation,
/// will be dense iff is dense.
///
- ///
- ///
+ ///
+ ///
public static void ApplyIntoEitherDefined(ref VBuffer src, ref VBuffer dst, Func func)
{
Contracts.CheckValue(func, nameof(func));
diff --git a/src/Microsoft.ML.CpuMath/AlignedArray.cs b/src/Microsoft.ML.CpuMath/AlignedArray.cs
index 1dc8e3ee46..87583a8ef6 100644
--- a/src/Microsoft.ML.CpuMath/AlignedArray.cs
+++ b/src/Microsoft.ML.CpuMath/AlignedArray.cs
@@ -13,7 +13,7 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath
/// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it
/// returns a Pin struct that is IDisposable). From the pin, you can get the IntPtr to pass to
/// native code.
- ///
+ ///
/// The ctor takes an alignment value, which must be a power of two at least sizeof(Float).
///
public sealed class AlignedArray
diff --git a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
index 5ec9b53cca..67f05ee7cf 100644
--- a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
+++ b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
@@ -80,7 +80,7 @@ private void AssertValid()
}
///
- /// The physical AligenedArray items.
+ /// The physical AligenedArray items.
///
public AlignedArray Items { get { return _items; } }
@@ -155,7 +155,7 @@ public void CopyTo(Float[] dst, ref int ivDst)
}
///
- /// Copy the values from this vector starting at slot ivSrc into dst, starting at slot ivDst.
+ /// Copy the values from this vector starting at slot ivSrc into dst, starting at slot ivDst.
/// The number of values that are copied is determined by count.
///
/// The staring index in this vector
@@ -525,7 +525,7 @@ public CpuAlignedMatrixRow(int crow, int ccol, int cbAlign)
public override int ColCountPhy { get { return RunLenPhy; } }
///
- /// Copy the values from this matrix, starting from the row into dst, starting at slot ivDst and advancing ivDst.
+ /// Copy the values from this matrix, starting from the row into dst, starting at slot ivDst and advancing ivDst.
///
/// The starting row in this matrix
/// The destination array
@@ -606,7 +606,7 @@ public void CopyTo(Float[] dst, ref int ivDst)
}
///
- /// Copy the values from this matrix, starting from the row into dst, starting at slot ivDst and advancing ivDst.
+ /// Copy the values from this matrix, starting from the row into dst, starting at slot ivDst and advancing ivDst.
///
/// The starting row in this matrix
/// The destination array
diff --git a/src/Microsoft.ML.CpuMath/AssemblyInfo.cs b/src/Microsoft.ML.CpuMath/AssemblyInfo.cs
new file mode 100644
index 0000000000..cb45bf5608
--- /dev/null
+++ b/src/Microsoft.ML.CpuMath/AssemblyInfo.cs
@@ -0,0 +1,9 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+[assembly: InternalsVisibleTo("Microsoft.ML.StandardLearners, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
\ No newline at end of file
diff --git a/src/Microsoft.ML.CpuMath/Avx.cs b/src/Microsoft.ML.CpuMath/Avx.cs
index 68e751c86b..6dcf898b6f 100644
--- a/src/Microsoft.ML.CpuMath/Avx.cs
+++ b/src/Microsoft.ML.CpuMath/Avx.cs
@@ -7,7 +7,7 @@
namespace Microsoft.ML.Runtime.Internal.CpuMath
{
///
- /// Keep Avx.cs in sync with Sse.cs. When making changes to one, use BeyondCompare or a similar tool
+ /// Keep Avx.cs in sync with Sse.cs. When making changes to one, use BeyondCompare or a similar tool
/// to view diffs and propagate appropriate changes to the other.
///
public static class AvxUtils
@@ -21,7 +21,7 @@ private static bool Compat(AlignedArray a)
return a.CbAlign == CbAlign;
}
- private unsafe static float* Ptr(AlignedArray a, float* p)
+ private static unsafe float* Ptr(AlignedArray a, float* p)
{
Contracts.AssertValue(a);
float* q = p + a.GetBase((long)p);
diff --git a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs
index 363c40007b..ad53810ff3 100644
--- a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs
+++ b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs
@@ -115,7 +115,7 @@ public static void MatTranTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src,
public static class GeneralUtils
{
///
- /// Count the number of zero bits in the lonest string of zero's from the lowest significant bit of the input integer.
+ /// Count the number of zero bits in the lonest string of zero's from the lowest significant bit of the input integer.
///
/// The input integer
///
diff --git a/src/Microsoft.ML.CpuMath/ICpuBuffer.cs b/src/Microsoft.ML.CpuMath/ICpuBuffer.cs
index e58a453f9f..ad55f5c8c6 100644
--- a/src/Microsoft.ML.CpuMath/ICpuBuffer.cs
+++ b/src/Microsoft.ML.CpuMath/ICpuBuffer.cs
@@ -77,8 +77,8 @@ public interface ICpuFullMatrix : ICpuMatrix
///
/// Zero out the items with the given indices.
- /// The indices contain the logical indices to the vectorized representation of the matrix,
- /// which can be different depending on whether the matrix is row-major or column-major.
+ /// The indices contain the logical indices to the vectorized representation of the matrix,
+ /// which can be different depending on whether the matrix is row-major or column-major.
///
void ZeroItems(int[] indices);
}
diff --git a/src/Microsoft.ML.CpuMath/IntUtils.cs b/src/Microsoft.ML.CpuMath/IntUtils.cs
index b0aed315c3..2492dddaff 100644
--- a/src/Microsoft.ML.CpuMath/IntUtils.cs
+++ b/src/Microsoft.ML.CpuMath/IntUtils.cs
@@ -84,7 +84,7 @@ private static ulong Div64(ulong lo, ulong hi, ulong den, out ulong rem)
return Div64Core(lo, hi, den, out rem);
}
- // REVIEW: on Linux, the hardware divide-by-zero exception is not translated into
+ // REVIEW: on Linux, the hardware divide-by-zero exception is not translated into
// a managed exception properly by CoreCLR so the process will crash. This is a temporary fix
// until CoreCLR addresses this issue.
[DllImport(Thunk.NativePath, CharSet = CharSet.Unicode, EntryPoint = "Div64"), SuppressUnmanagedCodeSecurity]
diff --git a/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj b/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj
index 62ba4f3a6a..bde7ae89f5 100644
--- a/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj
+++ b/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj
@@ -1,14 +1,22 @@
-
+
- netstandard2.0
- Microsoft.ML
+ Debug;Release;Debug-Intrinsics;Release-Intrinsics
+ $(Configuration.EndsWith('-Intrinsics'))
+
+ netstandard2.0
+ netstandard2.0;netcoreapp3.0
+ Microsoft.ML.CpuMath
true
- CORECLR
+ $(DefineConstants);CORECLR;PRIVATE_CONTRACTS
-
+
+
+
+
+
diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs
index 77be547b69..68e6ee906b 100644
--- a/src/Microsoft.ML.CpuMath/Sse.cs
+++ b/src/Microsoft.ML.CpuMath/Sse.cs
@@ -7,7 +7,7 @@
namespace Microsoft.ML.Runtime.Internal.CpuMath
{
///
- /// Keep Sse.cs in sync with Avx.cs. When making changes to one, use BeyondCompare or a similar tool
+ /// Keep Sse.cs in sync with Avx.cs. When making changes to one, use BeyondCompare or a similar tool
/// to view diffs and propagate appropriate changes to the other.
///
public static class SseUtils
@@ -21,7 +21,7 @@ private static bool Compat(AlignedArray a)
return a.CbAlign == CbAlign;
}
- private unsafe static float* Ptr(AlignedArray a, float* p)
+ private static unsafe float* Ptr(AlignedArray a, float* p)
{
Contracts.AssertValue(a);
float* q = p + a.GetBase((long)p);
diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs
index bc23963bbe..d7082c8313 100644
--- a/src/Microsoft.ML.CpuMath/Thunk.cs
+++ b/src/Microsoft.ML.CpuMath/Thunk.cs
@@ -9,7 +9,7 @@
namespace Microsoft.ML.Runtime.Internal.CpuMath
{
- internal unsafe static class Thunk
+ internal static unsafe class Thunk
{
internal const string NativePath = "CpuMathNative";
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index fc78e72c53..26ec32d3fe 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -254,7 +254,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
RoleMappedData srcData, IDataView marker)
{
var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, srcData.Data, dstData, marker);
- return RoleMappedData.Create(pipe, srcData.Schema.GetColumnRoleNames());
+ return new RoleMappedData(pipe, srcData.Schema.GetColumnRoleNames());
}
///
@@ -277,7 +277,7 @@ private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, I
// Training pipe and examples.
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- return TrainUtils.CreateExamples(data, label, features, group, weight, name, customCols);
+ return new RoleMappedData(data, label, features, group, weight, name, customCols);
}
private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
@@ -538,7 +538,7 @@ private FoldResult RunFold(int fold)
if (_getValidationDataView != null)
{
ch.Assert(_applyTransformsToValidationData != null);
- if (!TrainUtils.CanUseValidationData(trainer))
+ if (!trainer.Info.SupportsValidation)
ch.Warning("Trainer does not accept validation dataset.");
else
{
@@ -568,7 +568,7 @@ private FoldResult RunFold(int fold)
{
using (var file = host.CreateOutputFile(modelFileName))
{
- var rmd = RoleMappedData.Create(
+ var rmd = new RoleMappedData(
CompositeDataLoader.ApplyTransform(host, _loader, null, null,
(e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
trainData.Schema.GetColumnRoleNames());
@@ -581,17 +581,17 @@ private FoldResult RunFold(int fold)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
var eval = evalComp.CreateInstance(host);
- // Note that this doesn't require the provided columns to exist (because of "Opt").
+ // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
// We don't normally expect the scorer to drop columns, but if it does, we should not require
// all the columns in the test pipeline to still be present.
- var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames());
+ var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);
var dict = eval.Evaluate(dataEval);
RoleMappedData perInstance = null;
if (_savePerInstance)
{
var perInst = eval.GetPerInstanceMetrics(dataEval);
- perInstance = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames());
+ perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
}
ch.Done();
return new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema);
diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs
index 435c25bf5b..2a62d78901 100644
--- a/src/Microsoft.ML.Data/Commands/DataCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs
@@ -305,7 +305,7 @@ protected void LoadModelObjects(
// can be loaded with no data at all, to get their schemas.
if (trainPipe == null)
trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
- trainSchema = RoleMappedSchema.Create(trainPipe.Schema, trainRoleMappings);
+ trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
}
// If the role mappings are null, an alternative would be to fail. However the idea
// is that the scorer should always still succeed, although perhaps with reduced
diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
index d0e066d789..77bdf0e32f 100644
--- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
@@ -19,7 +19,7 @@
namespace Microsoft.ML.Runtime.Data
{
- // REVIEW: For simplicity (since this is currently the case),
+ // REVIEW: For simplicity (since this is currently the case),
// we assume that all metrics are either numeric, or numeric vectors.
///
/// This class contains information about an overall metric, namely its name and whether it is a vector
@@ -92,7 +92,7 @@ public string GetNameMatch(string input)
public interface IEvaluator
{
///
- /// Compute the aggregate metrics. Return a dictionary from the metric kind
+ /// Compute the aggregate metrics. Return a dictionary from the metric kind
/// (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing the metric.
///
Dictionary Evaluate(RoleMappedData data);
@@ -158,7 +158,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
evalComp = EvaluateUtils.GetEvaluatorType(ch, input.Schema);
var eval = evalComp.CreateInstance(env);
- var data = TrainUtils.CreateExamples(input, label, null, group, weight, null, customCols);
+ var data = new RoleMappedData(input, label, null, group, weight, null, customCols);
return eval.GetPerInstanceMetrics(data);
}
}
@@ -236,7 +236,7 @@ private void RunCore(IChannel ch)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, view.Schema);
var evaluator = evalComp.CreateInstance(Host);
- var data = TrainUtils.CreateExamples(view, label, null, group, weight, name, customCols);
+ var data = new RoleMappedData(view, label, null, group, weight, name, customCols);
var metrics = evaluator.Evaluate(data);
MetricWriter.PrintWarnings(ch, metrics);
evaluator.PrintFoldResults(ch, metrics);
@@ -248,7 +248,7 @@ private void RunCore(IChannel ch)
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
{
var perInst = evaluator.GetPerInstanceMetrics(data);
- var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
+ var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
}
diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
index 505d3a28e6..e1057d18b4 100644
--- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
@@ -219,7 +219,7 @@ public static void LoadModel(IHostEnvironment env, Stream modelStream, bool load
if (roles != null)
{
var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null));
- schema = RoleMappedSchema.CreateOpt(emptyView.Schema, roles);
+ schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true);
}
else
{
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index 02d655b48b..607bf119d7 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -97,10 +97,7 @@ private void RunCore(IChannel ch)
ch.Trace("Creating loader");
- IPredictor predictor;
- IDataLoader loader;
- RoleMappedSchema trainSchema;
- LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
+ LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
ch.AssertValue(predictor);
ch.AssertValueOrNull(trainSchema);
ch.AssertValue(loader);
@@ -116,7 +113,7 @@ private void RunCore(IChannel ch)
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- var schema = TrainUtils.CreateRoleMappedSchemaOpt(loader.Schema, feat, group, customCols);
+ var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
var mapper = bindable.Bind(Host, schema);
if (!scorer.IsGood())
@@ -153,22 +150,20 @@ private void RunCore(IChannel ch)
Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0;
if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0)
- ch.Warning("outputAllColumns=+ always writes all columns irrespective of outputColumn specified.");
+ ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified.");
if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0)
{
foreach (var outCol in Args.OutputColumn)
{
- int dummyColIndex;
- if (!loader.Schema.TryGetColumnIndex(outCol, out dummyColIndex))
+ if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol);
}
}
- int colMax;
uint maxScoreId = 0;
if (!outputAllColumns)
- maxScoreId = loader.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId);
+ maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
var cols = new List();
for (int i = 0; i < loader.Schema.ColumnCount; i++)
@@ -211,12 +206,12 @@ private bool ShouldAddColumn(ISchema schema, int i, uint scoreSet, bool outputNa
{
switch (schema.GetColumnName(i))
{
- case "Label":
- case "Name":
- case "Names":
- return true;
- default:
- break;
+ case "Label":
+ case "Name":
+ case "Names":
+ return true;
+ default:
+ break;
}
}
if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema.GetColumnName(i).Equals) >= 0)
@@ -229,8 +224,7 @@ public static class ScoreUtils
{
public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema)
{
- ISchemaBoundMapper mapper;
- var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out mapper);
+ var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out var mapper);
return sc.CreateInstance(env, data.Data, mapper, trainSchema);
}
@@ -247,9 +241,8 @@ public static IDataScorerTransform GetScorer(SubComponent GetScorerC
Contracts.AssertValue(mapper);
string loadName = null;
- DvText scoreKind = default(DvText);
+ DvText scoreKind = default;
if (mapper.OutputSchema.ColumnCount > 0 &&
mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
scoreKind.HasChars)
@@ -298,9 +291,9 @@ public static SubComponent GetScorerC
///
/// Given a predictor and an optional scorer SubComponent, produces a compatible ISchemaBindableMapper.
/// First, it tries to instantiate the bindable mapper using the
- /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one
+ /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one
/// of the scorer).
- /// If the above fails, it checks whether the predictor implements
+ /// If the above fails, it checks whether the predictor implements
/// directly.
/// If this also isn't true, it will create a 'matching' standard mapper.
///
@@ -311,10 +304,8 @@ public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(scorerSettings);
- ISchemaBindableMapper bindable;
-
// See if we can instantiate a mapper using scorer arguments.
- if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out bindable))
+ if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable))
return bindable;
// The easy case is that the predictor implements the interface.
diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs
index 79e7bd5458..d0ebbd5a05 100644
--- a/src/Microsoft.ML.Data/Commands/TestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs
@@ -114,7 +114,7 @@ private void RunCore(IChannel ch)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
var evaluator = evalComp.CreateInstance(Host);
- var data = TrainUtils.CreateExamples(scorePipe, label, null, group, weight, name, customCols);
+ var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols);
var metrics = evaluator.Evaluate(data);
MetricWriter.PrintWarnings(ch, metrics);
evaluator.PrintFoldResults(ch, metrics);
@@ -128,7 +128,7 @@ private void RunCore(IChannel ch)
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
{
var perInst = evaluator.GetPerInstanceMetrics(data);
- var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
+ var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
}
diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
index e55a5a3992..69370ad3ef 100644
--- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
@@ -157,13 +157,13 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Binding columns");
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- var data = TrainUtils.CreateExamples(view, label, feature, group, weight, name, customCols);
+ var data = new RoleMappedData(view, label, feature, group, weight, name, customCols);
// REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
RoleMappedData validData = null;
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
{
- if (!TrainUtils.CanUseValidationData(trainer))
+ if (!trainer.Info.SupportsValidation)
{
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
}
@@ -172,7 +172,7 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Constructing the validation pipeline");
IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
- validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
+ validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
}
}
@@ -222,9 +222,9 @@ public static string MatchNameOrDefaultOrNull(IExceptionContext ectx, ISchema sc
return userName;
if (userName == defaultName)
return null;
-#pragma warning disable TLC_ContractsNameUsesNameof
+#pragma warning disable MSML_ContractsNameUsesNameof
throw ectx.ExceptUserArg(argName, $"Could not find column '{userName}'");
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
}
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name,
@@ -235,14 +235,14 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData
}
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
- SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null)
+ SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
{
ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env);
- return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inpPredictor);
+ return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor);
}
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
- ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null)
+ ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
@@ -250,79 +250,22 @@ private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappe
ch.CheckValue(trainer, nameof(trainer));
ch.CheckNonEmpty(name, nameof(name));
ch.CheckValueOrNull(validData);
- ch.CheckValueOrNull(inpPredictor);
+ ch.CheckValueOrNull(inputPredictor);
- var trainerRmd = trainer as ITrainer;
- if (trainerRmd == null)
- throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name);
-
- Action, object, object, object> trainCoreAction = TrainCore;
- IPredictor predictor;
AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
ch.Trace("Training");
if (validData != null)
AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
- var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
- typeof(RoleMappedData),
- inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor));
- Action trainExam = trainerRmd.Train;
- genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor });
-
- ch.Trace("Constructing predictor");
- predictor = trainerRmd.CreatePredictor();
- return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data);
- }
-
- public static bool CanUseValidationData(ITrainer trainer)
- {
- Contracts.CheckValue(trainer, nameof(trainer));
-
- if (trainer is ITrainer)
- return trainer is IValidatingTrainer;
-
- return false;
- }
-
- private static void TrainCore(IChannel ch, ITrainer trainer, Action train, TDataSet data, TDataSet validData = null, TPredictor predictor = null)
- where TDataSet : class
- where TPredictor : class
- {
- const string inputModelArg = nameof(TrainCommand.Arguments.InputModelFile);
- if (validData != null)
- {
- if (predictor != null)
- {
- var incValidTrainer = trainer as IIncrementalValidatingTrainer;
- if (incValidTrainer != null)
- {
- incValidTrainer.Train(data, validData, predictor);
- return;
- }
-
- ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
- }
-
- var validTrainer = trainer as IValidatingTrainer;
- ch.AssertValue(validTrainer);
- validTrainer.Train(data, validData);
- }
- else
+ if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
{
- if (predictor != null)
- {
- var incTrainer = trainer as IIncrementalTrainer;
- if (incTrainer != null)
- {
- incTrainer.Train(data, predictor);
- return;
- }
-
- ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
- }
-
- train(data);
+ ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
+ ": Trainer does not support incremental training.");
+ inputPredictor = null;
}
+ ch.Assert(validData == null || trainer.Info.SupportsValidation);
+ var predictor = trainer.Train(new TrainContext(data, validData, inputPredictor));
+ return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data);
}
public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)
@@ -348,7 +291,7 @@ public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string in
///
/// Save the model to the output path.
- /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
+ /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
/// and command. It also uses featureColumn, if provided, to extract feature names.
///
/// The host environment to use.
@@ -373,7 +316,7 @@ public static void SaveModel(IHostEnvironment env, IChannel ch, IFileHandle outp
///
/// Save the model to the stream.
- /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
+ /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
/// and command. It also uses featureColumn, if provided, to extract feature names.
///
/// The host environment to use.
@@ -438,9 +381,8 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor
IDataView pipeStart;
var xfs = BacktrackPipe(dataPipe, out pipeStart);
- IDataLoader loader;
Action saveAction;
- if (!blankLoader && (loader = pipeStart as IDataLoader) != null)
+ if (!blankLoader && pipeStart is IDataLoader loader)
saveAction = loader.Save;
else
{
@@ -458,7 +400,7 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor
///
/// Traces back the .Source chain of the transformation pipe up to the moment it no longer can.
- /// Returns all the transforms of and the first data view (a non-transform).
+ /// Returns all the transforms of and the first data view (a non-transform).
///
/// The transformation pipe to traverse.
/// The beginning data view of the transform chain
@@ -468,16 +410,11 @@ private static List BacktrackPipe(IDataView dataPipe, out IDataV
Contracts.AssertValue(dataPipe);
var transforms = new List();
- while (true)
+ while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
- // cause this method to iterate forever (and throw something when the list overflows). There's
+ // cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
-
- var xf = dataPipe as IDataTransform;
- if (xf == null)
- break;
-
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
@@ -514,11 +451,8 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
{
if (autoNorm != NormalizeOption.Yes)
{
- var nn = trainer as ITrainerEx;
DvBool isNormalized = DvBool.False;
- if (nn == null || !nn.NeedNormalization ||
- (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
- isNormalized.IsTrue))
+ if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol))
{
ch.Info("Not adding a normalizer.");
return false;
@@ -530,20 +464,13 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
}
}
ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
- // Quote the feature column name
- string quotedFeatureColumnName = featureColumn;
- StringBuilder sb = new StringBuilder();
- if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
- quotedFeatureColumnName = sb.ToString();
- var component = new SubComponent("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
- var loader = view as IDataLoader;
- if (loader != null)
- {
- view = CompositeDataLoader.Create(env, loader,
- new KeyValuePair>(null, component));
- }
+ IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
+ => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);
+
+ if (view is IDataLoader loader)
+ view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
else
- view = component.CreateInstance(env, view);
+ view = ApplyNormalizer(env, view);
return true;
}
return false;
@@ -556,8 +483,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
ch.AssertValue(trainer, nameof(trainer));
ch.AssertValue(data, nameof(data));
- ITrainerEx trainerEx = trainer as ITrainerEx;
- bool shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching));
+ bool shouldCache = cacheData ?? !(data.Data is BinaryLoader) && trainer.Info.WantCaching;
if (shouldCache)
{
@@ -565,7 +491,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
var prefetch = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
var cacheView = new CacheDataView(env, data.Data, prefetch);
// Because the prefetching worked, we know that these are valid columns.
- data = RoleMappedData.Create(cacheView, data.Schema.GetColumnRoleNames());
+ data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames());
}
else
ch.Trace("Not caching");
@@ -586,97 +512,5 @@ public static IEnumerable> CheckAndGenerateCust
}
return customColumnArg.Select(kindName => new ColumnRole(kindName.Key).Bind(kindName.Value));
}
-
- ///
- /// Given a schema and a bunch of column names, create the BoundSchema object. Any or all of the column
- /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not
- /// valid columns of the schema are also ignored.
- ///
- public static RoleMappedSchema CreateRoleMappedSchemaOpt(ISchema schema, string feature, string group, IEnumerable> custom = null)
- {
- Contracts.CheckValueOrNull(feature);
- Contracts.CheckValueOrNull(custom);
-
- var list = new List>();
- if (!string.IsNullOrWhiteSpace(feature))
- list.Add(ColumnRole.Feature.Bind(feature));
- if (!string.IsNullOrWhiteSpace(group))
- list.Add(ColumnRole.Group.Bind(group));
- if (custom != null)
- list.AddRange(custom);
-
- return RoleMappedSchema.CreateOpt(schema, list);
- }
-
- ///
- /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column
- /// names may be null or whitespace, in which case they are ignored. Any columns that are specified must
- /// be valid columns of the schema.
- ///
- public static RoleMappedData CreateExamples(IDataView view, string label, string feature,
- string group = null, string weight = null, string name = null,
- IEnumerable> custom = null)
- {
- Contracts.CheckValueOrNull(label);
- Contracts.CheckValueOrNull(feature);
- Contracts.CheckValueOrNull(group);
- Contracts.CheckValueOrNull(weight);
- Contracts.CheckValueOrNull(name);
- Contracts.CheckValueOrNull(custom);
-
- var list = new List>();
- if (!string.IsNullOrWhiteSpace(label))
- list.Add(ColumnRole.Label.Bind(label));
- if (!string.IsNullOrWhiteSpace(feature))
- list.Add(ColumnRole.Feature.Bind(feature));
- if (!string.IsNullOrWhiteSpace(group))
- list.Add(ColumnRole.Group.Bind(group));
- if (!string.IsNullOrWhiteSpace(weight))
- list.Add(ColumnRole.Weight.Bind(weight));
- if (!string.IsNullOrWhiteSpace(name))
- list.Add(ColumnRole.Name.Bind(name));
- if (custom != null)
- list.AddRange(custom);
-
- return RoleMappedData.Create(view, list);
- }
-
- ///
- /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column
- /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not
- /// valid columns of the schema are also ignored.
- ///
- public static RoleMappedData CreateExamplesOpt(IDataView view, string label, string feature,
- string group = null, string weight = null, string name = null,
- IEnumerable> custom = null)
- {
- Contracts.CheckValueOrNull(label);
- Contracts.CheckValueOrNull(feature);
- Contracts.CheckValueOrNull(group);
- Contracts.CheckValueOrNull(weight);
- Contracts.CheckValueOrNull(name);
- Contracts.CheckValueOrNull(custom);
-
- var list = new List>();
- if (!string.IsNullOrWhiteSpace(label))
- list.Add(ColumnRole.Label.Bind(label));
- if (!string.IsNullOrWhiteSpace(feature))
- list.Add(ColumnRole.Feature.Bind(feature));
- if (!string.IsNullOrWhiteSpace(group))
- list.Add(ColumnRole.Group.Bind(group));
- if (!string.IsNullOrWhiteSpace(weight))
- list.Add(ColumnRole.Weight.Bind(weight));
- if (!string.IsNullOrWhiteSpace(name))
- list.Add(ColumnRole.Name.Bind(name));
- if (custom != null)
- list.AddRange(custom);
-
- return RoleMappedData.CreateOpt(view, list);
- }
-
- private static KeyValuePair Pair(ColumnRole kind, T value)
- {
- return new KeyValuePair(kind, value);
- }
}
}
diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
index f6ffa772f9..03ee7cdf12 100644
--- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
@@ -147,12 +147,12 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Binding columns");
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
- var data = TrainUtils.CreateExamples(trainPipe, label, features, group, weight, name, customCols);
+ var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);
RoleMappedData validData = null;
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
{
- if (!TrainUtils.CanUseValidationData(trainer))
+ if (!trainer.Info.SupportsValidation)
{
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
}
@@ -161,7 +161,7 @@ private void RunCore(IChannel ch, string cmd)
ch.Trace("Constructing the validation pipeline");
IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
- validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
+ validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
}
}
@@ -189,8 +189,8 @@ private void RunCore(IChannel ch, string cmd)
if (!evalComp.IsGood())
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
var evaluator = evalComp.CreateInstance(Host);
- var dataEval = TrainUtils.CreateExamplesOpt(scorePipe, label, features,
- group, weight, name, customCols);
+ var dataEval = new RoleMappedData(scorePipe, label, features,
+ group, weight, name, customCols, opt: true);
var metrics = evaluator.Evaluate(dataEval);
MetricWriter.PrintWarnings(ch, metrics);
evaluator.PrintFoldResults(ch, metrics);
@@ -204,7 +204,7 @@ private void RunCore(IChannel ch, string cmd)
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
{
var perInst = evaluator.GetPerInstanceMetrics(dataEval);
- var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
+ var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
}
diff --git a/src/Microsoft.ML.Data/Data/Combiner.cs b/src/Microsoft.ML.Data/Data/Combiner.cs
index 9a5de27ff6..ee45aee3e3 100644
--- a/src/Microsoft.ML.Data/Data/Combiner.cs
+++ b/src/Microsoft.ML.Data/Data/Combiner.cs
@@ -21,7 +21,7 @@ public abstract class Combiner
public sealed class TextCombiner : Combiner
{
- private volatile static TextCombiner _instance;
+ private static volatile TextCombiner _instance;
public static TextCombiner Instance
{
get
@@ -46,7 +46,7 @@ public override void Combine(ref DvText dst, DvText src)
public sealed class FloatAdder : Combiner
{
- private volatile static FloatAdder _instance;
+ private static volatile FloatAdder _instance;
public static FloatAdder Instance
{
get
@@ -67,7 +67,7 @@ private FloatAdder()
public sealed class R4Adder : Combiner
{
- private volatile static R4Adder _instance;
+ private static volatile R4Adder _instance;
public static R4Adder Instance
{
get
@@ -88,7 +88,7 @@ private R4Adder()
public sealed class R8Adder : Combiner
{
- private volatile static R8Adder _instance;
+ private static volatile R8Adder _instance;
public static R8Adder Instance
{
get
@@ -110,7 +110,7 @@ private R8Adder()
// REVIEW: Delete this!
public sealed class U4Adder : Combiner
{
- private volatile static U4Adder _instance;
+ private static volatile U4Adder _instance;
public static U4Adder Instance
{
get
diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs
index 974f40c39d..0a9833064a 100644
--- a/src/Microsoft.ML.Data/Data/Conversion.cs
+++ b/src/Microsoft.ML.Data/Data/Conversion.cs
@@ -53,7 +53,7 @@ public sealed class Conversions
// REVIEW: Reconcile implementations with TypeUtils, and clarify the distinction.
// Singleton pattern.
- private volatile static Conversions _instance;
+ private static volatile Conversions _instance;
public static Conversions Instance
{
get
diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs
index 4772228fa0..1db4d5ad0a 100644
--- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs
+++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs
@@ -286,9 +286,9 @@ private sealed class Splitter
private enum ExtraIndex
{
Id,
-#pragma warning disable TLC_GeneralName // Allow for this private enum.
+#pragma warning disable MSML_GeneralName // Allow for this private enum.
_Lim
-#pragma warning restore TLC_GeneralName
+#pragma warning restore MSML_GeneralName
}
private Splitter(ISchema schema)
diff --git a/src/Microsoft.ML.Data/Data/IColumn.cs b/src/Microsoft.ML.Data/Data/IColumn.cs
index 28d6ffd057..2f2f496f99 100644
--- a/src/Microsoft.ML.Data/Data/IColumn.cs
+++ b/src/Microsoft.ML.Data/Data/IColumn.cs
@@ -13,16 +13,16 @@ namespace Microsoft.ML.Runtime.Data
///
/// This interface is an analogy to that encapsulates the contents of a single
/// column.
- ///
+ ///
/// Note that in the same sense that is not thread safe, implementors of this interface
/// by similar token must not be considered thread safe by users of the interface, and by the same token
/// implementors should feel free to write their implementations with the expectation that only one thread
/// will be calling it at a time.
- ///
+ ///
/// Similarly, in the same sense that an can have its values "change under it" by having
/// the underlying cursor move, so too might this item have its values change under it, and they will if
/// they were directly instantiated from a row.
- ///
+ ///
/// Generally actual implementors of this interface should not implement this directly, but instead implement
/// .
///
@@ -495,7 +495,7 @@ public override ValueGetter GetGetter()
///
private sealed class RowColumnRow : IRow
{
- private readonly static DefaultCountedImpl _defCount = new DefaultCountedImpl();
+ private static readonly DefaultCountedImpl _defCount = new DefaultCountedImpl();
private readonly ICounted _counted;
private readonly IColumn[] _columns;
private readonly SchemaImpl _schema;
diff --git a/src/Microsoft.ML.Data/Data/IRowSeekable.cs b/src/Microsoft.ML.Data/Data/IRowSeekable.cs
index c2fb54bf70..3c0bf0db08 100644
--- a/src/Microsoft.ML.Data/Data/IRowSeekable.cs
+++ b/src/Microsoft.ML.Data/Data/IRowSeekable.cs
@@ -6,7 +6,7 @@
namespace Microsoft.ML.Runtime.Data
{
- // REVIEW: Would it be a better apporach to add something akin to CanSeek,
+ // REVIEW: Would it be a better apporach to add something akin to CanSeek,
// as we have a CanShuffle? The idea is trying to make IRowSeekable propagate along certain transforms.
///
/// Represents a data view that supports random access to a specific row.
@@ -18,14 +18,14 @@ public interface IRowSeekable : ISchematized
///
/// Represents a row seeker with random access that can retrieve a specific row by the row index.
- /// For IRowSeeker, when the state is valid (that is when MoveTo() returns true), it returns the
- /// current row index. Otherwise it's -1.
+ /// For IRowSeeker, when the state is valid (that is when MoveTo() returns true), it returns the
+ /// current row index. Otherwise it's -1.
///
public interface IRowSeeker : IRow, IDisposable
{
///
/// Moves the seeker to a row at a specific row index.
- /// If the row index specified is out of range (less than zero or not less than the
+ /// If the row index specified is out of range (less than zero or not less than the
/// row count), it returns false and sets its Position property to -1.
///
/// The row index to move to.
diff --git a/src/Microsoft.ML.Data/Data/ITransposeDataView.cs b/src/Microsoft.ML.Data/Data/ITransposeDataView.cs
index 2188c766de..f247bc9859 100644
--- a/src/Microsoft.ML.Data/Data/ITransposeDataView.cs
+++ b/src/Microsoft.ML.Data/Data/ITransposeDataView.cs
@@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Data
/// ). This interface is intended to be implemented by classes that
/// want to provide an option for an alternate way of accessing the data stored in a
/// .
- ///
+ ///
/// The interface only advertises that columns may be accessible in slot-wise fashion. A column
/// is accessible in this fashion iff 's
/// returns a non-null value.
diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs
index 3f57266b0f..091fe26cb2 100644
--- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs
+++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs
@@ -39,7 +39,7 @@ private static Delegate GetGetterAsDelegateCore(IRow row, int col)
///
/// Given a destination type, IRow, and column index, return a ValueGetter for the column
- /// with a conversion to typeDst, if needed. This is a weakly typed version of
+ /// with a conversion to typeDst, if needed. This is a weakly typed version of
/// .
///
///
@@ -293,7 +293,7 @@ private static ValueGetter> GetVecGetterAsCore(VectorT
///
/// This method returns a small helper delegate that returns whether we are at the start
- /// of a new group, that is, we have just started, or the key-value at indicated column
+ /// of a new group, that is, we have just started, or the key-value at indicated column
/// is different than it was, in the last call. This is practically useful for determining
/// group boundaries. Note that the delegate will return true on the first row.
///
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
index 4a58e097fb..7bc0a8d2ad 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
@@ -79,7 +79,7 @@ private sealed class TableOfContentsEntry
public readonly ColumnType Type;
///
- /// The compression scheme used on this column's blocks.
+ /// The compression scheme used on this column's blocks.
///
public readonly CompressionKind Compression;
@@ -971,7 +971,7 @@ public void Save(ModelSaveContext ctx)
}
///
- /// Write the parameters of a loader to the save context. Can be called by , where there's no actual
+ /// Write the parameters of a loader to the save context. Can be called by , where there's no actual
/// loader, only default parameters.
///
private static void SaveParameters(ModelSaveContext ctx, int threads, string generatedRowIndexName, Double shuffleBlocks)
@@ -991,7 +991,7 @@ private static void SaveParameters(ModelSaveContext ctx, int threads, string gen
}
///
- /// Save a zero-row dataview that will be used to infer schema information, used in the case
+ /// Save a zero-row dataview that will be used to infer schema information, used in the case
/// where the binary loader is instantiated with no input streams.
///
private static void SaveSchema(IHostEnvironment env, ModelSaveContext ctx, ISchema schema, out int[] unsavableColIndices)
@@ -1017,10 +1017,10 @@ private static void SaveSchema(IHostEnvironment env, ModelSaveContext ctx, ISche
}
///
- /// Given the schema and a model context, save an imaginary instance of a binary loader with the
- /// specified schema. Deserialization from this context should produce a real binary loader that
+ /// Given the schema and a model context, save an imaginary instance of a binary loader with the
+ /// specified schema. Deserialization from this context should produce a real binary loader that
/// has the specified schema.
- ///
+ ///
/// This is used in an API scenario, when the data originates from something other than a loader.
/// Since our model file requires a loader at the beginning, we have to construct a bogus 'binary' loader
/// to begin the pipe with, with the assumption that the user will bypass the loader at deserialization
@@ -1042,9 +1042,9 @@ public static void SaveInstance(IHostEnvironment env, ModelSaveContext ctx, ISch
int[] unsavable;
SaveSchema(env, ctx, schema, out unsavable);
// REVIEW: we silently ignore unsavable columns.
- // This method is invoked only in an API scenario, where we need to save a loader but we only have a schema.
- // In this case, the API user is likely not subscribed to our environment's channels. Also, in this case, the presence of
- // unsavable columns is not necessarily a bad thing: the user typically provides his own data when loading the transforms,
+ // This method is invoked only in an API scenario, where we need to save a loader but we only have a schema.
+ // In this case, the API user is likely not subscribed to our environment's channels. Also, in this case, the presence of
+ // unsavable columns is not necessarily a bad thing: the user typically provides his own data when loading the transforms,
// thus bypassing the bogus loader.
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs
index 7fe9fbbf4a..e2f44df2a4 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs
@@ -850,7 +850,7 @@ public ColumnType LoadTypeDescriptionOrNull(Stream stream)
/// The type of the codec to write and utilize
/// The value to encode and write
/// The number of bytes written
- /// Whether the write was successful or not
+ /// Whether the write was successful or not
public bool TryWriteTypeAndValue(Stream stream, ColumnType type, ref T value, out int bytesWritten)
{
_host.CheckValue(stream, nameof(stream));
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs
index fd68a34cc9..9c6e607022 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs
@@ -13,7 +13,7 @@ namespace Microsoft.ML.Runtime.Data.IO
/// on the appropriate ColumnType, then opens multiple writers to write blocks of data
/// to some stream. The idea is that each writer or reader is called on some "managable chunk"
/// of data.
- ///
+ ///
/// Codecs should be thread safe, though the readers and writers they spawn do not need to
/// be thread safe.
///
@@ -60,7 +60,7 @@ internal interface IValueCodec : IValueCodec
/// Stream on which we open reader.
/// The number of items expected to be encoded in the block
/// starting from the current position of the stream. Implementors should, if
- /// possible, throw if it seems if the block contains a different number of
+ /// possible, throw if it seems if the block contains a different number of
/// elements.
IValueReader OpenReader(Stream stream, int items);
}
@@ -89,7 +89,7 @@ internal interface IValueWriter : IDisposable
/// be spawned from an , its write methods called some
/// number of times to write to the stream, and then Commit will be called when
/// all values have been written, the stream now being at the end of the written block.
- ///
+ ///
/// The intended usage of the value writers is that blocks are composed of some small
/// number of values (perhaps a few thousand), the idea being that a block is something
/// that should easily fit in main memory, both for reading and writing. Some writers
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs
index 4dcc82ac9b..024eaef4a2 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs
@@ -13,20 +13,20 @@ internal static class Zlib
public const string DllPath = "zlib.dll";
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- private static unsafe extern Constants.RetCode deflateInit2_(ZStream* strm, int level, int method, int windowBits,
+ private static extern unsafe Constants.RetCode deflateInit2_(ZStream* strm, int level, int method, int windowBits,
int memLevel, Constants.Strategy strategy, byte* version, int streamSize);
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- private static unsafe extern Constants.RetCode inflateInit2_(ZStream* strm, int windowBits, byte* version, int streamSize);
+ private static extern unsafe Constants.RetCode inflateInit2_(ZStream* strm, int windowBits, byte* version, int streamSize);
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- private static unsafe extern byte* zlibVersion();
+ private static extern unsafe byte* zlibVersion();
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- public static unsafe extern Constants.RetCode deflateEnd(ZStream* strm);
+ public static extern unsafe Constants.RetCode deflateEnd(ZStream* strm);
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- public static unsafe extern Constants.RetCode deflate(ZStream* strm, Constants.Flush flush);
+ public static extern unsafe Constants.RetCode deflate(ZStream* strm, Constants.Flush flush);
public static unsafe Constants.RetCode DeflateInit2(ZStream* strm, int level, int method, int windowBits,
int memLevel, Constants.Strategy strategy)
@@ -40,10 +40,10 @@ public static unsafe Constants.RetCode InflateInit2(ZStream* strm, int windowBit
}
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- public static unsafe extern Constants.RetCode inflate(ZStream* strm, Constants.Flush flush);
+ public static extern unsafe Constants.RetCode inflate(ZStream* strm, Constants.Flush flush);
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
- public static unsafe extern Constants.RetCode inflateEnd(ZStream* strm);
+ public static extern unsafe Constants.RetCode inflateEnd(ZStream* strm);
}
[StructLayout(LayoutKind.Sequential)]
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
index 4a7106c2df..a2ab3a7b16 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
@@ -41,7 +41,7 @@ public sealed class Arguments
public KeyValuePair>[] Transform;
}
- internal struct TransformEx
+ private struct TransformEx
{
public readonly string Tag;
public readonly string ArgsString;
@@ -78,16 +78,14 @@ private static VersionInfo GetVersionInfo()
// The composition of loader plus transforms in order.
private readonly IDataLoader _loader;
private readonly TransformEx[] _transforms;
- private readonly IDataView _view;
private readonly ITransposeDataView _tview;
- private readonly ITransposeSchema _tschema;
private readonly IHost _host;
///
/// Returns the underlying data view of the composite loader.
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
///
- internal IDataView View { get { return _view; } }
+ public IDataView View { get; }
///
/// Creates a loader according to the specified .
@@ -200,7 +198,7 @@ private static IDataLoader ApplyTransformsCore(IHost host, IDataLoader srcLoader
IDataLoader pipeStart;
if (composite != null)
{
- srcView = composite._view;
+ srcView = composite.View;
exes.AddRange(composite._transforms);
pipeStart = composite._loader;
}
@@ -409,9 +407,9 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)
_host = host;
_host.AssertNonEmpty(transforms);
- _view = transforms[transforms.Length - 1].Transform;
- _tview = _view as ITransposeDataView;
- _tschema = _tview == null ? new TransposerUtils.SimpleTransposeSchema(_view.Schema) : _tview.TransposeSchema;
+ View = transforms[transforms.Length - 1].Transform;
+ _tview = View as ITransposeDataView;
+ TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
var srcLoader = transforms[0].Transform.Source as IDataLoader;
@@ -561,29 +559,20 @@ private static string GenerateTag(int index)
public long? GetRowCount(bool lazy = true)
{
- return _view.GetRowCount(lazy);
+ return View.GetRowCount(lazy);
}
- public bool CanShuffle
- {
- get { return _view.CanShuffle; }
- }
+ public bool CanShuffle => View.CanShuffle;
- public ISchema Schema
- {
- get { return _view.Schema; }
- }
+ public ISchema Schema => View.Schema;
- public ITransposeSchema TransposeSchema
- {
- get { return _tschema; }
- }
+ public ITransposeSchema TransposeSchema { get; }
public IRowCursor GetRowCursor(Func predicate, IRandom rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
- return _view.GetRowCursor(predicate, rand);
+ return View.GetRowCursor(predicate, rand);
}
public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
@@ -591,13 +580,13 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
- return _view.GetRowCursorSet(out consolidator, predicate, n, rand);
+ return View.GetRowCursorSet(out consolidator, predicate, n, rand);
}
public ISlotCursor GetSlotCursor(int col)
{
_host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
- if (_tschema == null || _tschema.GetSlotType(col) == null)
+ if (TransposeSchema?.GetSlotType(col) == null)
{
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
Schema.GetColumnName(col));
diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs
index 69eb3bbb3b..10bf816dc1 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs
@@ -682,13 +682,12 @@ private bool TryTruncatePath(int dirCount, string path, out string truncPath)
Ch.Warning($"Path {path} did not have {dirCount} directories necessary for parsing.");
return false;
}
-
+
// Rejoin segments to create a valid path.
truncPath = String.Join(Path.DirectorySeparatorChar.ToString(), segments);
return true;
}
-
///
/// Parse all column values from the directory path.
///
diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs
index ca3aa075ab..70d8f898ab 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs
@@ -76,7 +76,7 @@ public class Arguments : IPartitionedPathParserFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Column definitions used to override the Partitioned Path Parser. Expected with the format name:type:numeric-source, e.g. col=MyFeature:R4:1",
ShortName = "col", SortOrder = 1)]
- public Microsoft.ML.Runtime.Data.PartitionedFileLoader.Column[] Columns;
+ public PartitionedFileLoader.Column[] Columns;
[Argument(ArgumentType.AtMostOnce, HelpText = "Data type of each column.")]
public DataKind Type = DataKind.Text;
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index 3678c749ba..babca545c8 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -34,7 +34,7 @@ public sealed partial class TextLoader : IDataLoader
///
/// Vector column of I4 that contains values from columns 1, 3 to 10
/// col=ColumnName:I4:1,3-10
- ///
+ ///
/// Key range column of KeyType with underlying storage type U4 that contains values from columns 1, 3 to 10, that can go from 1 to 100 (0 reserved for out of range)
/// col=ColumnName:U4[1-100]:1,3-10
///
@@ -554,7 +554,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile)
{
var range = col.Source[i];
- // Check for remaining range, raise flag.
+ // Check for remaining range, raise flag.
if (range.AllOther)
{
ch.CheckUserArg(iinfoOther < 0, nameof(Range.AllOther), "At most one all other range can be specified");
@@ -605,7 +605,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile)
NameToInfoIndex[name] = iinfo;
}
- // Note that segsOther[isegOther] is not a real segment to be included.
+ // Note that segsOther[isegOther] is not a real segment to be included.
// It only persists segment information such as Min, Max, autoEnd, variableEnd for later processing.
// Process all other range.
if (iinfoOther >= 0)
@@ -641,7 +641,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile)
foreach (var seg in segsAll)
{
- // At this step, all indices less than min is contained in some segment, either in
+ // At this step, all indices less than min is contained in some segment, either in
// segsAll or segsNew.
ch.Assert(min < lim);
if (min < seg.Min)
@@ -1014,7 +1014,7 @@ public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource files
_host.CheckNonEmpty(args.Separator, nameof(args.Separator), "Must specify a separator");
//Default arg.Separator is tab and default args.SeparatorChars is also a '\t'.
- //At a time only one default can be different and whichever is different that will
+ //At a time only one default can be different and whichever is different that will
//be used.
if (args.SeparatorChars.Length > 1 || args.SeparatorChars[0] != '\t')
{
@@ -1110,7 +1110,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
// Get settings just for core arguments, not everything.
string tmp = CmdParser.GetSettings(host, args, new ArgumentsCore());
- // Try to get the schema information from the file.
+ // Try to get the schema information from the file.
string str = Cursor.GetEmbeddedArgs(files);
if (string.IsNullOrWhiteSpace(str))
return false;
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
index 48e44b31e7..582d81b546 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
@@ -27,7 +27,7 @@ public sealed partial class TextLoader : IDataLoader
///
private sealed class ValueCreatorCache
{
- private volatile static ValueCreatorCache _instance;
+ private static volatile ValueCreatorCache _instance;
public static ValueCreatorCache Instance
{
get
@@ -137,9 +137,9 @@ private sealed class ParseStats
private volatile int _cref;
// Total number of rows, number of unparsable values, number of format errors.
- private /*volatile*/ long _rowCount;
- private /*volatile*/ long _badCount;
- private /*volatile*/ long _fmtCount;
+ private long _rowCount;
+ private long _badCount;
+ private long _fmtCount;
public ParseStats(IChannelProvider provider, int cref, long maxShow = MaxShow)
{
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
index 173c588607..37cbe23b92 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
@@ -531,7 +531,7 @@ public void Save(ModelSaveContext ctx)
}
///
- /// Save a zero-row dataview that will be used to infer schema information, used in the case
+ /// Save a zero-row dataview that will be used to infer schema information, used in the case
/// where the tranpsose loader is instantiated with no input streams.
///
private static void SaveSchema(IHostEnvironment env, ModelSaveContext ctx, ISchema schema)
diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs
index 6633e2535f..dc713a82b2 100644
--- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs
+++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs
@@ -24,7 +24,7 @@ namespace Microsoft.ML.Runtime.Data
/// This class provides the functionality to combine multiple IDataView objects which share the same schema
/// All sources must contain the same number of columns and their column names, sizes, and item types must match.
/// The row count of the resulting IDataView will be the sum over that of each individual.
- ///
+ ///
/// An AppendRowsDataView instance is shuffleable iff all of its sources are shuffleable and their row counts are known.
///
public sealed class AppendRowsDataView : IDataView
@@ -46,8 +46,8 @@ public sealed class AppendRowsDataView : IDataView
///
/// Create a dataview by appending the rows of the sources.
- ///
- /// All sources must be consistent with the passed-in schema in the number of columns, column names,
+ ///
+ /// All sources must be consistent with the passed-in schema in the number of columns, column names,
/// and column types. If schema is null, the first source's schema will be used.
///
/// The host environment.
@@ -203,7 +203,7 @@ public bool IsColumnActive(int col)
}
///
- /// The deterministic cursor. It will scan through the sources sequentially.
+ /// The deterministic cursor. It will scan through the sources sequentially.
///
private sealed class Cursor : CursorBase
{
@@ -293,7 +293,7 @@ public override void Dispose()
///
/// A RandCursor will ask each subordinate cursor to shuffle itself.
- /// Then, at each step, it randomly calls a subordinate to move next with probability (roughly) proportional to
+ /// Then, at each step, it randomly calls a subordinate to move next with probability (roughly) proportional to
/// the number of the subordinate's remaining rows.
///
private sealed class RandCursor : CursorBase
@@ -383,16 +383,16 @@ public override void Dispose()
///
/// Given k classes with counts (N_0, N_2, N_3, ..., N_{k-1}), the goal of this sampler is to select the i-th
- /// class with probability N_i/M, where M = N_0 + N_1 + ... + N_{k-1}.
+ /// class with probability N_i/M, where M = N_0 + N_1 + ... + N_{k-1}.
/// Once the i-th class is selected, its count will be updated to N_i - 1.
- ///
+ ///
/// For efficiency consideration, the sampling distribution is only an approximation of the desired distribution.
///
private sealed class MultinomialWithoutReplacementSampler
{
// Implementation: generate a batch array of size BatchSize.
// Each class will claim a fraction of the batch proportional to its remaining row count.
- // Shuffle the array. The sampler reads from the array one at a time until the batch is consumed.
+ // Shuffle the array. The sampler reads from the array one at a time until the batch is consumed.
// The sampler then generates a new batch and repeat the process.
private const int BatchSize = 1000;
diff --git a/src/Microsoft.ML.Data/DataView/CacheDataView.cs b/src/Microsoft.ML.Data/DataView/CacheDataView.cs
index 3bca858d1d..72fb4b18a5 100644
--- a/src/Microsoft.ML.Data/DataView/CacheDataView.cs
+++ b/src/Microsoft.ML.Data/DataView/CacheDataView.cs
@@ -618,7 +618,7 @@ private interface IWaiter
/// is equivalent to also having waited on i-1, i-2, etc.
/// Note that this is position within the cache, that is, a row index,
/// as opposed to position within the cursor.
- ///
+ ///
/// This method should be thread safe because in the parallel cursor
/// case it will be used by multiple threads.
///
@@ -955,23 +955,23 @@ public Wrapper(RandomIndex index)
/// next job ids before they push the completed jobs to the consumer. So the workers are
/// then subject to being blocked until their current completed jobs are fully accepted
/// (i.e. added to the to-consume queue).
- ///
+ ///
/// How it works:
/// Suppose we have 7 workers (w0,..,w6) and 14 jobs (j0,..,j13).
/// Initially, jobs get assigned to workers using a shared counter.
/// Here is an example outcome of using a shared counter:
/// w1->j0, w6->j1, w0->j2, w3->j3, w4->j4, w5->j5, w2->j6.
- ///
+ ///
/// Suppose workers finished jobs in the following order:
/// w5->j5, w0->j2, w6->j1, w4->j4, w3->j3,w1->j0, w2->j6.
- ///
+ ///
/// w5 finishes processing j5 first, but will be blocked until the processing of jobs
/// j0,..,j4 completes since the consumer can consume jobs in order only.
/// Therefore, the next available job (j7) should not be assigned to w5. It should be
- /// assigned to the worker whose job *get consumed first* (w1 since it processes j0
- /// which is the first job) *not* to the worker who completes its job first (w5 in
+ /// assigned to the worker whose job *get consumed first* (w1 since it processes j0
+ /// which is the first job) *not* to the worker who completes its job first (w5 in
/// this example).
- ///
+ ///
/// So, a shared counter can be used to assign jobs to workers initially but should
/// not be used onwards.
///
diff --git a/src/Microsoft.ML.Data/DataView/CompositeSchema.cs b/src/Microsoft.ML.Data/DataView/CompositeSchema.cs
index 4d387de1d5..81aef4b01e 100644
--- a/src/Microsoft.ML.Data/DataView/CompositeSchema.cs
+++ b/src/Microsoft.ML.Data/DataView/CompositeSchema.cs
@@ -37,7 +37,7 @@ public CompositeSchema(ISchema[] sources)
///
/// Returns an array of input predicated for sources, corresponding to the input predicate.
- /// The returned array size is equal to the number of sources, but if a given source is not needed at all,
+ /// The returned array size is equal to the number of sources, but if a given source is not needed at all,
/// the corresponding predicate will be null.
///
public Func[] GetInputPredicates(Func predicate)
diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
index b0fde835d8..d69379d5da 100644
--- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
+++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
@@ -30,7 +30,7 @@ public RowMapperColumnInfo(string name, ColumnType type, ColumnMetadataInfo meta
}
///
- /// This interface is used to create a .
+ /// This interface is used to create a .
/// Implementations should be given an in their constructor, and should have a
/// ctor or Create method with , along with a corresponding
/// .
@@ -44,7 +44,7 @@ public interface IRowMapper : ICanSaveModel
///
/// Returns the getters for the output columns given an active set of output columns. The length of the getters
- /// array should be equal to the number of columns added by the IRowMapper. It should contain the getter for the
+ /// array should be equal to the number of columns added by the IRowMapper. It should contain the getter for the
/// i'th output column if activeOutput(i) is true, and null otherwise.
///
Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer);
diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs
index 3baf8a7379..91bb9c8b6a 100644
--- a/src/Microsoft.ML.Data/DataView/Transposer.cs
+++ b/src/Microsoft.ML.Data/DataView/Transposer.cs
@@ -1041,7 +1041,6 @@ private static Splitter CreateCore(IDataView view, int col, int[] ends)
}
#region ISchema implementation
-
// Subclasses should implement ColumnCount and GetColumnType.
public override bool TryGetColumnIndex(string name, out int col)
{
@@ -1062,8 +1061,6 @@ public override string GetColumnName(int col)
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _view.Schema.GetColumnName(SrcCol);
}
-
- public override abstract ColumnType GetColumnType(int col);
#endregion
private abstract class RowBase : IRow
@@ -1215,7 +1212,7 @@ private sealed class Row : RowBase>
private VBuffer _inputValue;
// The delegate to get the input value.
private readonly ValueGetter> _inputGetter;
- // The limit of _inputValue.Indices
+ // The limit of _inputValue.Indices
private readonly int[] _srcIndicesLims;
// Convenient accessor since we use this all over the place.
private int[] Lims { get { return Parent._lims; } }
@@ -1405,7 +1402,7 @@ public static void GetSingleSlotValue(this ITransposeDataView view, int col,
}
///
- /// The is parameterized by a type that becomes the
+ /// The is parameterized by a type that becomes the
/// type parameter for a , and this is generally preferable and more
/// sensible but for various reasons it's often a lot simpler to have a get-getter be over
/// the actual type returned by the getter, that is, parameterize this by the actual
diff --git a/src/Microsoft.ML.Data/DataView/ZipDataView.cs b/src/Microsoft.ML.Data/DataView/ZipDataView.cs
index 9a7e79bab8..a472b48b36 100644
--- a/src/Microsoft.ML.Data/DataView/ZipDataView.cs
+++ b/src/Microsoft.ML.Data/DataView/ZipDataView.cs
@@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Data
{
///
/// This is a data view that is a 'zip' of several data views.
- /// The length of the zipped data view is equal to the shortest of the lengths of the components.
+ /// The length of the zipped data view is equal to the shortest of the lengths of the components.
///
public sealed class ZipDataView : IDataView
{
@@ -77,7 +77,7 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null)
var srcPredicates = _schema.GetInputPredicates(predicate);
- // REVIEW: if we know the row counts, we could only open cursor if it has needed columns, and have the
+ // REVIEW: if we know the row counts, we could only open cursor if it has needed columns, and have the
// outer cursor handle the early stopping. If we don't know row counts, we need to open all the cursors because
// we don't know which one will be the shortest.
// One reason this is not done currently is because the API has 'somewhat mutable' data views, so potentially this
@@ -88,8 +88,8 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null)
}
///
- /// Create an with no requested columns on a data view.
- /// Potentially, this can be optimized by calling GetRowCount(lazy:true) first, and if the count is not known,
+ /// Create an with no requested columns on a data view.
+ /// Potentially, this can be optimized by calling GetRowCount(lazy:true) first, and if the count is not known,
/// wrapping around GetCursor().
///
private IRowCursor GetMinimumCursor(IDataView dv)
diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
index 3dd16f141b..f08d52fe85 100644
--- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
+++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
@@ -364,9 +364,7 @@ private sealed class Dense : FeatureNameCollection
private readonly int _count;
private readonly string[] _names;
- private readonly RoleMappedSchema _schema;
-
- public override RoleMappedSchema Schema => _schema;
+ public override RoleMappedSchema Schema { get; }
public Dense(int count, string[] names)
{
@@ -379,8 +377,9 @@ public Dense(int count, string[] names)
if (size > 0)
Array.Copy(names, _names, size);
- _schema = RoleMappedSchema.Create(new FeatureNameCollectionSchema(this),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, RoleMappedSchema.ColumnRole.Feature.Value));
+ // REVIEW: This seems wrong. The default feature column name is "Features" yet the role is named "Feature".
+ Schema = new RoleMappedSchema(new FeatureNameCollectionSchema(this),
+ roles: RoleMappedSchema.ColumnRole.Feature.Bind(RoleMappedSchema.ColumnRole.Feature.Value));
}
public override int Count => _count;
@@ -470,8 +469,9 @@ public Sparse(int count, string[] names, int cnn)
}
Contracts.Assert(cv == cnn);
- _schema = RoleMappedSchema.Create(new FeatureNameCollectionSchema(this),
- RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, RoleMappedSchema.ColumnRole.Feature.Value));
+ // REVIEW: This seems wrong. The default feature column name is "Features" yet the role is named "Feature".
+ _schema = new RoleMappedSchema(new FeatureNameCollectionSchema(this),
+ roles: RoleMappedSchema.ColumnRole.Feature.Bind(RoleMappedSchema.ColumnRole.Feature.Value));
}
///
diff --git a/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs b/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs
index 7a51e1a5ee..5f09c604bb 100644
--- a/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs
+++ b/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs
@@ -7,12 +7,12 @@
namespace Microsoft.ML.Runtime.Internal.Internallearn
{
-#pragma warning disable TLC_GeneralName // This structure should be deprecated anyway.
+#pragma warning disable MSML_GeneralName // This structure should be deprecated anyway.
// REVIEW: Get rid of this. Everything should be in the ArgumentAttribute (or a class
// derived from ArgumentAttribute).
[AttributeUsage(AttributeTargets.Field)]
public class TGUIAttribute : Attribute
-#pragma warning restore TLC_GeneralName
+#pragma warning restore MSML_GeneralName
{
// Display parameters
public string Label { get; set; }
@@ -32,7 +32,7 @@ public class TGUIAttribute : Attribute
public bool NoSweep { get; set; }
//Settings are automatically populated for fields that are classes.
- //The below is an extension of the framework to add settings for
+ //The below is an extension of the framework to add settings for
//boolean type fields.
public bool ShowSettingsForCheckbox { get; set; }
public object Settings { get; set; }
diff --git a/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs b/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs
index 0199c30915..045c5d30a1 100644
--- a/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs
+++ b/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs
@@ -350,7 +350,7 @@ public static void AddMultWithOffset(ref VBuffer src, Float c, ref VBuffe
/// Perform in-place scaling of a vector into another vector as
/// = * .
/// This is more or less equivalent to performing the same operation with
- /// except perhaps more efficiently,
+ /// except perhaps more efficiently,
/// with one exception: if is 0 and
/// is sparse, will have a count of zero, instead of the
/// same count as .
diff --git a/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs b/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs
index fc335bb52e..5b46aad0b8 100644
--- a/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs
+++ b/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs
@@ -57,7 +57,7 @@ public static Float DotProduct(ref VBuffer a, ref VBuffer b)
}
///
- /// Sparsify vector A (keep at most + values)
+ /// Sparsify vector A (keep at most + values)
/// and optionally rescale values to the [-1, 1] range.
/// Vector to be sparsified and normalized.
/// How many top (positive) elements to preserve after sparsification.
diff --git a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs b/src/Microsoft.ML.Data/Dirty/PredictorBase.cs
index e9d02db58b..35c9a49133 100644
--- a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs
+++ b/src/Microsoft.ML.Data/Dirty/PredictorBase.cs
@@ -19,7 +19,7 @@ public abstract class PredictorBase : IPredictorProducing
{
public const string NormalizerWarningFormat =
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
- " Please contact tlcsupp for assistance with converting legacy models.";
+ " Please refer to https://aka.ms/MLNetIssue for assistance with converting legacy models.";
protected readonly IHost Host;
@@ -41,9 +41,9 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
// Verify that the Float type matches.
int cbFloat = ctx.Reader.ReadInt32();
-#pragma warning disable TLC_NoMessagesForLoadContext // This one is actually useful.
+#pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful.
Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version");
-#pragma warning restore TLC_NoMessagesForLoadContext
+#pragma warning restore MSML_NoMessagesForLoadContext
}
public virtual void Save(ModelSaveContext ctx)
diff --git a/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs b/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs
index 9e99bf8993..37f37f6c64 100644
--- a/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs
@@ -191,7 +191,7 @@ public interface ITrainerOutput
}
///
- /// Macro output class base.
+ /// Macro output class base.
///
public abstract class MacroOutput
{
diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
index 1ff3daee02..d2d59eb94b 100644
--- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
@@ -475,7 +475,7 @@ public float Cost
private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context,
string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
- string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null)
+ string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null, string name = null)
{
Contracts.AssertValue(env);
env.AssertNonEmpty(id);
@@ -510,49 +510,10 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
var inputInstance = _inputBuilder.GetInstance();
- var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
- " Using column '{2}'. To column use '{1}' instead, please specify this name in" +
- "the trainer node arguments.";
- if (!string.IsNullOrEmpty(label) && Utils.Size(_entryPoint.InputKinds) > 0 &&
- _entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithLabel)))
- {
- var labelColField = _inputBuilder.GetFieldNameOrNull("LabelColumn");
- ch.AssertNonEmpty(labelColField);
- var labelColFieldType = _inputBuilder.GetFieldTypeOrNull(labelColField);
- ch.Assert(labelColFieldType == typeof(string));
- var inputLabel = inputInstance.GetType().GetField(labelColField).GetValue(inputInstance);
- if (label != (string)inputLabel)
- ch.Warning(warning, "label", label, inputLabel);
- else
- _inputBuilder.TrySetValue(labelColField, label);
- }
- if (!string.IsNullOrEmpty(group) && Utils.Size(_entryPoint.InputKinds) > 0 &&
- _entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithGroupId)))
- {
- var groupColField = _inputBuilder.GetFieldNameOrNull("GroupIdColumn");
- ch.AssertNonEmpty(groupColField);
- var groupColFieldType = _inputBuilder.GetFieldTypeOrNull(groupColField);
- ch.Assert(groupColFieldType == typeof(string));
- var inputGroup = inputInstance.GetType().GetField(groupColField).GetValue(inputInstance);
- if (group != (Optional)inputGroup)
- ch.Warning(warning, "group Id", label, inputGroup);
- else
- _inputBuilder.TrySetValue(groupColField, label);
- }
- if (!string.IsNullOrEmpty(weight) && Utils.Size(_entryPoint.InputKinds) > 0 &&
- (_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithWeight)) ||
- _entryPoint.InputKinds.Contains(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))))
- {
- var weightColField = _inputBuilder.GetFieldNameOrNull("WeightColumn");
- ch.AssertNonEmpty(weightColField);
- var weightColFieldType = _inputBuilder.GetFieldTypeOrNull(weightColField);
- ch.Assert(weightColFieldType == typeof(string));
- var inputWeight = inputInstance.GetType().GetField(weightColField).GetValue(inputInstance);
- if (weight != (Optional)inputWeight)
- ch.Warning(warning, "weight", label, inputWeight);
- else
- _inputBuilder.TrySetValue(weightColField, label);
- }
+ SetColumnArgument(ch, inputInstance, "LabelColumn", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
+ SetColumnArgument(ch, inputInstance, "GroupIdColumn", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
+ SetColumnArgument(ch, inputInstance, "WeightColumn", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
+ SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");
// Validate outputs.
_outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
@@ -568,6 +529,38 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
Cost = cost;
}
+ private void SetColumnArgument(IChannel ch, object inputInstance, string argName, string colName, string columnRole, params Type[] inputKinds)
+ {
+ Contracts.AssertValue(ch);
+ ch.AssertValue(inputInstance);
+ ch.AssertNonEmpty(argName);
+ ch.AssertValueOrNull(colName);
+ ch.AssertNonEmpty(columnRole);
+ ch.AssertValueOrNull(inputKinds);
+
+ var colField = _inputBuilder.GetFieldNameOrNull(argName);
+ if (string.IsNullOrEmpty(colField))
+ return;
+
+ const string warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
+ " Using column '{2}'. To column use '{1}' instead, please specify this name in" +
+ "the trainer node arguments.";
+ if (!string.IsNullOrEmpty(colName) && Utils.Size(_entryPoint.InputKinds) > 0 &&
+ (Utils.Size(inputKinds) == 0 || _entryPoint.InputKinds.Intersect(inputKinds).Any()))
+ {
+ ch.AssertNonEmpty(colField);
+ var colFieldType = _inputBuilder.GetFieldTypeOrNull(colField);
+ ch.Assert(colFieldType == typeof(string));
+ var inputColName = inputInstance.GetType().GetField(colField).GetValue(inputInstance);
+ ch.Assert(inputColName is string || inputColName is Optional);
+ var str = inputColName is string ? (string)inputColName : ((Optional)inputColName).Value;
+ if (colName != str)
+ ch.Warning(warning, columnRole, colName, inputColName);
+ else
+ _inputBuilder.TrySetValue(colField, colName);
+ }
+ }
+
public static EntryPointNode Create(
IHostEnvironment env,
string entryPointName,
@@ -639,7 +632,7 @@ public static EntryPointNode Create(
///
/// Checks the given JSON object key-value pair is a valid EntryPoint input and
/// extracts out any variables that need to be populated. These variables will be
- /// added to the EntryPoint context. Input parameters that are not set to variables
+ /// added to the EntryPoint context. Input parameters that are not set to variables
/// will be immediately set using the input builder instance.
///
private void CheckAndSetInputValue(KeyValuePair pair)
@@ -699,7 +692,7 @@ private void CheckAndSetInputValue(KeyValuePair pair)
///
/// Checks the given JSON object key-value pair is a valid EntryPoint output.
- /// Extracts out any variables that need to be populated and adds them to the
+ /// Extracts out any variables that need to be populated and adds them to the
/// EntryPoint context.
///
private void CheckAndMarkOutputValue(KeyValuePair pair)
@@ -902,7 +895,7 @@ private object BuildParameterValue(List bindings)
}
public static List ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes,
- ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null)
+ ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null, string name = null)
{
Contracts.AssertValue(env);
env.AssertValue(context);
@@ -918,7 +911,7 @@ public static List ValidateNodes(IHostEnvironment env, RunContex
if (node == null)
throw env.Except("Unexpected node token: '{0}'", nodes[i]);
- string name = node[FieldNames.Name].Value();
+ string nodeName = node[FieldNames.Name].Value();
var inputs = node[FieldNames.Inputs] as JObject;
if (inputs == null && node[FieldNames.Inputs] != null)
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Inputs, node[FieldNames.Inputs]);
@@ -927,7 +920,7 @@ public static List ValidateNodes(IHostEnvironment env, RunContex
if (outputs == null && node[FieldNames.Outputs] != null)
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Outputs, node[FieldNames.Outputs]);
- var id = context.GenerateId(name);
+ var id = context.GenerateId(nodeName);
var unexpectedFields = node.Properties().Where(
x => x.Name != FieldNames.Name && x.Name != FieldNames.Inputs && x.Name != FieldNames.Outputs
&& x.Name != FieldNames.StageId && x.Name != FieldNames.Checkpoint && x.Name != FieldNames.Cost);
@@ -942,7 +935,7 @@ public static List ValidateNodes(IHostEnvironment env, RunContex
ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name)));
}
- result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight));
+ result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, nodeName, inputs, outputs, checkpoint, stageId, cost, label, group, weight, name));
}
ch.Done();
@@ -1080,8 +1073,8 @@ protected VariableBinding(string varName)
VariableName = varName;
}
- // A regex to validate an EntryPoint variable value accessor string. Valid EntryPoint variable names
- // can be any sequence of alphanumeric characters and underscores. They must start with a letter or underscore.
+ // A regex to validate an EntryPoint variable value accessor string. Valid EntryPoint variable names
+ // can be any sequence of alphanumeric characters and underscores. They must start with a letter or underscore.
// An EntryPoint variable can be followed with an array or dictionary specifier, which begins
// with '[', contains either an integer or alphanumeric string, optionally wrapped in single-quotes,
// followed with ']'.
diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
index 5583c66df0..94b67af670 100644
--- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
@@ -146,7 +146,7 @@ public static TOut Train(IHost host, TArg input,
TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);
ch.Trace("Binding columns");
- var roleMappedData = TrainUtils.CreateExamples(view, label, feature, group, weight, name, custom);
+ var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom);
RoleMappedData cachedRoleMappedData = roleMappedData;
Cache.CachingType? cachingType = null;
@@ -164,9 +164,8 @@ public static TOut Train(IHost host, TArg input,
}
case CachingOptions.Auto:
{
- ITrainerEx trainerEx = trainer as ITrainerEx;
// REVIEW: we should switch to hybrid caching in future.
- if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching))
+ if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
// default to Memory so mml is on par with maml
cachingType = Cache.CachingType.Memory;
break;
@@ -184,7 +183,7 @@ public static TOut Train(IHost host, TArg input,
Data = roleMappedData.Data,
Caching = cachingType.Value
}).OutputData;
- cachedRoleMappedData = RoleMappedData.Create(cacheView, roleMappedData.Schema.GetColumnRoleNames());
+ cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
}
var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples);
diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs
index e5afd8dbb5..4d3b765114 100644
--- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs
@@ -14,8 +14,8 @@
namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils
{
///
- /// The class that creates and wraps around an instance of an input object and gradually populates all fields, keeping track of missing
- /// required values. The values can be set from their JSON representation (during the graph parsing stage), as well as directly
+ /// The class that creates and wraps around an instance of an input object and gradually populates all fields, keeping track of missing
+ /// required values. The values can be set from their JSON representation (during the graph parsing stage), as well as directly
/// (in the process of graph execution).
///
public sealed class InputBuilder
@@ -515,7 +515,7 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut
}
///
- /// Ensures that the given value can be assigned to an entry point field with
+ /// Ensures that the given value can be assigned to an entry point field with
/// type . This method will wrap the value in the option
/// type if needed and throw an exception if the value isn't assignable.
///
@@ -791,7 +791,7 @@ public static class Range
///
public static class Deprecated
{
- public new static string ToString() => "Deprecated";
+ public static new string ToString() => "Deprecated";
public const string Message = "Message";
}
@@ -800,7 +800,7 @@ public static class Deprecated
///
public static class SweepableLongParam
{
- public new static string ToString() => "SweepRange";
+ public static new string ToString() => "SweepRange";
public const string RangeType = "RangeType";
public const string Max = "Max";
public const string Min = "Min";
@@ -814,7 +814,7 @@ public static class SweepableLongParam
///
public static class SweepableFloatParam
{
- public new static string ToString() => "SweepRange";
+ public static new string ToString() => "SweepRange";
public const string RangeType = "RangeType";
public const string Max = "Max";
public const string Min = "Min";
@@ -828,14 +828,14 @@ public static class SweepableFloatParam
///
public static class SweepableDiscreteParam
{
- public new static string ToString() => "SweepRange";
+ public static new string ToString() => "SweepRange";
public const string RangeType = "RangeType";
public const string Options = "Values";
}
public static class PipelineSweeperSupportedMetrics
{
- public new static string ToString() => "SupportedMetric";
+ public static new string ToString() => "SupportedMetric";
public const string Auc = BinaryClassifierEvaluator.Auc;
public const string AccuracyMicro = Data.MultiClassClassifierEvaluator.AccuracyMicro;
public const string AccuracyMacro = MultiClassClassifierEvaluator.AccuracyMacro;
diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs
index af726fa758..055b2fa299 100644
--- a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs
@@ -74,13 +74,13 @@ public void Save(IHostEnvironment env, Stream stream)
{
// REVIEW: address the asymmetry in the way we're loading and saving the model.
// Effectively, we have methods to load the transform model from a model.zip, but don't have
- // methods to compose the model.zip out of transform model, predictor and role mappings
+ // methods to compose the model.zip out of transform model, predictor and role mappings
// (we use the TrainUtils.SaveModel that does all three).
// Create the chain of transforms for saving.
IDataView data = new EmptyDataView(env, _transformModel.InputSchema);
data = _transformModel.Apply(env, data);
- var roleMappedData = RoleMappedData.CreateOpt(data, _roleMappings);
+ var roleMappedData = new RoleMappedData(data, _roleMappings, opt: true);
TrainUtils.SaveModel(env, ch, stream, _predictor, roleMappedData);
ch.Done();
@@ -102,7 +102,7 @@ public void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedDat
env.CheckValue(input, nameof(input));
input = _transformModel.Apply(env, input);
- roleMappedData = RoleMappedData.CreateOpt(input, _roleMappings);
+ roleMappedData = new RoleMappedData(input, _roleMappings, opt: true);
predictor = _predictor;
}
@@ -141,7 +141,7 @@ public RoleMappedSchema GetTrainingSchema(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));
var predInput = _transformModel.Apply(env, new EmptyDataView(env, _transformModel.InputSchema));
- var trainRms = RoleMappedSchema.CreateOpt(predInput.Schema, _roleMappings);
+ var trainRms = new RoleMappedSchema(predInput.Schema, _roleMappings, opt: true);
return trainRms;
}
}
diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
index 96ce0acac9..312a92bccc 100644
--- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
@@ -15,9 +15,9 @@ namespace Microsoft.ML.Runtime.EntryPoints
///
/// This module handles scoring a against a new dataset.
/// As a result, we return both the scored data and the scoring transform as a .
- ///
- /// REVIEW: This module does not support 'exotic' scoring scenarios, like recommendation and quantile regression
- /// (those where the user-defined scorer settings are necessary to identify the scorer). We could resolve this by
+ ///
+ /// REVIEW: This module does not support 'exotic' scoring scenarios, like recommendation and quantile regression
+ /// (those where the user-defined scorer settings are necessary to identify the scorer). We could resolve this by
/// adding a sub-component for extra scorer args, or by creating specialized EPs for these scenarios.
///
public static partial class ScoreModel
diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
index 9edc87df6d..ed8e7d56e2 100644
--- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
@@ -43,7 +43,7 @@ public sealed class TransformModel : ITransformModel
///
/// The resulting schema once applied to this model. The might have
- /// columns that are not needed by this transform and these columns will be seen in the
+ /// columns that are not needed by this transform and these columns will be seen in the
/// produced by this transform.
///
public ISchema OutputSchema => _chain.Schema;
diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
index 39a5f31c38..8e4f3be56c 100644
--- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
@@ -57,7 +57,7 @@ public static class OverallMetrics
}
///
- /// The anomaly detection evaluator outputs a data view by this name, which contains the the examples
+ /// The anomaly detection evaluator outputs a data view by this name, which contains the the examples
/// with the top scores in the test set. It contains the three columns listed below, with each row corresponding
/// to one test example.
///
@@ -796,7 +796,7 @@ public static CommonOutputs.CommonEvaluateOutput AnomalyDetection(IHostEnvironme
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new AnomalyDetectionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs b/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs
index f45aacd58e..342e1d3529 100644
--- a/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs
@@ -408,7 +408,7 @@ public UnweightedAuPrcAggregator(IRandom rand, int reservoirSize)
///
/// Compute the AUPRC using the "lower trapesoid" estimator, as described in the paper
- /// .
+ /// http://www.ecmlpkdd2013.org/wp-content/uploads/2013/07/aucpr_2013ecml_corrected.pdf.
///
protected override Double ComputeWeightedAuPrcCore(out Double unweighted)
{
@@ -482,7 +482,7 @@ public WeightedAuPrcAggregator(IRandom rand, int reservoirSize)
///
/// Compute the AUPRC using the "lower trapesoid" estimator, as described in the paper
- /// .
+ /// http://www.ecmlpkdd2013.org/wp-content/uploads/2013/07/aucpr_2013ecml_corrected.pdf.
///
protected override Double ComputeWeightedAuPrcCore(out Double unweighted)
{
diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
index 90078da9ee..71c08eecd0 100644
--- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
@@ -88,7 +88,7 @@ public enum Metrics
///
/// Binary classification evaluator outputs a data view with this name, which contains the p/r data.
- /// It contains the columns listed below, and in case data also contains a weight column, it contains
+ /// It contains the columns listed below, and in case data also contains a weight column, it contains
/// also columns for the weighted values.
/// and false positive rate.
///
@@ -1211,7 +1211,7 @@ public override IEnumerable GetOverallMetricColumns()
}
// This method saves the p/r plots, and returns the p/r metrics data view.
- // In case there are results from multiple folds, they are averaged using
+ // In case there are results from multiple folds, they are averaged using
// vertical averaging for the p/r plot, and appended using AppendRowsDataView for
// the p/r data view.
private bool TryGetPrMetrics(Dictionary[] metrics, out IDataView pr)
@@ -1455,7 +1455,7 @@ public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new BinaryClassifierMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
index 907760649f..bec1ac144a 100644
--- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
@@ -776,7 +776,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args)
string feat = EvaluateUtils.GetColName(_featureCol, schema.Feature, DefaultColumnNames.Features);
if (!schema.Schema.TryGetColumnIndex(feat, out int featCol))
throw Host.ExceptUserArg(nameof(Arguments.FeatureColumn), "Features column '{0}' not found", feat);
- yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, feat);
+ yield return RoleMappedSchema.ColumnRole.Feature.Bind(feat);
}
}
@@ -867,7 +867,7 @@ public static CommonOutputs.CommonEvaluateOutput Clustering(IHostEnvironment env
nameof(ClusteringMamlEvaluator.Arguments.FeatureColumn),
input.FeatureColumn, DefaultColumnNames.Features);
var evaluator = new ClusteringMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, features, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, features, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
index ef7183c2fa..c628cff1e4 100644
--- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
+++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
@@ -217,7 +217,7 @@ protected ValueGetter> GetKeyValueGetter(AggregatorDictionaryBas
///
/// This is a helper class for evaluators deriving from EvaluatorBase, used for computing aggregate metrics.
/// Aggregators should keep track of the number of passes done. The method should get
- /// the input getters of the given IRow that are needed for the current pass, assuming that all the needed column
+ /// the input getters of the given IRow that are needed for the current pass, assuming that all the needed column
/// information is stored in the given .
/// In the aggregator should call the getters once, and process the input as needed.
/// increments the pass count after each pass.
@@ -251,7 +251,7 @@ public bool Start()
return IsActive();
}
- ///
+ ///
/// This method should get the getters of the new IRow that are needed for the next pass.
///
public abstract void InitializeNextPass(IRow row, RoleMappedSchema schema);
@@ -370,7 +370,7 @@ private static AggregatorDictionaryBase CreateDictionary(RoleMappedSchem
}
///
- /// This method calls the getter of the stratification column, and returns the aggregator corresponding to
+ /// This method calls the getter of the stratification column, and returns the aggregator corresponding to
/// the stratification value.
///
///
diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
index 0e2de21530..942d139425 100644
--- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
+++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
@@ -115,10 +115,10 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, ISchema sche
ColumnInfo info;
if (!string.IsNullOrWhiteSpace(name))
{
-#pragma warning disable TLC_ContractsNameUsesNameof
+#pragma warning disable MSML_ContractsNameUsesNameof
if (!ColumnInfo.TryCreateFromName(schema, name, out info))
throw ectx.ExceptUserArg(argName, "Score column is missing");
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
return info;
}
@@ -145,9 +145,9 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, ISchema sche
if (!string.IsNullOrWhiteSpace(defName) && ColumnInfo.TryCreateFromName(schema, defName, out info))
return info;
-#pragma warning disable TLC_ContractsNameUsesNameof
+#pragma warning disable MSML_ContractsNameUsesNameof
throw ectx.ExceptUserArg(argName, "Score column is missing");
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
}
///
@@ -168,12 +168,12 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, ISchem
if (!string.IsNullOrWhiteSpace(name))
{
ColumnInfo info;
-#pragma warning disable TLC_ContractsNameUsesNameof
+#pragma warning disable MSML_ContractsNameUsesNameof
if (!ColumnInfo.TryCreateFromName(schema, name, out info))
throw ectx.ExceptUserArg(argName, "{0} column is missing", valueKind);
if (!testType(info.Type))
throw ectx.ExceptUserArg(argName, "{0} column has incompatible type", valueKind);
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
return info;
}
@@ -332,15 +332,15 @@ public static IEnumerable> GetMetrics(IDataView met
if (getters[i] != null)
{
getters[i](ref metricVal);
- // For R8 valued columns the metric name is the column name.
+ // For R8 valued columns the metric name is the column name.
yield return new KeyValuePair(schema.GetColumnName(i), metricVal);
}
else if (getVectorMetrics && vBufferGetters[i] != null)
{
vBufferGetters[i](ref metricVals);
- // For R8 vector valued columns the names of the metrics are the column name,
- // followed by the slot name if it exists, or "Label_i" if it doesn't.
+ // For R8 vector valued columns the names of the metrics are the column name,
+ // followed by the slot name if it exists, or "Label_i" if it doesn't.
VBuffer names = default(VBuffer);
var size = schema.GetColumnType(i).VectorSize;
var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i);
@@ -386,7 +386,7 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
env.CheckValue(input, nameof(input));
env.CheckParam(curFold >= 0, nameof(curFold));
- // We use the first column in the data view as an input column to the LambdaColumnMapper,
+ // We use the first column in the data view as an input column to the LambdaColumnMapper,
// because it must have an input.
int inputCol = 0;
while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol))
@@ -428,7 +428,7 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
env.CheckParam(curFold >= 0, nameof(curFold));
env.CheckParam(numFolds > 0, nameof(numFolds));
- // We use the first column in the data view as an input column to the LambdaColumnMapper,
+ // We use the first column in the data view as an input column to the LambdaColumnMapper,
// because it must have an input.
int inputCol = 0;
while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol))
@@ -444,7 +444,7 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
///
/// This method takes an array of data views and a specified input vector column, and adds a new output column to each of the data views.
- /// First, we find the union set of the slot names in the different data views. Next we define a new vector column for each
+ /// First, we find the union set of the slot names in the different data views. Next we define a new vector column for each
/// data view, indexed by the union of the slot names. For each data view, every slot value is the value in the slot corresponding
/// to its slot name in the original column. If a reconciled slot name does not exist in an input column, the value in the output
/// column is def.
@@ -552,14 +552,15 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
}
}
- private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
- out int[] indices, out Dictionary reconciledKeyNames)
+ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
+ int[] indices, Dictionary reconciledKeyNames)
{
+ Contracts.AssertValue(indices);
+ Contracts.AssertValue(reconciledKeyNames);
+
var dvCount = schemas.Length;
var keyValueMappers = new int[dvCount][];
- var keyNamesCur = default(VBuffer);
- indices = new int[dvCount];
- reconciledKeyNames = new Dictionary();
+ var keyNamesCur = default(VBuffer);
for (int i = 0; i < dvCount; i++)
{
var schema = schemas[i];
@@ -567,10 +568,11 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
throw Contracts.Except($"Schema number {i} does not contain column '{columnName}'");
var type = schema.GetColumnType(indices[i]);
+ var keyValueType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, indices[i]);
if (type.IsVector != isVec)
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type");
- if (!schema.HasKeyNames(indices[i], type.ItemType.KeyCount))
- throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have text key values");
+ if (keyValueType == null || keyValueType.ItemType.RawType != typeof(T))
+ throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type of key values");
if (!type.ItemType.IsKey || type.ItemType.RawKind != DataKind.U4)
throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{type.ItemType}'");
@@ -580,7 +582,7 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
foreach (var kvp in keyNamesCur.Items(true))
{
var key = kvp.Key;
- var name = kvp.Value;
+ var name = new DvText(kvp.Value.ToString());
if (!reconciledKeyNames.ContainsKey(name))
reconciledKeyNames[name] = reconciledKeyNames.Count;
keyValueMappers[i][key] = reconciledKeyNames[name];
@@ -591,21 +593,22 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
///
/// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
- /// First, we find the union set of the key values in the different data views. Next we define a new key column for each
+ /// First, we find the union set of the key values in the different data views. Next we define a new key column for each
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
/// corresponding to the key value in the original column.
///
- public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
+ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));
var dvCount = views.Length;
- Dictionary keyNames;
- int[] indices;
// Create mappings from the original key types to the reconciled key type.
- var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, false, out indices, out keyNames);
+ var indices = new int[dvCount];
+ var keyNames = new Dictionary();
+ // We use MarshalInvoke so that we can call MapKeys with the correct generic: keyValueType.RawType.
+ var keyValueMappers = Utils.MarshalInvoke(MapKeys, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, false, indices, keyNames);
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
var keyNamesVBuffer = new VBuffer(keyNames.Count, keyNames.Keys.ToArray());
ValueGetter> keyValueGetter =
@@ -629,20 +632,51 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s
}
}
+ ///
+ /// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
+ /// First, we find the union set of the key values in the different data views. Next we define a new key column for each
+ /// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
+ /// corresponding to the key value in the original column.
+ ///
+ public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView[] views, string columnName, int keyCount)
+ {
+ Contracts.CheckNonEmpty(views, nameof(views));
+ Contracts.CheckNonEmpty(columnName, nameof(columnName));
+
+ var keyType = new KeyType(DataKind.U4, 0, keyCount);
+
+ // For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
+ for (int i = 0; i < views.Length; i++)
+ {
+ if (!views[i].Schema.TryGetColumnIndex(columnName, out var index))
+ throw env.Except($"Data view {i} doesn't contain a column '{columnName}'");
+ ValueMapper mapper =
+ (ref uint src, ref uint dst) =>
+ {
+ if (src > keyCount)
+ dst = 0;
+ else
+ dst = src;
+ };
+ views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
+ views[i].Schema.GetColumnType(index), keyType, mapper);
+ }
+ }
+
///
/// This method is similar to , but it reconciles the key values over vector
/// input columns.
///
- public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
+ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));
var dvCount = views.Length;
- Dictionary keyNames;
- int[] columnIndices;
- var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, true, out columnIndices, out keyNames);
+ var keyNames = new Dictionary();
+ var columnIndices = new int[dvCount];
+ var keyValueMappers = Utils.MarshalInvoke(MapKeys, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, true, columnIndices, keyNames);
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
var keyNamesVBuffer = new VBuffer(keyNames.Count, keyNames.Keys.ToArray());
ValueGetter> keyValueGetter =
@@ -736,7 +770,7 @@ public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env,
var foldDataViews = perInstance.Select(getPerInstance).ToArray();
if (collate)
{
- var combined = AppendPerInstanceDataViews(env, foldDataViews, out variableSizeVectorColumnNames);
+ var combined = AppendPerInstanceDataViews(env, perInstance[0].Schema.Label?.Name, foldDataViews, out variableSizeVectorColumnNames);
return new[] { combined };
}
else
@@ -767,7 +801,8 @@ public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataVie
return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray());
}
- private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnumerable foldDataViews, out string[] variableSizeVectorColumnNames)
+ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string labelColName,
+ IEnumerable foldDataViews, out string[] variableSizeVectorColumnNames)
{
Contracts.AssertValue(env);
env.AssertValue(foldDataViews);
@@ -776,7 +811,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
// This is a dictionary from the column name to its vector size.
var vectorSizes = new Dictionary();
var firstDvSlotNames = new Dictionary>();
- var firstDvKeyColumns = new List();
+ ColumnType labelColKeyValuesType = null;
+ var firstDvKeyWithNamesColumns = new List();
+ var firstDvKeyNoNamesColumns = new Dictionary();
var firstDvVectorKeyColumns = new List();
var variableSizeVectorColumnNamesList = new List();
var list = new List();
@@ -822,10 +859,20 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
else
vectorSizes.Add(name, type.VectorSize);
}
- else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
+ else if (dvNumber == 0 && name == labelColName)
{
// The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
- firstDvKeyColumns.Add(name);
+ labelColKeyValuesType = dv.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i);
+ }
+ else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
+ firstDvKeyWithNamesColumns.Add(name);
+ else if (type.KeyCount > 0 && name != labelColName && !dv.Schema.HasKeyNames(i, type.KeyCount))
+ {
+ // For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4.
+ if (!firstDvKeyNoNamesColumns.ContainsKey(name))
+ firstDvKeyNoNamesColumns[name] = type.KeyCount;
+ if (firstDvKeyNoNamesColumns[name] < type.KeyCount)
+ firstDvKeyNoNamesColumns[name] = type.KeyCount;
}
}
var idv = dv;
@@ -839,26 +886,34 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
list.Add(idv);
dvNumber++;
}
-
variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray();
- if (variableSizeVectorColumnNamesList.Count == 0 && firstDvKeyColumns.Count == 0)
- return AppendRowsDataView.Create(env, null, list.ToArray());
var views = list.ToArray();
- foreach (var keyCol in firstDvKeyColumns)
- ReconcileKeyValues(env, views, keyCol);
+ foreach (var keyCol in firstDvKeyWithNamesColumns)
+ ReconcileKeyValues(env, views, keyCol, TextType.Instance);
+ if (labelColKeyValuesType != null)
+ ReconcileKeyValues(env, views, labelColName, labelColKeyValuesType.ItemType);
+ foreach (var keyCol in firstDvKeyNoNamesColumns)
+ ReconcileKeyValuesWithNoNames(env, views, keyCol.Key, keyCol.Value);
foreach (var vectorKeyCol in firstDvVectorKeyColumns)
- ReconcileVectorKeyValues(env, views, vectorKeyCol);
+ ReconcileVectorKeyValues(env, views, vectorKeyCol, TextType.Instance);
Func keyToValue =
(idv, i) =>
{
- foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns))
+ foreach (var keyCol in firstDvVectorKeyColumns.Concat(firstDvKeyWithNamesColumns).Prepend(labelColName))
{
+ if (keyCol == labelColName && labelColKeyValuesType == null)
+ continue;
idv = new KeyToValueTransform(env, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv);
var hidden = FindHiddenColumns(idv.Schema, keyCol);
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
}
+ foreach (var keyCol in firstDvKeyNoNamesColumns)
+ {
+ var hidden = FindHiddenColumns(idv.Schema, keyCol.Key);
+ idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
+ }
return idv;
};
@@ -938,7 +993,7 @@ private static List GetMetricNames(IChannel ch, ISchema schema, IRow row
ch.Assert(Utils.Size(vBufferGetters) == schema.ColumnCount);
// Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns
- // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't.
+ // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't.
VBuffer names = default(VBuffer);
int metricCount = 0;
var metricNames = new List();
@@ -1271,7 +1326,7 @@ private static void AddScalarColumn(this ArrayDataViewBuilder dvBldr, ISchema sc
}
///
- /// Takes a data view containing one or more rows of metrics, and returns a data view containing additional
+ /// Takes a data view containing one or more rows of metrics, and returns a data view containing additional
/// rows with the average and the standard deviation of the metrics in the input data view.
///
public static IDataView CombineFoldMetricsDataViews(IHostEnvironment env, IDataView data, int numFolds)
@@ -1454,8 +1509,8 @@ private static string GetOverallMetricsAsString(double[] sumMetrics, double[] su
}
// This method returns a string representation of a set of metrics. If there are stratification columns, it looks for columns named
- // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks
- // for a DvBool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column.
+ // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks
+ // for a DvBool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column.
// If nonAveragedCols is non-null, it computes the average and standard deviation over all the relevant rows and populates
// nonAveragedCols with columns that are either hidden, or are not of a type that we can display (i.e., either a numeric column,
// or a known length vector of doubles).
@@ -1694,7 +1749,7 @@ public static class MetricKinds
{
///
/// This data view contains the confusion matrix for N-class classification. It has N rows, and each row has
- /// the following columns:
+ /// the following columns:
/// * Count (vector indicating how many examples of this class were predicted as each one of the classes). This column
/// should have metadata containing the class names.
/// * (Optional) Weight (vector with the total weight of the examples of this class that were predicted as each one of the classes).
diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
index bbb53ba631..2af1b54d92 100644
--- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
@@ -10,10 +10,10 @@
namespace Microsoft.ML.Runtime.Data
{
///
- /// This interface is used by Maml components (the , the
+ /// This interface is used by Maml components (the , the
/// and the to evaluate, print and save the results.
- /// The input to the and the methods
- /// should be assumed to contain only the following column roles: label, group, weight and name. Any other columns needed for
+ /// The input to the and the methods
+ /// should be assumed to contain only the following column roles: label, group, weight and name. Any other columns needed for
/// evaluation should be searched for by name in the .
///
public interface IMamlEvaluator : IEvaluator
@@ -95,7 +95,7 @@ protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string sco
public Dictionary Evaluate(RoleMappedData data)
{
- data = RoleMappedData.Create(data.Data, GetInputColumnRoles(data.Schema, needStrat: true));
+ data = new RoleMappedData(data.Data, GetInputColumnRoles(data.Schema, needStrat: true));
return Evaluator.Evaluate(data);
}
@@ -108,7 +108,7 @@ public Dictionary Evaluate(RoleMappedData data)
: StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col));
if (needName && schema.Name != null)
- roles = roles.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, schema.Name.Name));
+ roles = roles.Prepend(RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Name));
return roles.Concat(GetInputColumnRolesCore(schema));
}
@@ -126,12 +126,12 @@ public Dictionary Evaluate(RoleMappedData data)
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name);
// Get the label column information.
- string lab = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label);
- yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, lab);
+ string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label);
+ yield return RoleMappedSchema.ColumnRole.Label.Bind(label);
- var weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null);
+ string weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null);
if (!string.IsNullOrEmpty(weight))
- yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, weight);
+ yield return RoleMappedSchema.ColumnRole.Weight.Bind(weight);
}
public virtual IEnumerable GetOverallMetricColumns()
@@ -203,7 +203,7 @@ public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData)
Host.AssertValue(scoredData);
var schema = scoredData.Schema;
- var dataEval = RoleMappedData.Create(scoredData.Data, GetInputColumnRoles(schema));
+ var dataEval = new RoleMappedData(scoredData.Data, GetInputColumnRoles(schema));
return Evaluator.GetPerInstanceMetrics(dataEval);
}
@@ -260,7 +260,7 @@ protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMap
public IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance)
{
Host.CheckValue(perInstance, nameof(perInstance));
- var data = RoleMappedData.Create(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true));
+ var data = new RoleMappedData(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true));
return WrapPerInstance(data);
}
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
index 3b5e5fc910..d57835b168 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
@@ -784,7 +784,7 @@ public static CommonOutputs.CommonEvaluateOutput MultiOutputRegression(IHostEnvi
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new MultiOutputRegressionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
index 5507176fb4..fd23e7c3b0 100644
--- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
@@ -256,6 +256,8 @@ public Double MacroAvgAccuracy
{
get
{
+ if (_numInstances == 0)
+ return 0;
Double macroAvgAccuracy = 0;
int countOfNonEmptyClasses = 0;
for (int i = 0; i < _numClasses; ++i)
@@ -267,8 +269,7 @@ public Double MacroAvgAccuracy
}
}
- Contracts.Assert(countOfNonEmptyClasses > 0);
- return macroAvgAccuracy / countOfNonEmptyClasses;
+ return countOfNonEmptyClasses > 0 ? macroAvgAccuracy / countOfNonEmptyClasses : 0;
}
}
@@ -1069,7 +1070,7 @@ public static CommonOutputs.ClassificationEvaluateOutput MultiClass(IHostEnviron
MatchColumns(host, input, out string label, out string weight, out string name);
var evaluator = new MultiClassMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
index 6d61f6b965..fb8d9c1249 100644
--- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
@@ -556,7 +556,7 @@ public static CommonOutputs.CommonEvaluateOutput QuantileRegression(IHostEnviron
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new QuantileRegressionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
index ae9c2a8594..616cff8394 100644
--- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
@@ -48,7 +48,7 @@ public sealed class Arguments
public const string MaxDcg = "MaxDCG";
///
- /// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
+ /// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
/// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one
/// group in the scored data.
///
@@ -851,7 +851,7 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args)
{
var cols = base.GetInputColumnRolesCore(schema);
var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId);
- return cols.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, groupIdCol));
+ return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol));
}
protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
@@ -1039,7 +1039,7 @@ public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, R
nameof(RankerMamlEvaluator.Arguments.GroupIdColumn),
input.GroupIdColumn, DefaultColumnNames.GroupId);
var evaluator = new RankerMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, groupId, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, groupId, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
index 4292e13b8d..1804ce429f 100644
--- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
@@ -354,7 +354,7 @@ public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env
string name;
MatchColumns(host, input, out label, out weight, out name);
var evaluator = new RegressionMamlEvaluator(host, input);
- var data = TrainUtils.CreateExamples(input.Data, label, null, null, weight, name);
+ var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
index 38fb6075ce..8d5b0fd2d0 100644
--- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
+++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
@@ -15,7 +15,6 @@
-
diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
index 0e25840c3b..36d839b93d 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
@@ -19,7 +19,7 @@ public interface ICanSaveOnnx
}
///
- /// This data model component is savable as Onnx.
+ /// This data model component is savable as ONNX.
///
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
{
diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
index 9759ce0c6c..230f2600a3 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
@@ -2,245 +2,101 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
-using System.Linq;
-using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
using Microsoft.ML.Runtime.Data;
namespace Microsoft.ML.Runtime.Model.Onnx
{
///
- /// A context for defining a ONNX output.
+ /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
+ /// same context object is iteratively given to exportable components via the interface
+ /// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is
+ /// given to a component, all other components up to that component have already attempted to express themselves in
+ /// this context, with their outputs possibly available in the ONNX graph.
///
- public sealed class OnnxContext
+ public abstract class OnnxContext
{
- private readonly List _nodes;
- private readonly List _inputs;
- private readonly List _intermediateValues;
- private readonly List _outputs;
- private readonly Dictionary _columnNameMap;
- private readonly HashSet _variableMap;
- private readonly HashSet _nodeNames;
- private readonly string _name;
- private readonly string _producerName;
- private readonly IHost _host;
- private readonly string _domain;
- private readonly string _producerVersion;
- private readonly long _modelVersion;
-
- public OnnxContext(IHostEnvironment env, string name, string producerName,
- string producerVersion, long modelVersion, string domain)
- {
- Contracts.CheckValue(env, nameof(env));
- Contracts.CheckValue(name, nameof(name));
- Contracts.CheckValue(name, nameof(domain));
-
- _host = env.Register(nameof(OnnxContext));
- _nodes = new List();
- _intermediateValues = new List();
- _inputs = new List();
- _outputs = new List();
- _columnNameMap = new Dictionary();
- _variableMap = new HashSet();
- _nodeNames = new HashSet();
- _name = name;
- _producerName = producerName;
- _producerVersion = producerVersion;
- _modelVersion = modelVersion;
- _domain = domain;
- }
-
- public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
-
- ///
- /// Stops tracking a column. If removeVariable is true then it also removes the
- /// variable associated with it, this is useful in the event where an output variable is
- /// created before realizing the transform cannot actually save as ONNX.
- ///
- /// IDataView column name to stop tracking
- /// Remove associated ONNX variable at the time.
- public void RemoveColumn(string colName, bool removeVariable)
- {
-
- if (removeVariable)
- {
- foreach (var val in _intermediateValues)
- {
- if (val.Name == _columnNameMap[colName])
- {
- _intermediateValues.Remove(val);
- break;
- }
- }
- }
-
- if (_columnNameMap.ContainsKey(colName))
- _columnNameMap.Remove(colName);
- }
-
- ///
- /// Removes an ONNX variable. If removeColumn is true then it also removes the
- /// IDataView column associated with it.
- ///
- /// ONNX variable to remove.
- /// IDataView column to stop tracking
- public void RemoveVariable(string variableName, bool removeColumn)
- {
- _host.Assert(_columnNameMap.ContainsValue(variableName));
- if (removeColumn)
- {
- foreach (var val in _intermediateValues)
- {
- if (val.Name == variableName)
- {
- _intermediateValues.Remove(val);
- break;
- }
- }
- }
-
- string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key;
-
- Contracts.Assert(_variableMap.Contains(columnName));
-
- _columnNameMap.Remove(columnName);
- _variableMap.Remove(columnName);
- }
-
///
/// Generates a unique name for the node based on a prefix.
///
- public string GetNodeName(string prefix)
- {
- _host.CheckValue(prefix, nameof(prefix));
- return GetUniqueName(prefix, c => _nodeNames.Contains(c));
- }
+ /// The prefix for the node
+ /// A name that has not yet been returned from this function, starting with
+ public abstract string GetNodeName(string prefix);
///
- /// Adds a node to the node list of the graph.
+ /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
+ /// safely call .
///
- ///
- public void AddNode(NodeProto node)
- {
- _host.CheckValue(node, nameof(node));
- _host.Assert(!_nodeNames.Contains(node.Name));
-
- _nodeNames.Add(node.Name);
- _nodes.Add(node);
- }
+ /// The data view column name
+ /// Whether the column is mapped in this context
+ public abstract bool ContainsColumn(string colName);
///
- /// Generates a unique name based on a prefix.
+ /// Stops tracking a column.
///
- private string GetUniqueName(string prefix, Func pred)
- {
- _host.CheckValue(prefix, nameof(prefix));
- _host.CheckValue(pred, nameof(pred));
-
- if (!pred(prefix))
- return prefix;
-
- int count = 0;
- while (pred(prefix + count++)) ;
- return prefix + --count;
- }
+ /// Column name to stop tracking
+ /// Remove associated ONNX variable. This is useful in the event where an output
+ /// variable is created through before realizing
+ /// the transform cannot actually save as ONNX.
+ public abstract void RemoveColumn(string colName, bool removeVariable = false);
///
- /// Retrieves the variable name that maps to the IDataView column name at a
- /// given point in the pipeline execution.
+ /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the column associated with it.
///
- /// Column Name mapping.
- public string GetVariableName(string colName)
- {
- _host.CheckValue(colName, nameof(colName));
- _host.Assert(_columnNameMap.ContainsKey(colName));
-
- return _columnNameMap[colName];
- }
-
- ///
- /// Retrieves the variable name that maps to the IDataView column name at a
- /// given point in the pipeline execution.
- ///
- /// Column Name mapping.
- public string TryGetVariableName(string colName)
- {
- if (_columnNameMap.ContainsKey(colName))
- return GetVariableName(colName);
-
- return null;
- }
-
- ///
- /// Generates a unique column name based on the IDataView column name if
- /// there is a collision between names in the pipeline at any point.
- ///
- /// IDataView column name.
- /// Unique variable name.
- private string AddVariable(string colName)
- {
- _host.CheckValue(colName, nameof(colName));
-
- if (!_columnNameMap.ContainsKey(colName))
- _columnNameMap.Add(colName, colName);
- else
- _columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s));
-
- _variableMap.Add(_columnNameMap[colName]);
- return _columnNameMap[colName];
- }
+ /// ONNX variable to remove. Note that this is an ONNX variable name, not an column name
+ /// IDataView column to stop tracking
+ public abstract void RemoveVariable(string variableName, bool removeColumn);
///
- /// Adds an intermediate column to the list.
+ /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
+ /// 's column names will map to a variable in the ONNX graph if the intermediate steps
+ /// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps
+ /// to the column name at a given point in the pipeline execution. Callers should
+ /// probably confirm with whether a mapping for that data view column
+ /// already exists.
///
- public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false)
- {
-
- colName = AddVariable(colName);
-
- //Let the runtime figure the shape.
- if (!skip)
- {
- _host.CheckValue(type, nameof(type));
-
- _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
- }
-
- return colName;
- }
+ /// The data view column name
+ /// The ONNX variable name corresponding to that data view column
+ public abstract string GetVariableName(string colName);
///
- /// Adds an output variable to the list.
+ /// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
+ /// returns that newly allocated name.
///
- public string AddOutputVariable(ColumnType type, string colName, List dim = null)
- {
- _host.CheckValue(type, nameof(type));
-
- if (!ContainsColumn(colName))
- AddVariable(colName);
-
- colName = GetVariableName(colName);
- _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
- return colName;
- }
+ /// The data view type associated with this column name
+ /// The data view column name
+ /// Whether we should skip the process of establishing the mapping from data view column to
+ /// ONNX variable name.
+ /// The returned value is the name of the variable corresponding
+ public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);
///
- /// Adds an input variable to the list.
+ /// Creates an ONNX node
///
- public void AddInputVariable(ColumnType type, string colName)
- {
- _host.CheckValue(type, nameof(type));
- _host.CheckValue(colName, nameof(colName));
-
- colName = AddVariable(colName);
- _inputs.Add(OnnxUtils.GetModelArgs(type, colName));
- }
+ /// The name of the ONNX operator to apply
+ /// The names of the variables as inputs
+ /// The names of the variables to create as outputs,
+ /// which ought to have been something returned from
+ /// The name of the operator, which ought to be something returned from
+ /// The domain of the ONNX operator, if non-default
+ /// A node added to the in-progress ONNX graph, that attributes can be set on
+ public abstract OnnxNode CreateNode(string opType, IEnumerable inputs,
+ IEnumerable outputs, string name, string domain = null);
///
- /// Makes the ONNX model based on the context.
+ /// Convenience alternative to
+ /// for the case where there is exactly one input and output.
///
- public ModelProto MakeModel()
- => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
+ /// The name of the ONNX operator to apply
+ /// The name of the variable as input
+ /// The name of the variable as output,
+ /// which ought to have been something returned from
+ /// The name of the operator, which ought to be something returned from
+ /// The domain of the ONNX operator, if non-default
+ /// A node added to the in-progress ONNX graph, that attributes can be set on
+ public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
+ => CreateNode(opType, new[] { input }, new[] { output }, name, domain);
}
}
diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
new file mode 100644
index 0000000000..259a6d27d4
--- /dev/null
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime.Model.Onnx
+{
+ ///
+ /// An abstraction for an ONNX node as created by
+ /// .
+ /// That method creates a with inputs and outputs, but this object can modify the node further
+ /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations).
+ ///
+ public abstract class OnnxNode
+ {
+ public abstract void AddAttribute(string argName, double value);
+ public abstract void AddAttribute(string argName, long value);
+ public abstract void AddAttribute(string argName, DvText value);
+ public abstract void AddAttribute(string argName, string value);
+ public abstract void AddAttribute(string argName, bool value);
+
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, string[] value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ }
+}
diff --git a/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs b/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs
index d0923a9962..dfd5ef55fb 100644
--- a/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs
+++ b/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs
@@ -33,7 +33,7 @@ public sealed class BoundPfaContext
///
private readonly Dictionary _nameToVarName;
///
- /// This contains a map of those names in
+ /// This contains a map of those names in
///
private readonly HashSet _unavailable;
diff --git a/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs b/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs
index 55122535d4..c0996beea1 100644
--- a/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs
+++ b/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs
@@ -215,7 +215,7 @@ public static JObject CreateFuncBlock(JArray prms, JToken returnType, JToken doB
/// declaration. So, if you use a record type three times, that means one of the three usages must be
/// accompanied by a full type declaration, whereas the other two can just then identify it by name.
/// This is extremely silly, but there you go.
- ///
+ ///
/// Anyway: this will attempt to add a type to the list of registered types. If it returns true
/// then the caller is responsible, then, for ensuring that their PFA code they are generating contains
/// not only a reference of the type, but a declaration of the type. If however this returns false
diff --git a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs
index 6c28dac997..dfec0913ca 100644
--- a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs
+++ b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs
@@ -147,13 +147,13 @@ private void Run(IChannel ch)
{
RoleMappedData data;
if (trainSchema != null)
- data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames());
+ data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
else
{
// We had a predictor, but no roles stored in the model. Just suppose
// default column names are OK, if present.
- data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label,
- DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name);
+ data = new RoleMappedData(end, DefaultColumnNames.Label,
+ DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
}
var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
diff --git a/src/Microsoft.ML.Data/Model/Repository.cs b/src/Microsoft.ML.Data/Model/Repository.cs
index 7556cc970e..eb665f1bfc 100644
--- a/src/Microsoft.ML.Data/Model/Repository.cs
+++ b/src/Microsoft.ML.Data/Model/Repository.cs
@@ -6,7 +6,6 @@
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
-using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
namespace Microsoft.ML.Runtime.Model
@@ -73,7 +72,7 @@ public void Dispose()
}
}
- // These are the open entries that may contain streams into our _dirTemp.
+ // These are the open entries that may contain streams into our DirTemp.
private List _open;
private bool _disposed;
@@ -108,19 +107,37 @@ internal Repository(bool needDir, IExceptionContext ectx)
PathMap = new Dictionary();
_open = new List();
if (needDir)
- {
- DirTemp = GetTempPath();
- Directory.CreateDirectory(DirTemp);
- }
+ DirTemp = GetShortTempDir();
else
GC.SuppressFinalize(this);
}
- // REVIEW: This should use host environment functionality.
- private static string GetTempPath()
+ private static string GetShortTempDir()
+ {
+ var rnd = RandomUtils.Create();
+ string path;
+ do
+ {
+ path = Path.Combine(Path.GetTempPath(), "TLC_" + rnd.Next().ToString("X"));
+ path = Path.GetFullPath(path);
+ Directory.CreateDirectory(path);
+ }
+ while (!EnsureDirectory(path));
+ return path;
+ }
+
+ private static bool EnsureDirectory(string path)
{
- Guid guid = Guid.NewGuid();
- return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "TLC_" + guid.ToString()));
+ path = Path.GetFullPath(Path.Combine(path, ".lock"));
+ try
+ {
+ using (var stream = new FileStream(path, FileMode.CreateNew))
+ return true;
+ }
+ catch
+ {
+ return false;
+ }
}
~Repository()
@@ -214,7 +231,7 @@ protected void RemoveEntry(Entry ent)
///
/// When building paths to our local file system, we want to force both forward and backward slashes
/// to the system directory separator character. We do this for cases where we either used Windows-specific
- /// path building logic, or concatenated filesystem paths with zip archive entries on Linux.
+ /// path building logic, or concatenated filesystem paths with zip archive entries on Linux.
///
private static string NormalizeForFileSystem(string path) =>
path?.Replace('/', Path.DirectorySeparatorChar).Replace('\\', Path.DirectorySeparatorChar);
@@ -232,7 +249,7 @@ protected void GetPath(out string pathEnt, out string pathTemp, string dir, stri
_ectx.CheckParam(!name.Contains(".."), nameof(name));
// The gymnastics below are meant to deal with bad invocations including absolute paths, etc.
- // That's why we go through it even if _dirTemp is null.
+ // That's why we go through it even if DirTemp is null.
string root = Path.GetFullPath(DirTemp ?? @"x:\dummy");
string entityPath = Path.Combine(root, dir ?? "", name);
entityPath = Path.GetFullPath(entityPath);
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index efa52d2ff6..237afb400e 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -111,7 +111,7 @@ public interface ICalibratorTrainer
///
public interface ICalibrator
{
- /// Given a classifier output, produce the probability
+ /// Given a classifier output, produce the probability
Float PredictProbability(Float output);
/// Get the summary of current calibrator settings
@@ -687,8 +687,7 @@ public static class CalibratorUtils
private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
ITrainer trainer, IPredictor predictor, RoleMappedSchema schema)
{
- var trainerEx = trainer as ITrainerEx;
- if (trainerEx == null || !trainerEx.NeedCalibration)
+ if (!trainer.Info.NeedCalibration)
{
ch.Info("Not training a calibrator because it is not needed.");
return false;
@@ -746,13 +745,10 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat
/// The trainer used to train the predictor.
/// The predictor that needs calibration.
/// The examples to used for calibrator training.
- /// Indicates whether the predictor returned needs to be an .
- /// This parameter is needed for OVA that uses the predictors as s. If it is false,
- /// The predictor returned is an an .
- /// The original predictor, if no calibration is needed,
+ /// The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.
public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
- int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
+ int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
@@ -763,7 +759,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema))
return predictor;
- return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data, needValueMapper);
+ return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data);
}
///
@@ -775,13 +771,10 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
/// The maximum rows to use for calibrator training.
/// The predictor that needs calibration.
/// The examples to used for calibrator training.
- /// Indicates whether the predictor returned needs to be an .
- /// This parameter is needed for OVA that uses the predictors as s. If it is false,
- /// The predictor returned is an an .
- /// The original predictor, if no calibration is needed,
+ /// The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.
public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
- int maxRows, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
+ int maxRows, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
@@ -834,10 +827,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal
}
}
var cali = caliTrainer.FinishTraining(ch);
- return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali, needValueMapper);
+ return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali);
}
- public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali, bool needValueMapper = false)
+ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali)
{
Contracts.Assert(predictor != null);
if (cali == null)
@@ -853,7 +846,7 @@ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironm
var predWithFeatureScores = predictor as IPredictorWithFeatureWeights;
if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer)
return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali);
- if (needValueMapper)
+ if (predictor is IValueMapper)
return new CalibratedPredictor(env, predictor, cali);
return new SchemaBindableCalibratedPredictor(env, predictor, cali);
}
@@ -1443,19 +1436,14 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str
string opType = "Affine";
string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
- var node = OnnxUtils.MakeNode(opType, new List { scoreProbablityColumnNames[0] },
- new List { linearOutput }, ctx.GetNodeName(opType), "ai.onnx");
-
- OnnxUtils.NodeAddAttributes(node, "alpha", ParamA * -1);
- OnnxUtils.NodeAddAttributes(node, "beta", -0.0000001);
-
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] },
+ new[] { linearOutput }, ctx.GetNodeName(opType), "");
+ node.AddAttribute("alpha", ParamA * -1);
+ node.AddAttribute("beta", -0.0000001);
opType = "Sigmoid";
- node = OnnxUtils.MakeNode(opType, new List { linearOutput },
- new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx");
-
- ctx.AddNode(node);
+ node = ctx.CreateNode(opType, new[] { linearOutput },
+ new[] { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "");
return true;
}
diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
index 6fde9a5815..6da402431d 100644
--- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
@@ -82,7 +82,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
/// The type of the label names from the metadata (either
/// originating from the key value metadata of the training label column, or deserialized
/// from the model of a bindable mapper)
- /// Whether we can call with
+ /// Whether we can call with
/// this mapper and expect it to succeed
private static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType)
{
@@ -201,16 +201,14 @@ public override void SaveAsOnnx(OnnxContext ctx)
for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
- //Check if "Probability" column was generated by the base class, only then
+ //Check if "Probability" column was generated by the base class, only then
//label can be predicted.
if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
{
string opType = "Binarizer";
- var node = OnnxUtils.MakeNode(opType, new List { ctx.GetVariableName(outColumnNames[2]) },
- new List { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
-
- OnnxUtils.NodeAddAttributes(node, "threshold", 0.5);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) },
+ new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
+ node.AddAttribute("threshold", 0.5);
}
}
diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
index 91c84a0734..41c12e94ed 100644
--- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Data
{
///
/// This class is a scorer that passes through all the ISchemaBound columns without adding any "derived columns".
- /// It also passes through all metadata (except for possibly changing the score column kind), and adds the
+ /// It also passes through all metadata (except for possibly changing the score column kind), and adds the
/// score set id metadata.
///
@@ -70,7 +70,7 @@ private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper binda
Contracts.AssertValue(roles);
Contracts.AssertValueOrNull(suffix);
- var mapper = bindable.Bind(env, RoleMappedSchema.Create(input, roles));
+ var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));
// We don't actually depend on this invariant, but if this assert fires it means the bindable
// did the wrong thing.
Contracts.Assert(mapper.InputSchema.Schema == input);
diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
index 4832f92cd4..c12fd9b4d1 100644
--- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
@@ -452,7 +452,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
/// The type of the label names from the metadata (either
/// originating from the key value metadata of the training label column, or deserialized
/// from the model of a bindable mapper)
- /// Whether we can call with
+ /// Whether we can call with
/// this mapper and expect it to succeed
public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType)
{
diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
index fe69585b78..2fd039897a 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
@@ -117,7 +117,7 @@ public BindingsImpl ApplyToSchema(ISchema input, ISchemaBindableMapper bindable,
env.AssertValue(bindable);
string scoreCol = RowMapper.OutputSchema.GetColumnName(ScoreColumnIndex);
- var schema = RoleMappedSchema.Create(input, RowMapper.GetInputColumnRoles());
+ var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles());
// Checks compatibility of the predictor input types.
var mapper = bindable.Bind(env, schema);
@@ -148,7 +148,7 @@ public static BindingsImpl Create(ModelLoadContext ctx, ISchema input,
string scoreKind = ctx.LoadNonEmptyString();
string scoreCol = ctx.LoadNonEmptyString();
- var mapper = bindable.Bind(env, RoleMappedSchema.Create(input, roles));
+ var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));
var rowMapper = mapper as ISchemaBoundRowMapper;
env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!");
diff --git a/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs b/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs
index ddb05e3686..0f115bb2f0 100644
--- a/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs
+++ b/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs
@@ -251,7 +251,7 @@ public SequencePredictorSchema(ColumnType type, ref VBuffer keyNames, st
Contracts.CheckParam(keyNames.Length == type.ItemType.KeyCount,
nameof(keyNames), "keyNames length must match type's key count");
// REVIEW: Assuming the caller takes some care, it seems
- // like we can get away with
+ // like we can get away with
_keyNames = keyNames;
_keyNamesType = new VectorType(TextType.Instance, keyNames.Length);
_getKeyNames = GetKeyNames;
diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
index 13cdb126ee..1da5a5562a 100644
--- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
+++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
@@ -123,7 +123,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
}
// For the detail of the following rules, see the following paper.
- // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
+ // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
// Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009.
public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion
@@ -139,9 +139,9 @@ public class Arguments : ArgumentsBase
public int WindowSize = 5;
}
- protected internal Queue PastScores;
+ protected Queue PastScores;
- internal MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
+ private protected MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
: base(args, lowerIsBetter)
{
Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs
index 90f8b64a7c..ca2f2c7b64 100644
--- a/src/Microsoft.ML.Data/Training/TrainerBase.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs
@@ -4,59 +4,32 @@
namespace Microsoft.ML.Runtime.Training
{
- public abstract class TrainerBase : ITrainer, ITrainerEx
+ public abstract class TrainerBase : ITrainer
+ where TPredictor : IPredictor
{
- public const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";
+ ///
+ /// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
+ /// instances were able to be found.
+ ///
+ protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";
- protected readonly IHost Host;
+ protected IHost Host { get; }
public string Name { get; }
public abstract PredictionKind PredictionKind { get; }
- public abstract bool NeedNormalization { get; }
- public abstract bool NeedCalibration { get; }
- public abstract bool WantCaching { get; }
+ public abstract TrainerInfo Info { get; }
protected TrainerBase(IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
- Contracts.CheckNonEmpty(name, nameof(name));
+ env.CheckNonEmpty(name, nameof(name));
Name = name;
Host = env.Register(name);
}
- IPredictor ITrainer.CreatePredictor()
- {
- return CreatePredictorCore();
- }
-
- protected abstract IPredictor CreatePredictorCore();
- }
-
- public abstract class TrainerBase : TrainerBase
- where TPredictor : IPredictor
- {
- protected TrainerBase(IHostEnvironment env, string name)
- : base(env, name)
- {
- }
-
- public abstract TPredictor CreatePredictor();
-
- protected sealed override IPredictor CreatePredictorCore()
- {
- return CreatePredictor();
- }
- }
-
- public abstract class TrainerBase : TrainerBase, ITrainer
- where TPredictor : IPredictor
- {
- protected TrainerBase(IHostEnvironment env, string name)
- : base(env, name)
- {
- }
+ IPredictor ITrainer.Train(TrainContext context) => Train(context);
- public abstract void Train(TDataSet data);
+ public abstract TPredictor Train(TrainContext context);
}
}
diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs
index b2032bfc38..33d3d1490d 100644
--- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs
@@ -400,10 +400,10 @@ protected static IRowCursor CreateCursor(RoleMappedData data, CursOpt opt, IRand
/// delegate of the cursor, indicating what additional options should be specified on subsequent
/// passes over the data. The base implementation checks if any rows were skipped, and if none were
/// skipped, it signals the context that it needn't bother with any filtering checks.
- ///
+ ///
/// Because the result will be "or"-red, a perfectly acceptable implementation is that this
/// return the default , in which case the flags will not ever change.
- ///
+ ///
/// If the cursor was created with a signal delegate, the return value of this method will be sent
/// to that delegate.
///
diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
index 5efc9264f1..1459f55cab 100644
--- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
@@ -58,6 +58,20 @@ public bool TryUnparse(StringBuilder sb)
public sealed class Arguments
{
+ public Arguments()
+ {
+
+ }
+
+ internal Arguments(params string[] columns)
+ {
+ Column = new Column[columns.Length];
+ for (int i = 0; i < columns.Length; i++)
+ {
+ Column[i] = new Column() { Source = columns[i], Name = columns[i] };
+ }
+ }
+
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
@@ -442,6 +456,17 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ChooseColumns";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Names of the columns to choose.
+ public ChooseColumnsTransform(IHostEnvironment env, IDataView input, params string[] columns)
+ : this(env, new Arguments(columns), input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs b/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs
index 58eee5430b..2347d2c679 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs
@@ -324,17 +324,17 @@ protected ColumnBindingsBase(ISchema input, bool user, params string[] names)
if (string.IsNullOrWhiteSpace(name))
{
throw user ?
-#pragma warning disable TLC_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
+#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
Contracts.ExceptUserArg(standardColumnArgName, "New column needs a name") :
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
Contracts.ExceptDecode("New column needs a name");
}
if (_nameToInfoIndex.ContainsKey(name))
{
throw user ?
-#pragma warning disable TLC_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
+#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
Contracts.ExceptUserArg(standardColumnArgName, "New column '{0}' specified multiple times", name) :
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
Contracts.ExceptDecode("New column '{0}' specified multiple times", name);
}
_nameToInfoIndex.Add(name, iinfo);
@@ -686,10 +686,10 @@ protected ManyToOneColumnBindingsBase(ManyToOneColumn[] column, ISchema input, F
for (int j = 0; j < src.Length; j++)
{
Contracts.CheckUserArg(!string.IsNullOrWhiteSpace(src[j]), nameof(ManyToOneColumn.Source));
-#pragma warning disable TLC_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
+#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
if (!input.TryGetColumnIndex(src[j], out srcIndices[j]))
throw Contracts.ExceptUserArg(standardColumnArgName, "Source column '{0}' not found", src[j]);
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
srcTypes[j] = input.GetColumnType(srcIndices[j]);
var size = srcTypes[j].ValueCount;
srcSize = size == 0 ? null : checked(srcSize + size);
@@ -700,10 +700,10 @@ protected ManyToOneColumnBindingsBase(ManyToOneColumn[] column, ISchema input, F
string reason = testTypes(srcTypes);
if (reason != null)
{
-#pragma warning disable TLC_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
+#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
throw Contracts.ExceptUserArg(standardColumnArgName, "Column '{0}' has invalid source types: {1}. Source types: '{2}'.",
item.Name, reason, string.Join(", ", srcTypes.Select(type => type.ToString())));
-#pragma warning restore TLC_ContractsNameUsesNameof
+#pragma warning restore MSML_ContractsNameUsesNameof
}
}
Infos[i] = new ColInfo(srcSize.GetValueOrDefault(), srcIndices, srcTypes);
@@ -861,7 +861,7 @@ public Func GetDependencies(Func predicate)
}
///
- /// Parsing utilities for converting between transform column argument objects and
+ /// Parsing utilities for converting between transform column argument objects and
/// command line representations.
///
public static class ColumnParsingUtils
diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
index 544bce0aeb..b2024cc18c 100644
--- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
@@ -55,7 +55,7 @@ public sealed class TaggedColumn
public string Name;
// The tag here (the key of the KeyValuePair) is the string that will be the prefix of the slot name
- // in the output column. For non-vector columns, the slot name will be either the column name or the
+ // in the output column. For non-vector columns, the slot name will be either the column name or the
// tag if it is non empty. For vector columns, the slot names will be 'ColumnName.SlotName' if the
// tag is empty, 'Tag.SlotName' if tag is non empty, and simply the slot name if tag is non empty
// and equal to the column name.
@@ -90,6 +90,19 @@ public bool TryUnparse(StringBuilder sb)
public sealed class Arguments : TransformInputBase
{
+ public Arguments()
+ {
+ }
+
+ public Arguments(string name, params string[] source)
+ {
+ Column = new[] { new Column()
+ {
+ Name = name,
+ Source = source
+ }};
+ }
+
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
}
@@ -232,11 +245,8 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
{
// All meta-data is passed through in this case, so don't need the slot names type.
echoSrc[i] = true;
- DvBool b = DvBool.False;
isNormalized[i] =
- info.SrcTypes[0].ItemType.IsNumber &&
- Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, info.SrcIndices[0], ref b) &&
- b.IsTrue;
+ info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]);
types[i] = info.SrcTypes[0];
continue;
}
@@ -247,9 +257,7 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
{
foreach (var srcCol in info.SrcIndices)
{
- DvBool b = DvBool.False;
- if (!Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, srcCol, ref b) ||
- !b.IsTrue)
+ if (!Input.IsNormalized(srcCol))
{
isNormalized[i] = false;
break;
@@ -497,7 +505,7 @@ private void GetSlotNames(int iinfo, ref VBuffer dst)
}
}
- public const string Summary = "Concatenates two columns of the same item type.";
+ public const string Summary = "Concatenates one or more columns of the same item type.";
public const string UserName = "Concat Transform";
public const string LoadName = "Concat";
@@ -527,6 +535,18 @@ private static VersionInfo GetVersionInfo()
public override ISchema Schema => _bindings;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Input columns to concatenate.
+ public ConcatTransform(IHostEnvironment env, IDataView input, string name, params string[] source)
+ : this(env, new Arguments(name, source), input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
@@ -700,13 +720,10 @@ public void SaveAsOnnx(OnnxContext ctx)
Source.Schema.GetColumnType(srcIndex).ValueCount));
}
- var node = OnnxUtils.MakeNode(opType, new List(inputList.Select(t => t.Key)),
- new List { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType));
-
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, inputList.Select(t => t.Key),
+ new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType));
- OnnxUtils.NodeAddAttributes(node, "inputList", inputList.Select(x => x.Key));
- OnnxUtils.NodeAddAttributes(node, "inputdimensions", inputList.Select(x => x.Value));
+ node.AddAttribute("inputdimensions", inputList.Select(x => x.Value));
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
index c37f0a6983..52005c7558 100644
--- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
@@ -169,6 +169,23 @@ private static VersionInfo GetVersionInfo()
// This is parallel to Infos.
private readonly ColInfoEx[] _exes;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// The expected type of the converted column.
+ /// Name of the output column.
+ /// Name of the column to be converted. If this is null '' will be used.
+ public ConvertTransform(IHostEnvironment env,
+ IDataView input,
+ DataKind resultType,
+ string name,
+ string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ResultType = resultType }, input)
+ {
+ }
+
public ConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
input, null)
diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
index f365dd9e98..2729a48e3e 100644
--- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
@@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "CopyColumns";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be copied.
+ public CopyColumnsTransform(IHostEnvironment env, IDataView input, string name, string source)
+ : this(env, new Arguments(){ Column = new[] { new Column() { Source = source, Name = name }}}, input)
+ {
+ }
+
public CopyColumnsTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, null)
{
diff --git a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
index 502e6f395d..3e15199ff7 100644
--- a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
@@ -237,6 +237,17 @@ private static VersionInfo GetVersionInfo()
private const string DropRegistrationName = "DropColumns";
private const string KeepRegistrationName = "KeepColumns";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the columns to be dropped.
+ public DropColumnsTransform(IHostEnvironment env, IDataView input, params string[] columnsToDrop)
+ :this(env, new Arguments() { Column = columnsToDrop }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
@@ -383,4 +394,17 @@ public ValueGetter GetGetter(int col)
}
}
}
+
+ public class KeepColumnsTransform
+ {
+ ///
+ /// A helper method to create for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the columns to be kept. All other columns will be removed.
+ ///
+ public static IDataTransform Create(IHostEnvironment env, IDataView input, params string[] columnsToKeep)
+ => new DropColumnsTransform(env, new DropColumnsTransform.KeepArguments() { Column = columnsToKeep }, input);
+ }
}
diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
index 9a40f404ea..230cfbe680 100644
--- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
@@ -313,7 +313,7 @@ private void GetSlotsMinMax(Column col, out int[] slotsMin, out int[] slotsMax)
slotsMin[j] = range.Min;
// There are two reasons for setting the max to int.MaxValue - 1:
// 1. max is an index, so it has to be strictly less than int.MaxValue.
- // 2. to prevent overflows when adding 1 to it.
+ // 2. to prevent overflows when adding 1 to it.
slotsMax[j] = range.Max ?? int.MaxValue - 1;
}
Array.Sort(slotsMin, slotsMax);
@@ -473,7 +473,7 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots
// Six possible ways a drop slot range interacts with categorical slots range.
//
- // +--------------Drop-------------+
+ // +--------------Drop-------------+
// | |
//
// +---Drop---+ +---Drop---+ +---Drop---+
diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
index f80589bdab..cacd681141 100644
--- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
@@ -24,9 +24,9 @@
namespace Microsoft.ML.Runtime.Data
{
///
- /// This transform adds columns containing either random numbers distributed
+ /// This transform adds columns containing either random numbers distributed
/// uniformly between 0 and 1 or an auto-incremented integer starting at zero.
- /// It will be used in conjunction with a filter transform to create random
+ /// It will be used in conjunction with a filter transform to create random
/// partitions of the data, used in cross validation.
///
public sealed class GenerateNumberTransform : RowToRowTransformBase
@@ -77,16 +77,22 @@ private bool TryParse(string str)
}
}
+ private static class Defaults
+ {
+ public const bool UseCounter = false;
+ public const uint Seed = 42;
+ }
+
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:seed)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
[Argument(ArgumentType.AtMostOnce, HelpText = "Use an auto-incremented integer starting at zero instead of a random number", ShortName = "cnt")]
- public bool UseCounter;
+ public bool UseCounter = Defaults.UseCounter;
[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed")]
- public uint Seed = 42;
+ public uint Seed = Defaults.Seed;
}
private sealed class Bindings : ColumnBindingsBase
@@ -250,6 +256,18 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "GenerateNumber";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Use an auto-incremented integer starting at zero instead of a random number.
+ public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter)
+ : this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
index ca959069f7..0519428284 100644
--- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
@@ -25,7 +25,7 @@ namespace Microsoft.ML.Runtime.Data
///
/// This transform can hash either single valued columns or vector columns. For vector columns,
- /// it hashes each slot separately.
+ /// it hashes each slot separately.
/// It can hash either text values or key values.
///
public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
@@ -33,6 +33,14 @@ public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
public const int NumBitsMin = 1;
public const int NumBitsLim = 32;
+ private static class Defaults
+ {
+ public const int HashBits = NumBitsLim - 1;
+ public const uint Seed = 314489979;
+ public const bool Ordered = false;
+ public const int InvertHash = 0;
+ }
+
public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col",
@@ -41,18 +49,18 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive",
ShortName = "bits", SortOrder = 2)]
- public int HashBits = NumBitsLim - 1;
+ public int HashBits = Defaults.HashBits;
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
- public uint Seed = 314489979;
+ public uint Seed = Defaults.Seed;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash",
ShortName = "ord")]
- public bool Ordered;
+ public bool Ordered = Defaults.Ordered;
[Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
- public int InvertHash;
+ public int InvertHash = Defaults.InvertHash;
}
public sealed class Column : OneToOneColumn
@@ -234,6 +242,27 @@ public override void Save(ModelSaveContext ctx)
TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues);
}
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// Number of bits to hash into. Must be between 1 and 31, inclusive.
+ /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
+ public HashTransform(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ int hashBits = Defaults.HashBits,
+ int invertHash = Defaults.InvertHash)
+ : this(env, new Arguments() {
+ Column = new[] { new Column() { Source = source ?? name, Name = name } },
+ HashBits = hashBits, InvertHash = invertHash }, input)
+ {
+ }
+
public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column,
input, TestType)
diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
index 7a7e8fafda..d615b96894 100644
--- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
+++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
@@ -265,7 +265,7 @@ public VBuffer GetMetadata()
public void Add(int dstSlot, ValueGetter getter, ref T key)
{
- // REVIEW: I only call the getter if I determine I have to, but
+ // REVIEW: I only call the getter if I determine I have to, but
// at the cost of passing along this getter and ref argument (as opposed
// to just the argument). Is this really appropriate or helpful?
Contracts.Assert(0 <= dstSlot && dstSlot < _slots);
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
index 165ab7e0df..997fa22d03 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
@@ -73,6 +73,18 @@ private static VersionInfo GetVersionInfo()
private readonly ColumnType[] _types;
private KeyToValueMap[] _kvMaps;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ public KeyToValueTransform(IHostEnvironment env, IDataView input, string name, string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
index bffbaa881c..0f4b616a49 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
@@ -70,6 +70,11 @@ public bool TryUnparse(StringBuilder sb)
}
}
+ private static class Defaults
+ {
+ public const bool Bag = false;
+ }
+
public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
@@ -77,7 +82,7 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")]
- public bool Bag;
+ public bool Bag = Defaults.Bag;
}
internal const string Summary = "Converts a key column to an indicator vector.";
@@ -112,6 +117,23 @@ private static VersionInfo GetVersionInfo()
private readonly bool[] _concat;
private readonly VectorType[] _types;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ /// Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.
+ public KeyToVectorTransform(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ bool bag = Defaults.Bag)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
@@ -244,10 +266,9 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo
protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
string opType = "OneHotEncoder";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- OnnxUtils.NodeAddAttributes(node, "cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x));
- OnnxUtils.NodeAddAttributes(node, "zeros", true);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x));
+ node.AddAttribute("zeros", true);
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
index 5329d89a57..8817833f40 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
@@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "LabelConvert";
private VectorType _slotType;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
+ {
+ }
+
public LabelConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, RowCursorUtils.TestGetLabelGetter)
{
diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
index 81a91b5f17..a7672b5a1c 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
@@ -111,6 +111,23 @@ private static string TestIsMulticlassLabel(ColumnType type)
return $"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double.";
}
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Label of the positive class.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ public LabelIndicatorTransform(IHostEnvironment env,
+ IDataView input,
+ int classIndex,
+ string name,
+ string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input)
+ {
+ }
+
public LabelIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column,
input, TestIsMulticlassLabel)
diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
index 96c2111366..c8515291f3 100644
--- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
@@ -26,15 +26,21 @@
namespace Microsoft.ML.Runtime.Data
{
+ ///
public sealed class NAFilter : FilterBase
{
+ private static class Defaults
+ {
+ public const bool Complement = false;
+ }
+
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Column", ShortName = "col", SortOrder = 1)]
public string[] Column;
[Argument(ArgumentType.Multiple, HelpText = "If true, keep only rows that contain NA values, and filter the rest.")]
- public bool Complement;
+ public bool Complement = Defaults.Complement;
}
private sealed class ColInfo
@@ -72,6 +78,18 @@ private static VersionInfo GetVersionInfo()
private readonly bool _complement;
private const string RegistrationName = "MissingValueFilter";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// If true, keep only rows that contain NA values, and filter the rest.
+ /// Name of the columns. Only these columns will be used to filter rows having 'NA' values.
+ public NAFilter(IHostEnvironment env, IDataView input, bool complement = Defaults.Complement, params string[] columns)
+ : this(env, new Arguments() { Column = columns, Complement = complement }, input)
+ {
+ }
+
public NAFilter(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, input)
{
diff --git a/src/Microsoft.ML.Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
similarity index 88%
rename from src/Microsoft.ML.Transforms/NormalizeColumn.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
index a5769ec90a..a01b584a97 100644
--- a/src/Microsoft.ML.Transforms/NormalizeColumn.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
@@ -135,12 +135,21 @@ public bool TryUnparse(StringBuilder sb)
}
}
+ private static class Defaults
+ {
+ public const bool FixZero = true;
+ public const bool MeanVarCdf = false;
+ public const bool LogMeanVarCdf = true;
+ public const int NumBins = 1024;
+ public const int MinBinSize = 10;
+ }
+
public abstract class FixZeroArgumentsBase : ArgumentsBase
{
// REVIEW: This only allows mapping either zero or min to zero. It might make sense to allow also max, midpoint and mean to be mapped to zero.
// REVIEW: Convert this to bool? or even an enum{Auto, No, Yes}, and automatically map zero to zero when it is null/Auto.
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to map zero to zero, preserving sparsity", ShortName = "zero")]
- public bool FixZero = true;
+ public bool FixZero = Defaults.FixZero;
}
public abstract class AffineArgumentsBase : FixZeroArgumentsBase
@@ -158,13 +167,13 @@ public sealed class MinMaxArguments : AffineArgumentsBase
public sealed class MeanVarArguments : AffineArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use CDF as the output", ShortName = "cdf")]
- public bool UseCdf;
+ public bool UseCdf = Defaults.MeanVarCdf;
}
public sealed class LogMeanVarArguments : ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use CDF as the output", ShortName = "cdf")]
- public bool UseCdf = true;
+ public bool UseCdf = Defaults.LogMeanVarCdf;
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public LogNormalColumn[] Column;
@@ -179,7 +188,7 @@ public abstract class BinArgumentsBase : FixZeroArgumentsBase
[Argument(ArgumentType.AtMostOnce, HelpText = "Max number of bins, power of 2 recommended", ShortName = "bins")]
[TGUI(Label = "Max number of bins")]
- public int NumBins = 1024;
+ public int NumBins = Defaults.NumBins;
public override OneToOneColumn[] GetColumns() => Column;
}
@@ -196,7 +205,7 @@ public sealed class SupervisedBinArguments : BinArgumentsBase
public string LabelColumn;
[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of examples per bin")]
- public int MinBinSize = 10;
+ public int MinBinSize = Defaults.MinBinSize;
}
public const string MinMaxNormalizerSummary = "Normalizes the data based on the observed minimum and maximum values of the data.";
@@ -218,6 +227,26 @@ public sealed class SupervisedBinArguments : BinArgumentsBase
public const string BinNormalizerShortName = "Bin";
public const string SupervisedBinNormalizerShortName = "SupBin";
+ ///
+ /// A helper method to create MinMaxNormalizer transform for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ public static NormalizeTransform CreateMinMaxNormalizer(IHostEnvironment env, IDataView input, string name, string source = null)
+ {
+ var args = new MinMaxArguments()
+ {
+ Column = new[] { new AffineColumn(){
+ Source = source ?? name,
+ Name = name
+ }
+ }
+ };
+ return Create(env, args, input);
+ }
+
///
/// Public create method corresponding to SignatureDataTransform.
///
@@ -234,6 +263,32 @@ public static NormalizeTransform Create(IHostEnvironment env, MinMaxArguments ar
return func;
}
+ ///
+ /// A helper method to create MeanVarNormalizer transform for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// /// Whether to use CDF as the output.
+ public static NormalizeTransform CreateMeanVarNormalizer(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ bool useCdf = Defaults.MeanVarCdf)
+ {
+ var args = new MeanVarArguments()
+ {
+ Column = new[] { new AffineColumn(){
+ Source = source ?? name,
+ Name = name
+ }
+ },
+ UseCdf = useCdf
+ };
+ return Create(env, args, input);
+ }
+
///
/// Public create method corresponding to SignatureDataTransform.
///
@@ -250,6 +305,32 @@ public static NormalizeTransform Create(IHostEnvironment env, MeanVarArguments a
return func;
}
+ ///
+ /// A helper method to create LogMeanVarNormalizer transform for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// /// Whether to use CDF as the output.
+ public static NormalizeTransform CreateLogMeanVarNormalizer(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ bool useCdf = Defaults.LogMeanVarCdf)
+ {
+ var args = new LogMeanVarArguments()
+ {
+ Column = new[] { new LogNormalColumn(){
+ Source = source ?? name,
+ Name = name
+ }
+ },
+ UseCdf = useCdf
+ };
+ return Create(env, args, input);
+ }
+
///
/// Public create method corresponding to SignatureDataTransform.
///
@@ -266,6 +347,24 @@ public static NormalizeTransform Create(IHostEnvironment env, LogMeanVarArgument
return func;
}
+ public static NormalizeTransform CreateBinningNormalizer(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ int numBins = Defaults.NumBins)
+ {
+ var args = new BinArguments()
+ {
+ Column = new[] { new BinColumn(){
+ Source = source ?? name,
+ Name = name
+ }
+ },
+ NumBins = numBins
+ };
+ return Create(env, args, input);
+ }
+
///
/// Public create method corresponding to SignatureDataTransform.
///
@@ -282,6 +381,28 @@ public static NormalizeTransform Create(IHostEnvironment env, BinArguments args,
return func;
}
+ public static NormalizeTransform CreateSupervisedBinningNormalizer(IHostEnvironment env,
+ IDataView input,
+ string labelColumn,
+ string name,
+ string source = null,
+ int numBins = Defaults.NumBins,
+ int minBinSize = Defaults.MinBinSize)
+ {
+ var args = new SupervisedBinArguments()
+ {
+ Column = new[] { new BinColumn(){
+ Source = source ?? name,
+ Name = name
+ }
+ },
+ LabelColumn = labelColumn,
+ NumBins = numBins,
+ MinBinSize = minBinSize
+ };
+ return Create(env, args, input);
+ }
+
///
/// Public create method corresponding to SignatureDataTransform.
///
@@ -313,8 +434,8 @@ private AffineColumnFunction(IHost host)
public abstract void Save(ModelSaveContext ctx);
public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
-
- public abstract bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount);
+ public bool CanSaveOnnx => true;
+ public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
public abstract Delegate GetGetter(IRow input, int icol);
@@ -425,10 +546,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return null;
}
- public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
- {
- return false;
- }
+ public bool CanSaveOnnx => false;
+
+ public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
+ => throw Host.ExceptNotSupp();
public abstract Delegate GetGetter(IRow input, int icol);
@@ -550,10 +671,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return null;
}
- public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
- {
- return false;
- }
+ public bool CanSaveOnnx => false;
+
+ public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
+ => throw Host.ExceptNotSupp();
public abstract Delegate GetGetter(IRow input, int icol);
diff --git a/src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
similarity index 98%
rename from src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
index 41e55ee338..6cad82c127 100644
--- a/src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
@@ -542,7 +542,7 @@ public ImplOne(IHost host, TFloat scale, TFloat offset)
{
}
- public new static ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R8.");
List nz = null;
@@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx)
public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
=> PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale);
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
{
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount));
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount));
+ nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount));
+ nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount));
return true;
}
@@ -605,7 +605,7 @@ public ImplVec(IHost host, TFloat[] scale, TFloat[] offset, int[] indicesNonZero
{
}
- public new static ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of R8.");
int cv = Math.Max(1, typeSrc.VectorSize);
@@ -648,12 +648,12 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType)));
}
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{
- if (Offset != null)
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset);
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale);
+ if (Offset != null)
+ node.AddAttribute("offset", Offset);
+ node.AddAttribute("scale", Scale);
return true;
}
@@ -867,7 +867,7 @@ public ImplOne(IHost host, TFloat mean, TFloat stddev, bool useLog)
{
}
- public new static ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R8.");
host.CheckValue(ctx, nameof(ctx));
@@ -932,7 +932,7 @@ public ImplVec(IHost host, TFloat[] mean, TFloat[] stddev, bool useLog)
{
}
- public new static ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of R8.");
int cv = Math.Max(1, typeSrc.VectorSize);
@@ -1051,7 +1051,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero)
Host.Assert(0 <= _offset & _offset <= 1);
}
- public new static ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R8.");
host.CheckValue(ctx, nameof(ctx));
@@ -1133,7 +1133,7 @@ public ImplVec(IHost host, TFloat[][] binUpperBounds, bool fixZero)
}
}
- public new static ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of R8.");
int cv = Math.Max(1, typeSrc.VectorSize);
@@ -1280,7 +1280,7 @@ private static void ComputeScaleAndOffset(TFloat max, TFloat min, out TFloat sca
// but infinities and NaN to NaN.
// REVIEW: If min <= 0 and max >= 0, then why not fix zero for this slot and simply scale by 1 / max(abs(..))?
// We could even be more aggressive about it, and fix zero if 0 < min < max <= 2 * min.
- // Then the common case where features are in the range [1, N] (and integer valued) wouldn't subtract 1 every time....
+ // Then the common case where features are in the range [1, N] (and integer valued) wouldn't subtract 1 every time....
if (!(max > min))
scale = offset = 0;
else if ((scale = 1 / (max - min)) == 0)
@@ -1302,7 +1302,7 @@ private static void ComputeScaleAndOffsetFixZero(TFloat max, TFloat min, out TFl
// In the case where max <= min, the slot contains no useful information (since it is either constant, or
// is all NaNs, or has no rows), so we force it to zero.
// Note that setting scale to zero effectively maps finite values to zero,
- // but infinities and NaN to NaN.
+ // but infinities and NaN to NaN.
offset = 0;
if (!(max > min))
scale = 0;
@@ -1321,7 +1321,7 @@ public static void ComputeScaleAndOffset(Double mean, Double stddev, out TFloat
// In the case where stdev==0, the slot contains no useful information (since it is constant),
// so we force it to zero. Note that setting scale to zero effectively maps finite values to zero,
- // but infinities and NaN to NaN.
+ // but infinities and NaN to NaN.
if (stddev == 0)
scale = offset = 0;
else if ((scale = 1 / (TFloat)stddev) == 0)
@@ -1338,7 +1338,7 @@ public static void ComputeScaleAndOffsetFixZero(Double mean, Double meanSquaredE
// In the case where stdev==0, the slot contains no useful information (since it is constant),
// so we force it to zero. Note that setting scale to zero effectively maps finite values to zero,
- // but infinities and NaN to NaN.
+ // but infinities and NaN to NaN.
offset = 0;
if (meanSquaredError == 0)
scale = 0;
diff --git a/src/Microsoft.ML.Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
similarity index 98%
rename from src/Microsoft.ML.Transforms/NormalizeColumnSng.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
index ef5eef8551..af94f31454 100644
--- a/src/Microsoft.ML.Transforms/NormalizeColumnSng.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
@@ -542,7 +542,7 @@ public ImplOne(IHost host, TFloat scale, TFloat offset)
{
}
- public new static ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R4.");
List nz = null;
@@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx)
public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
=> PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale);
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount));
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount));
+ node.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount));
+ node.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount));
return true;
}
@@ -605,7 +605,7 @@ public ImplVec(IHost host, TFloat[] scale, TFloat[] offset, int[] indicesNonZero
{
}
- public new static ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of R4.");
int cv = Math.Max(1, typeSrc.VectorSize);
@@ -648,14 +648,14 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType)));
}
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{
if (Offset != null)
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset);
+ node.AddAttribute("offset", Offset);
else
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(0, featureCount));
+ node.AddAttribute("offset", Enumerable.Repeat(0, featureCount));
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale);
+ node.AddAttribute("scale", Scale);
return true;
}
@@ -869,7 +869,7 @@ public ImplOne(IHost host, TFloat mean, TFloat stddev, bool useLog)
{
}
- public new static ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R4.");
host.CheckValue(ctx, nameof(ctx));
@@ -934,7 +934,7 @@ public ImplVec(IHost host, TFloat[] mean, TFloat[] stddev, bool useLog)
{
}
- public new static ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of R4.");
int cv = Math.Max(1, typeSrc.VectorSize);
@@ -1053,7 +1053,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero)
Host.Assert(0 <= _offset & _offset <= 1);
}
- public new static ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplOne Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R4.");
host.CheckValue(ctx, nameof(ctx));
@@ -1135,7 +1135,7 @@ public ImplVec(IHost host, TFloat[][] binUpperBounds, bool fixZero)
}
}
- public new static ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
+ public static new ImplVec Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
{
host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of R4.");
int cv = Math.Max(1, typeSrc.VectorSize);
@@ -1282,7 +1282,7 @@ private static void ComputeScaleAndOffset(TFloat max, TFloat min, out TFloat sca
// but infinities and NaN to NaN.
// REVIEW: If min <= 0 and max >= 0, then why not fix zero for this slot and simply scale by 1 / max(abs(..))?
// We could even be more aggressive about it, and fix zero if 0 < min < max <= 2 * min.
- // Then the common case where features are in the range [1, N] (and integer valued) wouldn't subtract 1 every time....
+ // Then the common case where features are in the range [1, N] (and integer valued) wouldn't subtract 1 every time....
if (!(max > min))
scale = offset = 0;
else if ((scale = 1 / (max - min)) == 0)
@@ -1304,7 +1304,7 @@ private static void ComputeScaleAndOffsetFixZero(TFloat max, TFloat min, out TFl
// In the case where max <= min, the slot contains no useful information (since it is either constant, or
// is all NaNs, or has no rows), so we force it to zero.
// Note that setting scale to zero effectively maps finite values to zero,
- // but infinities and NaN to NaN.
+ // but infinities and NaN to NaN.
offset = 0;
if (!(max > min))
scale = 0;
@@ -1323,7 +1323,7 @@ public static void ComputeScaleAndOffset(Double mean, Double stddev, out TFloat
// In the case where stdev==0, the slot contains no useful information (since it is constant),
// so we force it to zero. Note that setting scale to zero effectively maps finite values to zero,
- // but infinities and NaN to NaN.
+ // but infinities and NaN to NaN.
if (stddev == 0)
scale = offset = 0;
else if ((scale = 1 / (TFloat)stddev) == 0)
@@ -1340,7 +1340,7 @@ public static void ComputeScaleAndOffsetFixZero(Double mean, Double meanSquaredE
// In the case where stdev==0, the slot contains no useful information (since it is constant),
// so we force it to zero. Note that setting scale to zero effectively maps finite values to zero,
- // but infinities and NaN to NaN.
+ // but infinities and NaN to NaN.
offset = 0;
if (meanSquaredError == 0)
scale = 0;
diff --git a/src/Microsoft.ML.Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
similarity index 84%
rename from src/Microsoft.ML.Transforms/NormalizeTransform.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
index e5602918b6..bf9d77ed49 100644
--- a/src/Microsoft.ML.Transforms/NormalizeTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
@@ -66,7 +66,9 @@ public interface IColumnFunction : ICanSaveModel
JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
- bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount);
+ bool CanSaveOnnx { get; }
+
+ bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
}
public sealed partial class NormalizeTransform : OneToOneTransformBase
@@ -168,7 +170,7 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input,
while (cursor.MoveNext())
{
// If the row has bad values, the good values are still being used for training.
- // The comparisons in the code below are arranged so that NaNs in the input are not recorded.
+ // The comparisons in the code below are arranged so that NaNs in the input are not recorded.
// REVIEW: Should infinities and/or NaNs be filtered before the normalization? Should we not record infinities for min/max?
// Currently, infinities are recorded and will result in zero scale which in turn will result in NaN output for infinity input.
bool any = false;
@@ -197,6 +199,33 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input,
SetMetadata();
}
+ ///
+ /// Potentially apply a min-max normalizer to the data's feature column, keeping all existing role
+ /// mappings except for the feature role mapping.
+ ///
+ /// The host environment to use to potentially instantiate the transform
+ /// The role-mapped data that is potentially going to be modified by this method.
+ /// The trainer to query as to whether it wants normalization. If the
+ /// 's is true
+ /// True if the normalizer was applied and was modified
+ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(data, nameof(data));
+ env.CheckValue(trainer, nameof(trainer));
+
+ // If the trainer does not need normalization, or if the features either don't exist
+ // or are not normalized, return false.
+ if (!trainer.Info.NeedNormalization || data.Schema.FeaturesAreNormalized() != false)
+ return false;
+ var featInfo = data.Schema.Feature;
+ env.AssertValue(featInfo); // Should be defined, if FeaturesAreNormalized returned a definite value.
+
+ var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);
+ data = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
+ return true;
+ }
+
private NormalizeTransform(IHost host, ModelLoadContext ctx, IDataView input)
: base(host, ctx, input, null)
{
@@ -212,7 +241,7 @@ private NormalizeTransform(IHost host, ModelLoadContext ctx, IDataView input)
for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
{
var typeSrc = Infos[iinfo].TypeSrc;
- // REVIEW: this check (was even an assert) here is too late. Apparently, no-one tests compatibility
+ // REVIEW: this check (was even an assert) here is too late. Apparently, no-one tests compatibility
// of the types at deserialization (aka re-application), which is a bug.
if (typeSrc.ValueCount == 0)
throw Host.Except("Column '{0}' is a vector of variable size, which is not supported for normalizers", Infos[iinfo].Name);
@@ -286,11 +315,11 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
if (info.TypeSrc.ValueCount == 0)
return false;
- string opType = "Scaler";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- if (_functions[iinfo].OnnxInfo(ctx, new OnnxUtils.NodeProtoWrapper(node), info.TypeSrc.ValueCount))
+ if (_functions[iinfo].CanSaveOnnx)
{
- ctx.AddNode(node);
+ string opType = "Scaler";
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ _functions[iinfo].OnnxInfo(ctx, node, info.TypeSrc.ValueCount);
return true;
}
@@ -329,6 +358,30 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
}
}
+ public static class NormalizeUtils
+ {
+ ///
+ /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not
+ /// specified on the schema, then this will return null.
+ ///
+ /// The role-mapped schema to query
+ /// Returns null if does not have
+ /// defined, and otherwise returns a Boolean value as returned from
+ /// on that feature column
+ ///
+ public static bool? FeaturesAreNormalized(this RoleMappedSchema schema)
+ {
+ // REVIEW: The role mapped data has the ability to have multiple columns fill the role of features, which is
+ // useful in some trainers that are nonetheless parameteric and can therefore benefit from normalization.
+ Contracts.CheckValue(schema, nameof(schema));
+ var featInfo = schema.Feature;
+ return featInfo == null ? default(bool?) : schema.Schema.IsNormalized(featInfo.Index);
+ }
+ }
+
+ ///
+ /// This contains entry-point definitions related to .
+ ///
public static class Normalize
{
[TlcModule.EntryPoint(Name = "Transforms.MinMaxNormalizer", Desc = NormalizeTransform.MinMaxNormalizerSummary, UserName = NormalizeTransform.MinMaxNormalizerUserName, ShortName = NormalizeTransform.MinMaxNormalizerShortName)]
@@ -402,14 +455,10 @@ public static CommonOutputs.TransformOutput SupervisedBin(IHostEnvironment env,
var columnsToNormalize = new List();
foreach (var column in input.Column)
{
- int col;
- if (!schema.TryGetColumnIndex(column.Source, out col))
+ if (!schema.TryGetColumnIndex(column.Source, out int col))
throw env.ExceptUserArg(nameof(input.Column), $"Column '{column.Source}' does not exist.");
- if (!schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, col, ref isNormalized) ||
- isNormalized.IsFalse)
- {
+ if (!schema.IsNormalized(col))
columnsToNormalize.Add(column);
- }
}
var entryPoints = new List();
diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs
index 35f37d39a8..7b42008b15 100644
--- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs
@@ -308,7 +308,7 @@ protected override bool MoveNextCore()
if (!_newGroupInInputCursorDel())
return true;
- // If this is the first step, we need to move next on _groupCursor. Otherwise, the position of _groupCursor is
+ // If this is the first step, we need to move next on _groupCursor. Otherwise, the position of _groupCursor is
// at the start of the next group.
if (_groupCursor.State == CursorState.NotStarted)
{
diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
index b9ab10f4c1..142779dee2 100644
--- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
@@ -77,6 +77,19 @@ private static VersionInfo GetVersionInfo()
private readonly bool _includeMin;
private readonly bool _includeMax;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the input column.
+ /// Minimum value (0 to 1 for key types).
+ /// Maximum value (0 to 1 for key types).
+ public RangeFilter(IHostEnvironment env, IDataView input, string column, Double? minimum = null, Double? maximum = null)
+ : this(env, new Arguments() { Column = column, Min = minimum, Max = maximum }, input)
+ {
+ }
+
public RangeFilter(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, input)
{
@@ -171,9 +184,9 @@ public override void Save(ModelSaveContext ctx)
// int: id of column name
// double: min
// double: max
- // byte: complement
- // byte: includeMin
- // byte: includeMax
+ // byte: complement
+ // byte: includeMin
+ // byte: includeMax
ctx.Writer.Write(sizeof(Float));
ctx.SaveNonEmptyString(Source.Schema.GetColumnName(_index));
Host.Assert(_min < _max);
diff --git a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
index 37e52ee2da..3940bbe979 100644
--- a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
@@ -33,18 +33,25 @@ namespace Microsoft.ML.Runtime.Data
///
public sealed class ShuffleTransform : RowToRowTransformBase
{
+ private static class Defaults
+ {
+ public const int PoolRows = 1000;
+ public const bool PoolOnly = false;
+ public const bool ForceShuffle = false;
+ }
+
public sealed class Arguments
{
// REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps?
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")]
- public int PoolRows = 1000;
+ public int PoolRows = Defaults.PoolRows;
// REVIEW: Come up with a better way to specify the desired set of functionality.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.", ShortName = "po")]
- public bool PoolOnly;
+ public bool PoolOnly = Defaults.PoolOnly;
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always provide a shuffled view.", ShortName = "force")]
- public bool ForceShuffle;
+ public bool ForceShuffle = Defaults.ForceShuffle;
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always shuffle the input. The default value is the same as forceShuffle.", ShortName = "forceSource")]
public bool? ForceShuffleSource;
@@ -79,6 +86,23 @@ private static VersionInfo GetVersionInfo()
// know how to copy other types of values.
private readonly IDataView _subsetInput;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// The pool will have this many rows
+ /// If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.
+ /// If true, the transform will always provide a shuffled view.
+ public ShuffleTransform(IHostEnvironment env,
+ IDataView input,
+ int poolRows = Defaults.PoolRows,
+ bool poolOnly = Defaults.PoolOnly,
+ bool forceShuffle = Defaults.ForceShuffle)
+ : this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
@@ -236,7 +260,7 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando
// The desired functionality is to support some permutations of whether we allow
// shuffling at the source level, or not.
- //
+ //
// Pool | Source | Options
// -----------+----------+--------
// Randonly | Never | poolOnly+
@@ -277,14 +301,14 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid
/// over a pool of size P. Logically, externally, the cursor acts as if you have this pool
/// P and whenever you randomly sample and yield a row from it, that row is then discarded
/// and replaced with the next row from the input source cursor.
- ///
+ ///
/// It would also be possible to implement in a way that cleaves closely to this logical
/// interpretation, but this would be inefficient. We instead have a buffer of larger size
/// P+B. A consumer (running presumably in the main thread) sampling and fetching items and a
/// producer (running in a task, which may be running in a different thread) filling the buffer
/// with items to sample, utilizing this extra space to enable an efficient possibly
/// multithreaded scheme.
- ///
+ ///
/// The consumer, for its part, at any given time "owns" a contiguous portion of this buffer.
/// (A contiguous portion of this buffer we consider to be able to wrap around, from the end
/// to the beginning. The buffer is accessed in a "circular" fashion.) Consider that this portion
@@ -295,18 +319,18 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid
/// rows ready to be sampled in future iterations, but that we are not sampling yet (in order
/// to behave equivalently to the simple logical model of at any given time sampling P items).
/// The producer owns the complement of the portion owned by the consumer.
- ///
+ ///
/// As the cursor progresses, the producer fills in successive items in its portion of the
/// buffer it owns, and passes them off to the consumer (not one item at a time, but rather in
/// batches, to keep down the amount of intertask communication). The consumer in addition to
/// taking ownership of these items, will also periodically pass dead items back to the producer
/// (again, not one dead item at a time, but in batches when the number of dead items reaches
/// a certain threshold).
- ///
+ ///
/// This communication is accomplished using a pair of BufferBlock instances, through which
/// the producer and consumer are notified how many additional items they can take ownership
/// of.
- ///
+ ///
/// As the consumer "selects" a row from the pool of selectable rows each time it moves to
/// the next row, this randomly selected row is considered to be the "first" index, since this
/// makes its subsequent transition to being a dead row much simpler. It would be inefficient to
@@ -314,7 +338,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid
/// first, of course, so one rather swaps an index, so that these nicely behavior contiguous
/// circular indices, get mapped in an index within the buffers, through a permutation maintained
/// in the pipeIndices array.
- ///
+ ///
/// The result is something functionally equivalent to but but considerably faster than the
/// simple implementation described in the first paragraph.
///
diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
index 278f3ee418..2adb17258e 100644
--- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
@@ -60,13 +60,13 @@ public sealed class Arguments : TransformInputBase
public sealed class TakeArguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = Arguments.TakeHelp, ShortName = "c,n,t", SortOrder = 1)]
- public long Count = long.MaxValue;
+ public long Count = Arguments.DefaultTake;
}
public sealed class SkipArguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = Arguments.SkipHelp, ShortName = "c,n,s", SortOrder = 1)]
- public long Count = 0;
+ public long Count = Arguments.DefaultSkip;
}
private static VersionInfo GetVersionInfo()
@@ -164,7 +164,7 @@ public override void Save(ModelSaveContext ctx)
public override bool CanShuffle { get { return false; } }
///
- /// Returns the computed count of rows remaining after skip and take operation.
+ /// Returns the computed count of rows remaining after skip and take operation.
/// Returns null if count is unknown.
///
public override long? GetRowCount(bool lazy = true)
@@ -270,4 +270,30 @@ protected override bool MoveManyCore(long count)
}
}
}
+
+ public static class SkipFilter
+ {
+ ///
+ /// A helper method to create transform for skipping the number of rows defined by the parameter.
+ /// when created with behaves as 'SkipFilter'.
+ ///
+ /// Host Environment.
+ /// >Input . This is the output from previous transform or loader.
+ /// Number of rows to skip
+ public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = SkipTakeFilter.Arguments.DefaultSkip)
+ => SkipTakeFilter.Create(env, new SkipTakeFilter.SkipArguments() { Count = count }, input);
+ }
+
+ public static class TakeFilter
+ {
+ ///
+ /// A helper method to create transform by taking the top rows defined by the parameter.
+ /// when created with behaves as 'TakeFilter'.
+ ///
+ /// Host Environment.
+ /// >Input . This is the output from previous transform or loader.
+ /// Number of rows to take
+ public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = SkipTakeFilter.Arguments.DefaultTake)
+ => SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments() { Count = count }, input);
+ }
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
index da7442f90e..6eaf48e995 100644
--- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
@@ -29,14 +29,14 @@
namespace Microsoft.ML.Runtime.Data
{
- ///
- /// TermTransform builds up term vocabularies (dictionaries).
- /// Notes:
- /// * Each column builds/uses exactly one "vocabulary" (dictionary).
- /// * Output columns are KeyType-valued.
- /// * The Key value is the one-based index of the item in the dictionary.
- /// * Not found is assigned the value zero.
- ///
+
+ // TermTransform builds up term vocabularies (dictionaries).
+ // Notes:
+ // * Each column builds/uses exactly one "vocabulary" (dictionary).
+ // * Output columns are KeyType-valued.
+ // * The Key value is the one-based index of the item in the dictionary.
+ // * Not found is assigned the value zero.
+ ///
public sealed partial class TermTransform : OneToOneTransformBase, ITransformTemplate
{
public abstract class ColumnBase : OneToOneColumn
@@ -97,10 +97,16 @@ public enum SortOrder : byte
// other things, like case insensitive (where appropriate), culturally aware, etc.?
}
+ private static class Defaults
+ {
+ public const int MaxNumTerms = 1000000;
+ public const SortOrder Sort = SortOrder.Occurrence;
+ }
+
public abstract class ArgumentsBase : TransformInputBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of terms to keep per column when auto-training", ShortName = "max", SortOrder = 5)]
- public int MaxNumTerms = 1000000;
+ public int MaxNumTerms = Defaults.MaxNumTerms;
[Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", SortOrder = 105, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string Terms;
@@ -124,7 +130,7 @@ public abstract class ArgumentsBase : TransformInputBase
// REVIEW: Should we always sort? Opinions are mixed. See work item 7797429.
[Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " +
"If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').", SortOrder = 113)]
- public SortOrder Sort = SortOrder.Occurrence;
+ public SortOrder Sort = Defaults.Sort;
// REVIEW: Should we do this here, or correct the various pieces of code here and in MRS etc. that
// assume key-values will be string? Once we correct these things perhaps we can see about removing it.
@@ -196,6 +202,26 @@ private CodecFactory CodecFactory
public override bool CanSavePfa => true;
public override bool CanSaveOnnx => true;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// Maximum number of terms to keep per column when auto-training.
+ /// How items should be ordered when vectorized. By default, they will be in the order encountered.
+ /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').
+ public TermTransform(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ int maxNumTerms = Defaults.MaxNumTerms,
+ SortOrder sort = Defaults.Sort)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, MaxNumTerms = maxNumTerms, Sort = sort }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
@@ -586,10 +612,10 @@ private TermTransform(IHost host, ModelLoadContext ctx, IDataView input)
termMap[i] = TermMap.TextImpl.Create(c, host);
}
});
-#pragma warning disable TLC_NoMessagesForLoadContext // Vaguely useful.
+#pragma warning disable MSML_NoMessagesForLoadContext // Vaguely useful.
if (!b)
throw Host.ExceptDecode("Missing {0} model", dir);
-#pragma warning restore TLC_NoMessagesForLoadContext
+#pragma warning restore MSML_NoMessagesForLoadContext
_termMap = new BoundTermMap[cinfo];
for (int i = 0; i < cinfo; ++i)
_termMap[i] = termMap[i].Bind(this, i);
@@ -690,11 +716,13 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
TermMap map = (TermMap)_termMap[iinfo].Map;
map.GetTerms(ref terms);
string opType = "LabelEncoder";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- OnnxUtils.NodeAddAttributes(node, "classes_strings", terms.DenseValues());
- OnnxUtils.NodeAddAttributes(node, "default_int64", -1);
- OnnxUtils.NodeAddAttributes(node, "default_string", DvText.Empty);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ node.AddAttribute("classes_strings", terms.DenseValues());
+ node.AddAttribute("default_int64", -1);
+ //default_string needs to be an empty string but there is a BUG in Lotus that
+ //throws a validation error when default_string is empty. As a work around, set
+ //default_string to a space.
+ node.AddAttribute("default_string", " ");
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs
index a81575b9c9..9a43dc5517 100644
--- a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs
@@ -447,10 +447,10 @@ private enum MapType : byte
/// type. The input type, whatever it is, must have as its input item
/// type, and will produce either , or a vector type with that output
/// type if the input was a vector.
- ///
+ ///
/// Note that instances of this class can be shared among multiple
/// instances. To associate this with a particular transform, use the method.
- ///
+ ///
/// These are the immutable and serializable analogs to the used in
/// training.
///
diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
index bdf1a36d41..bba2f58256 100644
--- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
@@ -181,7 +181,7 @@ private static RoleMappedData CreateDataFromArgs(IExceptionContext
var name = TrainUtils.MatchNameOrDefaultOrNull(ectx, schema, nameof(args.NameColumn), args.NameColumn,
DefaultColumnNames.Name);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ectx, args.CustomColumn);
- return TrainUtils.CreateExamples(input, label, feat, group, weight, name, customCols);
+ return new RoleMappedData(input, label, feat, group, weight, name, customCols);
}
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs
index 263a3cf4ca..2d9cedb17b 100644
--- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs
@@ -596,7 +596,7 @@ public void SaveAsOnnx(OnnxContext ctx)
///
/// Called by . Should be implemented by subclasses that return
- /// true from . Will be called
+ /// true from . Will be called
///
/// The context. Can be used to declare cells, access other information,
/// and whatnot. This method should not actually, however, declare the variable corresponding
diff --git a/src/Microsoft.ML.Data/Transforms/doc.xml b/src/Microsoft.ML.Data/Transforms/doc.xml
new file mode 100644
index 0000000000..13f108a107
--- /dev/null
+++ b/src/Microsoft.ML.Data/Transforms/doc.xml
@@ -0,0 +1,100 @@
+
+
+
+
+
+ Removes missing values from vector type columns.
+
+
+ This transform removes the entire row if any of the input columns have a missing value in that row.
+ This preprocessing is required for many ML algorithms that cannot work with missing values.
+ Useful if any missing entry invalidates the entire row.
+ If the is set to true, this transform would do the exact opposite,
+ it will keep only the rows that have missing values.
+
+
+
+
+
+
+ pipeline.Add(new MissingValuesRowDropper("Column1"));
+
+
+
+
+
+
+ Converts input values (words, numbers, etc.) to index in a dictionary.
+
+
+ The TextToKeyConverter transform builds up term vocabularies (dictionaries).
+ The TextToKeyConverter and the are the two one primary mechanisms by which raw input is transformed into keys.
+ If multiple columns are used, each column builds/uses exactly one vocabulary.
+ The output columns are KeyType-valued.
+ The Key value is the one-based index of the item in the dictionary.
+ If the key is not found in the dictionary, it is assigned the missing value indicator.
+ This dictionary mapping values to keys is most commonly learnt from the unique values in input data,
+ but can be defined through other means: either with the mapping defined directly on the command line, or as loaded from an external file.
+
+
+
+
+
+
+
+ pipeline.Add(new TextToKeyConverter(("Column", "OutColumn"))
+ {
+ Sort = TermTransformSortOrder.Occurrence
+ });
+
+
+
+
+
+
+ Handle missing values by replacing them with either the default value or the indicated value.
+
+
+ This transform handles missing values in the input columns. For each input column, it creates an output column
+ where the missing values are replaced by one of these specified values:
+
+ -
+ The default value of the appropriate type.
+
+ -
+ The mean value of the appropriate type.
+
+ -
+ The max value of the appropriate type.
+
+ -
+ The min value of the appropriate type.
+
+
+ The last three work only for numeric/TimeSpan/DateTime kind columns.
+
+ The output column can also optionally include an indicator vector for which slots were missing in the input column.
+ This can be done only when the indicator vector type can be converted to the input column type, i.e. only for numeric columns.
+
+
+ When computing the mean/max/min value, there is also an option to compute it over the whole column instead of per slot.
+ This option has a default value of true for variable length vectors, and false for known length vectors.
+ It can be changed to true for known length vectors, but it results in an error if changed to false for variable length vectors.
+
+
+
+
+
+
+
+
+ pipeline.Add(new MissingValueHandler("FeatureCol", "CleanFeatureCol")
+ {
+ ReplaceWith = NAHandleTransformReplacementKind.Mean
+ });
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs b/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs
index e0a138fc91..7d0695083b 100644
--- a/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs
+++ b/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs
@@ -73,7 +73,7 @@ public static IDataView ApplyAllTransformsToData(IHostEnvironment env, IDataView
// Backtrack the chain until we reach a chain start or a non-transform.
// REVIEW: we 'unwrap' the composite data loader here and step through its pipeline.
- // It's probably more robust to make CompositeDataLoader not even be an IDataView, this
+ // It's probably more robust to make CompositeDataLoader not even be an IDataView, this
// would force the user to do the right thing and unwrap on his end.
var cdl = chain as CompositeDataLoader;
if (cdl != null)
diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
index 3e39a0008a..5b99b173fa 100644
--- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
+++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
@@ -79,7 +79,7 @@ public static IDataView LoadPipeline(IHostEnvironment env, RepositoryReader rep,
}
///
- /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
+ /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
/// the resulting data. If there are no transforms in the stream, or if there's no DataLoader stream at all
/// (this can happen if the model is produced by old TL), returns the source data.
/// If the DataLoader stream is invalid, throws.
@@ -101,7 +101,7 @@ public static IDataView LoadTransforms(IHostEnvironment env, IDataView data, Str
}
///
- /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
+ /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
/// the resulting data. If there are no transforms in the stream, or if there's no DataLoader stream at all
/// (this can happen if the model is produced by old TL), returns the source data.
/// If the DataLoader stream is invalid, throws.
@@ -157,8 +157,8 @@ public static ModelSaveContext GetDataModelSavingContext(RepositoryWriter rep)
}
///
- /// Loads data view (loader and transforms) from if is set to true,
- /// otherwise loads loader only.
+ /// Loads data view (loader and transforms) from if is set to true,
+ /// otherwise loads loader only.
///
public static IDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool loadTransforms)
{
@@ -188,7 +188,7 @@ public static IDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep,
}
///
- /// REVIEW: consider adding an overload that returns
+ /// REVIEW: consider adding an overload that returns
/// Loads optionally feature names from the repository directory.
/// Returns false iff no stream was found for feature names, iff result is set to null.
///
@@ -338,11 +338,11 @@ public static RoleMappedSchema LoadRoleMappedSchemaOrNull(IHostEnvironment env,
if (roleMappings == null)
return null;
var pipe = ModelFileUtils.LoadLoader(h, rep, new MultiFileSource(null), loadTransforms: true);
- return RoleMappedSchema.Create(pipe.Schema, roleMappings);
+ return new RoleMappedSchema(pipe.Schema, roleMappings);
}
///
- /// The RepositoryStreamWrapper is a IMultiStreamSource wrapper of a Stream object in a repository.
+ /// The RepositoryStreamWrapper is a IMultiStreamSource wrapper of a Stream object in a repository.
/// It is used to deserialize RoleMappings.txt from a model zip file.
///
private sealed class RepositoryStreamWrapper : IMultiStreamSource
@@ -382,7 +382,7 @@ public Stream Open(int index)
public TextReader OpenTextReader(int index) { return new StreamReader(Open(index)); }
///
- /// A custom entry stream wrapper that includes custom dispose logic for disposing the entry
+ /// A custom entry stream wrapper that includes custom dispose logic for disposing the entry
/// when the stream is disposed.
///
private sealed class EntryStream : Stream
diff --git a/src/Microsoft.ML.Data/Utilities/SlotDropper.cs b/src/Microsoft.ML.Data/Utilities/SlotDropper.cs
index cd74463291..64b510a655 100644
--- a/src/Microsoft.ML.Data/Utilities/SlotDropper.cs
+++ b/src/Microsoft.ML.Data/Utilities/SlotDropper.cs
@@ -91,7 +91,7 @@ public ValueGetter> SubsetGetter(ValueGetter> getter)
}
///
- /// Drops slots from src and populates the dst with the resulting vector. Slots are
+ /// Drops slots from src and populates the dst with the resulting vector. Slots are
/// dropped based on min and max slots that were passed at the constructor.
///
public void DropSlots(ref VBuffer src, ref VBuffer dst)
diff --git a/src/Microsoft.ML.Data/Utils/IntSequencePool.cs b/src/Microsoft.ML.Data/Utils/IntSequencePool.cs
index 3efb038e6e..e27b297025 100644
--- a/src/Microsoft.ML.Data/Utils/IntSequencePool.cs
+++ b/src/Microsoft.ML.Data/Utils/IntSequencePool.cs
@@ -173,7 +173,7 @@ private int GetCore(uint[] sequence, int min, int lim, out uint hash)
Contracts.Assert(ibCur <= ibLim);
if (i >= lim)
{
- // Need to make sure that we have reached the end of the sequence in the pool at the
+ // Need to make sure that we have reached the end of the sequence in the pool at the
// same time that we reached the end of sequence.
if (ibCur == ibLim)
return idCur;
diff --git a/src/Microsoft.ML.Data/Utils/LossFunctions.cs b/src/Microsoft.ML.Data/Utils/LossFunctions.cs
index 7ff47a4f9e..7df431c3d1 100644
--- a/src/Microsoft.ML.Data/Utils/LossFunctions.cs
+++ b/src/Microsoft.ML.Data/Utils/LossFunctions.cs
@@ -124,9 +124,9 @@ public Float ComputeDualUpdateInvariant(Float scaledFeaturesNormSquared)
return 1 / Math.Max(1, (Float)0.25 + scaledFeaturesNormSquared);
}
- // REVIEW: this dual update uses a different log loss formulation,
+ // REVIEW: this dual update uses a different log loss formulation,
//although the two are equivalents if the labels are restricted to 0 and 1
- //Need to update so that it can handle probability label and true to the
+ //Need to update so that it can handle probability label and true to the
//definition, which is a smooth loss function
public Float DualUpdate(Float output, Float label, Float dual, Float invariant, int maxNumThreads)
{
diff --git a/src/Microsoft.ML.Ensemble/Batch.cs b/src/Microsoft.ML.Ensemble/Batch.cs
new file mode 100644
index 0000000000..e9c8fcf179
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Batch.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime.Ensemble
+{
+ public sealed class Batch
+ {
+ public readonly RoleMappedData TrainInstances;
+ public readonly RoleMappedData TestInstances;
+
+ public Batch(RoleMappedData trainData, RoleMappedData testData)
+ {
+ Contracts.CheckValue(trainData, nameof(trainData));
+ Contracts.CheckValue(testData, nameof(testData));
+ TrainInstances = trainData;
+ TestInstances = testData;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
new file mode 100644
index 0000000000..ae6c2adac6
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
@@ -0,0 +1,114 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+
+namespace Microsoft.ML.Runtime.Ensemble
+{
+ internal static class EnsembleUtils
+ {
+ ///
+ /// Return a dataset with non-selected features zeroed out.
+ ///
+ public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, BitArray features)
+ {
+ Contracts.AssertValue(host);
+ Contracts.AssertValue(data);
+ Contracts.Assert(data.Schema.Feature != null);
+ Contracts.AssertValue(features);
+
+ var type = data.Schema.Feature.Type;
+ Contracts.Assert(features.Length == type.VectorSize);
+ int card = Utils.GetCardinality(features);
+ if (card == type.VectorSize)
+ return data;
+
+ // REVIEW: This doesn't preserve metadata on the features column. Should it?
+ var name = data.Schema.Feature.Name;
+ var view = LambdaColumnMapper.Create(
+ host, "FeatureSelector", data.Data, name, name, type, type,
+ (ref VBuffer src, ref VBuffer dst) => SelectFeatures(ref src, features, card, ref dst));
+
+ var res = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
+ return res;
+ }
+
+ ///
+ /// Fill dst with values selected from src if the indices of the src values are set in includedIndices,
+ /// otherwise assign default(T). The length of dst will be equal to src.Length.
+ ///
+ public static void SelectFeatures(ref VBuffer src, BitArray includedIndices, int cardinality, ref VBuffer dst)
+ {
+ Contracts.Assert(Utils.Size(includedIndices) == src.Length);
+ Contracts.Assert(cardinality == Utils.GetCardinality(includedIndices));
+ Contracts.Assert(cardinality < src.Length);
+
+ var values = dst.Values;
+ var indices = dst.Indices;
+
+ if (src.IsDense)
+ {
+ if (cardinality >= src.Length / 2)
+ {
+ T defaultValue = default;
+ if (Utils.Size(values) < src.Length)
+ values = new T[src.Length];
+ for (int i = 0; i < src.Length; i++)
+ values[i] = !includedIndices[i] ? defaultValue : src.Values[i];
+ dst = new VBuffer(src.Length, values, indices);
+ }
+ else
+ {
+ if (Utils.Size(values) < cardinality)
+ values = new T[cardinality];
+ if (Utils.Size(indices) < cardinality)
+ indices = new int[cardinality];
+
+ int count = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ if (includedIndices[i])
+ {
+ Contracts.Assert(count < cardinality);
+ values[count] = src.Values[i];
+ indices[count] = i;
+ count++;
+ }
+ }
+
+ Contracts.Assert(count == cardinality);
+ dst = new VBuffer(src.Length, count, values, indices);
+ }
+ }
+ else
+ {
+ int valuesSize = Utils.Size(values);
+ int indicesSize = Utils.Size(indices);
+ if (valuesSize < src.Count || indicesSize < src.Count)
+ {
+ if (valuesSize < cardinality)
+ values = new T[cardinality];
+ if (indicesSize < cardinality)
+ indices = new int[cardinality];
+ }
+
+ int count = 0;
+ for (int i = 0; i < src.Count; i++)
+ {
+ if (includedIndices[src.Indices[i]])
+ {
+ values[count] = src.Values[i];
+ indices[count] = src.Indices[i];
+ count++;
+ }
+ }
+
+ dst = new VBuffer(src.Length, count, values, indices);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
new file mode 100644
index 0000000000..a9d7983adf
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
@@ -0,0 +1,406 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.IO.Compression;
+using System.Linq;
+using Microsoft.ML.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+
+[assembly: LoadableClass(typeof(void), typeof(EnsembleCreator), null, typeof(SignatureEntryPointModule), "CreateEnsemble")]
+
+namespace Microsoft.ML.Runtime.EntryPoints
+{
+ ///
+ /// A component to combine given models into an ensemble model.
+ ///
+ public static class EnsembleCreator
+ {
+ ///
+ /// These are the combiner options for binary and multi class classifiers.
+ ///
+ public enum ClassifierCombiner
+ {
+ Median,
+ Average,
+ Vote,
+ }
+
+ ///
+ /// These are the combiner options for regression and anomaly detection.
+ ///
+ public enum ScoreCombiner
+ {
+ Median,
+ Average,
+ }
+
+ public abstract class PipelineInputBase
+ {
+ [Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)]
+ public IPredictorModel[] Models;
+ }
+
+ public abstract class InputBase
+ {
+ [Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)]
+ public IPredictorModel[] Models;
+
+ [Argument(ArgumentType.AtMostOnce, ShortName = "validate", HelpText = "Whether to validate that all the pipelines are identical", SortOrder = 5)]
+ public bool ValidatePipelines = true;
+ }
+
+ public sealed class ClassifierInput : InputBase
+ {
+ [Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
+ public ClassifierCombiner ModelCombiner = ClassifierCombiner.Median;
+ }
+
+ public sealed class PipelineClassifierInput : PipelineInputBase
+ {
+ [Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
+ public ClassifierCombiner ModelCombiner = ClassifierCombiner.Median;
+ }
+
+ public sealed class RegressionInput : InputBase
+ {
+ [Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
+ public ScoreCombiner ModelCombiner = ScoreCombiner.Median;
+ }
+
+ public sealed class PipelineRegressionInput : PipelineInputBase
+ {
+ [Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
+ public ScoreCombiner ModelCombiner = ScoreCombiner.Median;
+ }
+
+ public sealed class PipelineAnomalyInput : PipelineInputBase
+ {
+ [Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
+ public ScoreCombiner ModelCombiner = ScoreCombiner.Average;
+ }
+
+ private static void GetPipeline(IHostEnvironment env, InputBase input, out IDataView startingData, out RoleMappedData transformedData)
+ {
+ Contracts.AssertValue(env);
+ env.AssertValue(input);
+ env.AssertNonEmpty(input.Models);
+
+ ISchema inputSchema = null;
+ startingData = null;
+ transformedData = null;
+ byte[][] transformedDataSerialized = null;
+ string[] transformedDataZipEntryNames = null;
+ for (int i = 0; i < input.Models.Length; i++)
+ {
+ var model = input.Models[i];
+
+ var inputData = new EmptyDataView(env, model.TransformModel.InputSchema);
+ model.PrepareData(env, inputData, out RoleMappedData transformedDataCur, out IPredictor pred);
+
+ if (inputSchema == null)
+ {
+ env.Assert(i == 0);
+ inputSchema = model.TransformModel.InputSchema;
+ startingData = inputData;
+ transformedData = transformedDataCur;
+ }
+ else if (input.ValidatePipelines)
+ {
+ using (var ch = env.Start("Validating pipeline"))
+ {
+ if (transformedDataSerialized == null)
+ {
+ ch.Assert(transformedDataZipEntryNames == null);
+ SerializeRoleMappedData(env, ch, transformedData, out transformedDataSerialized,
+ out transformedDataZipEntryNames);
+ }
+ CheckSamePipeline(env, ch, transformedDataCur, transformedDataSerialized, transformedDataZipEntryNames);
+ ch.Done();
+ }
+ }
+ }
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.BinaryEnsemble", Desc = "Combine binary classifiers into an ensemble", UserName = EnsembleTrainer.UserNameValue)]
+ public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, ClassifierInput input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("CombineModels");
+ host.CheckValue(input, nameof(input));
+ host.CheckNonEmpty(input.Models, nameof(input.Models));
+
+ GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);
+
+ var args = new EnsembleTrainer.Arguments();
+ switch (input.ModelCombiner)
+ {
+ case ClassifierCombiner.Median:
+ args.OutputCombiner = new MedianFactory();
+ break;
+ case ClassifierCombiner.Average:
+ args.OutputCombiner = new AverageFactory();
+ break;
+ case ClassifierCombiner.Vote:
+ args.OutputCombiner = new VotingFactory();
+ break;
+ default:
+ throw host.Except("Unknown combiner kind");
+ }
+
+ var trainer = new EnsembleTrainer(host, args);
+ var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing));
+
+ var predictorModel = new PredictorModel(host, transformedData, startingData, ensemble);
+
+ var output = new CommonOutputs.BinaryClassificationOutput { PredictorModel = predictorModel };
+ return output;
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.RegressionEnsemble", Desc = "Combine regression models into an ensemble", UserName = RegressionEnsembleTrainer.UserNameValue)]
+ public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionInput input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("CombineModels");
+ host.CheckValue(input, nameof(input));
+ host.CheckNonEmpty(input.Models, nameof(input.Models));
+
+ GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);
+
+ var args = new RegressionEnsembleTrainer.Arguments();
+ switch (input.ModelCombiner)
+ {
+ case ScoreCombiner.Median:
+ args.OutputCombiner = new MedianFactory();
+ break;
+ case ScoreCombiner.Average:
+ args.OutputCombiner = new AverageFactory();
+ break;
+ default:
+ throw host.Except("Unknown combiner kind");
+ }
+
+ var trainer = new RegressionEnsembleTrainer(host, args);
+ var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing));
+
+ var predictorModel = new PredictorModel(host, transformedData, startingData, ensemble);
+
+ var output = new CommonOutputs.RegressionOutput { PredictorModel = predictorModel };
+ return output;
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.BinaryPipelineEnsemble", Desc = "Combine binary classification models into an ensemble")]
+ public static CommonOutputs.BinaryClassificationOutput CreateBinaryPipelineEnsemble(IHostEnvironment env, PipelineClassifierInput input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("CombineModels");
+ host.CheckValue(input, nameof(input));
+ host.CheckNonEmpty(input.Models, nameof(input.Models));
+
+ IBinaryOutputCombiner combiner;
+ switch (input.ModelCombiner)
+ {
+ case ClassifierCombiner.Median:
+ combiner = new Median(host);
+ break;
+ case ClassifierCombiner.Average:
+ combiner = new Average(host);
+ break;
+ case ClassifierCombiner.Vote:
+ combiner = new Voting(host);
+ break;
+ default:
+ throw host.Except("Unknown combiner kind");
+ }
+ var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
+ return CreatePipelineEnsemble(host, input.Models, ensemble);
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.RegressionPipelineEnsemble", Desc = "Combine regression models into an ensemble")]
+ public static CommonOutputs.RegressionOutput CreateRegressionPipelineEnsemble(IHostEnvironment env, PipelineRegressionInput input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("CombineModels");
+ host.CheckValue(input, nameof(input));
+ host.CheckNonEmpty(input.Models, nameof(input.Models));
+
+ IRegressionOutputCombiner combiner;
+ switch (input.ModelCombiner)
+ {
+ case ScoreCombiner.Median:
+ combiner = new Median(host);
+ break;
+ case ScoreCombiner.Average:
+ combiner = new Average(host);
+ break;
+ default:
+ throw host.Except("Unknown combiner kind");
+ }
+ var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.Regression);
+ return CreatePipelineEnsemble(host, input.Models, ensemble);
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.MultiClassPipelineEnsemble", Desc = "Combine multiclass classifiers into an ensemble")]
+ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassPipelineEnsemble(IHostEnvironment env, PipelineClassifierInput input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("CombineModels");
+ host.CheckValue(input, nameof(input));
+ host.CheckNonEmpty(input.Models, nameof(input.Models));
+
+ IOutputCombiner> combiner;
+ switch (input.ModelCombiner)
+ {
+ case ClassifierCombiner.Median:
+ combiner = new MultiMedian(host, new MultiMedian.Arguments() { Normalize = true });
+ break;
+ case ClassifierCombiner.Average:
+ combiner = new MultiAverage(host, new MultiAverage.Arguments() { Normalize = true });
+ break;
+ case ClassifierCombiner.Vote:
+ combiner = new MultiVoting(host);
+ break;
+ default:
+ throw host.Except("Unknown combiner kind");
+ }
+ var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification);
+ return CreatePipelineEnsemble(host, input.Models, ensemble);
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.AnomalyPipelineEnsemble", Desc = "Combine anomaly detection models into an ensemble")]
+ public static CommonOutputs.AnomalyDetectionOutput CreateAnomalyPipelineEnsemble(IHostEnvironment env, PipelineAnomalyInput input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("CombineModels");
+ host.CheckValue(input, nameof(input));
+ host.CheckNonEmpty(input.Models, nameof(input.Models));
+
+ IRegressionOutputCombiner combiner;
+ switch (input.ModelCombiner)
+ {
+ case ScoreCombiner.Median:
+ combiner = new Median(host);
+ break;
+ case ScoreCombiner.Average:
+ combiner = new Average(host);
+ break;
+ default:
+ throw host.Except("Unknown combiner kind");
+ }
+ var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.AnomalyDetection);
+ return CreatePipelineEnsemble(host, input.Models, ensemble);
+ }
+
+ private static TOut CreatePipelineEnsemble(IHostEnvironment env, IPredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble)
+ where TOut : CommonOutputs.TrainerOutput, new()
+ {
+ var inputSchema = predictors[0].TransformModel.InputSchema;
+ var dv = new EmptyDataView(env, inputSchema);
+
+ // The role mappings are specific to the individual predictors.
+ var rmd = new RoleMappedData(dv);
+ var predictorModel = new PredictorModel(env, rmd, dv, ensemble);
+
+ var output = new TOut { PredictorModel = predictorModel };
+ return output;
+ }
+
+ ///
+ /// This method takes a as input, saves it as an in-memory
+ /// and returns two arrays indexed by the entries in the zip:
+ /// 1. An array of byte arrays, containing the byte sequences of each entry.
+ /// 2. An array of strings, containing the name of each entry.
+ ///
+ /// This method is used for comparing pipelines. Its outputs can be passed to
+ /// to check if this pipeline is identical to another pipeline.
+ ///
+ public static void SerializeRoleMappedData(IHostEnvironment env, IChannel ch, RoleMappedData data,
+ out byte[][] dataSerialized, out string[] dataZipEntryNames)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ch, nameof(ch));
+ ch.CheckValue(data, nameof(data));
+
+ using (var ms = new MemoryStream())
+ {
+ TrainUtils.SaveModel(env, ch, ms, null, data);
+ var zip = new ZipArchive(ms);
+ var entries = zip.Entries.OrderBy(e => e.FullName).ToArray();
+ dataSerialized = new byte[Utils.Size(entries)][];
+ dataZipEntryNames = new string[Utils.Size(entries)];
+ for (int i = 0; i < Utils.Size(entries); i++)
+ {
+ dataZipEntryNames[i] = entries[i].FullName;
+ dataSerialized[i] = new byte[entries[i].Length];
+ using (var s = entries[i].Open())
+ s.Read(dataSerialized[i], 0, (int)entries[i].Length);
+ }
+ }
+ }
+
+ ///
+ /// This method compares two pipelines to make sure they are identical. The first pipeline is passed
+ /// as a , and the second as a double byte array and a string array. The double
+ /// byte array and the string array are obtained by calling on the
+ /// second pipeline.
+ /// The comparison is done by saving as an in-memory ,
+ /// and for each entry in it, comparing its name, and the byte sequence to the corresponding entries in
+ /// and .
+ /// This method throws if for any of the entries the name/byte sequence are not identical.
+ ///
+ public static void CheckSamePipeline(IHostEnvironment env, IChannel ch,
+ RoleMappedData dataToCompare, byte[][] dataSerialized, string[] dataZipEntryNames)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ch, nameof(ch));
+ ch.CheckValue(dataToCompare, nameof(dataToCompare));
+ ch.CheckValue(dataSerialized, nameof(dataSerialized));
+ ch.CheckValue(dataZipEntryNames, nameof(dataZipEntryNames));
+ if (dataZipEntryNames.Length != dataSerialized.Length)
+ {
+ throw ch.ExceptParam(nameof(dataSerialized),
+ $"The length of {nameof(dataSerialized)} must be equal to the length of {nameof(dataZipEntryNames)}");
+ }
+
+ using (var ms = new MemoryStream())
+ {
+ // REVIEW: This can be done more efficiently by adding a custom type of repository that
+ // doesn't actually save the data, but upon stream closure compares the results to the given repository
+ // and then discards it. Currently, however, this cannot be done because ModelSaveContext does not use
+ // an abstract class/interface, but rather the RepositoryWriter class.
+ TrainUtils.SaveModel(env, ch, ms, null, dataToCompare);
+
+ string errorMsg = "Models contain different pipelines, cannot ensemble them.";
+ var zip = new ZipArchive(ms);
+ var entries = zip.Entries.OrderBy(e => e.FullName).ToArray();
+ ch.Check(dataSerialized.Length == Utils.Size(entries));
+ byte[] buffer = null;
+ for (int i = 0; i < dataSerialized.Length; i++)
+ {
+ ch.Check(dataZipEntryNames[i] == entries[i].FullName, errorMsg);
+ int len = dataSerialized[i].Length;
+ if (Utils.Size(buffer) < len)
+ buffer = new byte[len];
+ using (var s = entries[i].Open())
+ {
+ int bytesRead = s.Read(buffer, 0, len);
+ ch.Check(bytesRead == len, errorMsg);
+ for (int j = 0; j < len; j++)
+ ch.Check(buffer[j] == dataSerialized[i][j], errorMsg);
+ if (s.Read(buffer, 0, 1) > 0)
+ throw env.Except(errorMsg);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs
new file mode 100644
index 0000000000..b13cff3b35
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
+using Microsoft.ML.Runtime.EntryPoints;
+
+[assembly: EntryPointModule(typeof(DisagreementDiversityFactory))]
+[assembly: EntryPointModule(typeof(RegressionDisagreementDiversityFactory))]
+[assembly: EntryPointModule(typeof(MultiDisagreementDiversityFactory))]
+
+namespace Microsoft.ML.Ensemble.EntryPoints
+{
+ [TlcModule.Component(Name = DisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
+ public sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory
+ {
+ public IBinaryDiversityMeasure CreateComponent(IHostEnvironment env) => new DisagreementDiversityMeasure();
+ }
+
+ [TlcModule.Component(Name = RegressionDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
+ public sealed class RegressionDisagreementDiversityFactory : ISupportRegressionDiversityMeasureFactory
+ {
+ public IRegressionDiversityMeasure CreateComponent(IHostEnvironment env) => new RegressionDisagreementDiversityMeasure();
+ }
+
+ [TlcModule.Component(Name = MultiDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
+ public sealed class MultiDisagreementDiversityFactory : ISupportMulticlassDiversityMeasureFactory
+ {
+ public IMulticlassDiversityMeasure CreateComponent(IHostEnvironment env) => new MultiDisagreementDiversityMeasure();
+ }
+
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
new file mode 100644
index 0000000000..728cccb1f6
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble;
+using Microsoft.ML.Runtime.EntryPoints;
+
+[assembly: LoadableClass(typeof(void), typeof(Ensemble), null, typeof(SignatureEntryPointModule), "TrainEnsemble")]
+
+namespace Microsoft.ML.Ensemble.EntryPoints
+{
+ public static class Ensemble
+ {
+ [TlcModule.EntryPoint(Name = "Trainers.EnsembleBinaryClassifier", Desc = "Train binary ensemble.", UserName = EnsembleTrainer.UserNameValue)]
+ public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, EnsembleTrainer.Arguments input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("TrainBinaryEnsemble");
+ host.CheckValue(input, nameof(input));
+ EntryPointUtils.CheckInputArgs(host, input);
+
+ return LearnerEntryPointsUtils.Train(host, input,
+ () => new EnsembleTrainer(host, input),
+ () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
+ }
+
+ [TlcModule.EntryPoint(Name = "Trainers.EnsembleClassification", Desc = "Train multiclass ensemble.", UserName = EnsembleTrainer.UserNameValue)]
+ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassEnsemble(IHostEnvironment env, MulticlassDataPartitionEnsembleTrainer.Arguments input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("TrainMultiClassEnsemble");
+ host.CheckValue(input, nameof(input));
+ EntryPointUtils.CheckInputArgs(host, input);
+
+ return LearnerEntryPointsUtils.Train(host, input,
+ () => new MulticlassDataPartitionEnsembleTrainer(host, input),
+ () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
+ }
+
+ [TlcModule.EntryPoint(Name = "Trainers.EnsembleRegression", Desc = "Train regression ensemble.", UserName = EnsembleTrainer.UserNameValue)]
+ public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionEnsembleTrainer.Arguments input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("TrainRegressionEnsemble");
+ host.CheckValue(input, nameof(input));
+ EntryPointUtils.CheckInputArgs(host, input);
+
+ return LearnerEntryPointsUtils.Train(host, input,
+ () => new RegressionEnsembleTrainer(host, input),
+ () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs
new file mode 100644
index 0000000000..65ca5e9d06
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector;
+using Microsoft.ML.Runtime.EntryPoints;
+
+[assembly: EntryPointModule(typeof(AllFeatureSelectorFactory))]
+[assembly: EntryPointModule(typeof(RandomFeatureSelector))]
+
+namespace Microsoft.ML.Ensemble.EntryPoints
+{
+ [TlcModule.Component(Name = AllFeatureSelector.LoadName, FriendlyName = AllFeatureSelector.UserName)]
+ public sealed class AllFeatureSelectorFactory : ISupportFeatureSelectorFactory
+ {
+ IFeatureSelector IComponentFactory.CreateComponent(IHostEnvironment env) => new AllFeatureSelector(env);
+ }
+
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs
new file mode 100644
index 0000000000..537b35f47b
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs
@@ -0,0 +1,51 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+
+[assembly: EntryPointModule(typeof(AverageFactory))]
+[assembly: EntryPointModule(typeof(MedianFactory))]
+[assembly: EntryPointModule(typeof(MultiAverage))]
+[assembly: EntryPointModule(typeof(MultiMedian))]
+[assembly: EntryPointModule(typeof(MultiStacking))]
+[assembly: EntryPointModule(typeof(MultiVotingFactory))]
+[assembly: EntryPointModule(typeof(MultiWeightedAverage))]
+[assembly: EntryPointModule(typeof(RegressionStacking))]
+[assembly: EntryPointModule(typeof(Stacking))]
+[assembly: EntryPointModule(typeof(VotingFactory))]
+[assembly: EntryPointModule(typeof(WeightedAverage))]
+
+namespace Microsoft.ML.Ensemble.EntryPoints
+{
+ [TlcModule.Component(Name = Average.LoadName, FriendlyName = Average.UserName)]
+ public sealed class AverageFactory : ISupportBinaryOutputCombinerFactory, ISupportRegressionOutputCombinerFactory
+ {
+ public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new Average(env);
+
+ IBinaryOutputCombiner IComponentFactory.CreateComponent(IHostEnvironment env) => new Average(env);
+ }
+
+ [TlcModule.Component(Name = Median.LoadName, FriendlyName = Median.UserName)]
+ public sealed class MedianFactory : ISupportBinaryOutputCombinerFactory, ISupportRegressionOutputCombinerFactory
+ {
+ public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new Median(env);
+
+ IBinaryOutputCombiner IComponentFactory.CreateComponent(IHostEnvironment env) => new Median(env);
+ }
+
+ [TlcModule.Component(Name = Voting.LoadName, FriendlyName = Voting.UserName)]
+ public sealed class VotingFactory : ISupportBinaryOutputCombinerFactory
+ {
+ IBinaryOutputCombiner IComponentFactory.CreateComponent(IHostEnvironment env) => new Voting(env);
+ }
+
+ [TlcModule.Component(Name = MultiVoting.LoadName, FriendlyName = Voting.UserName)]
+ public sealed class MultiVotingFactory : ISupportMulticlassOutputCombinerFactory
+ {
+ public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiVoting(env);
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs
new file mode 100644
index 0000000000..bcfaaefb89
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs
@@ -0,0 +1,59 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Calibration;
+
+[assembly: EntryPointModule(typeof(PipelineEnsemble))]
+
+namespace Microsoft.ML.Runtime.Ensemble.EntryPoints
+{
+ public static class PipelineEnsemble
+ {
+ public sealed class SummaryOutput
+ {
+ [TlcModule.Output(Desc = "The summaries of the individual predictors")]
+ public IDataView[] Summaries;
+
+ [TlcModule.Output(Desc = "The model statistics of the individual predictors")]
+ public IDataView[] Stats;
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.EnsembleSummary", Desc = "Summarize a pipeline ensemble predictor.")]
+ public static SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.Input input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("PipelineEnsemblePredictor");
+ host.CheckValue(input, nameof(input));
+ EntryPointUtils.CheckInputArgs(host, input);
+
+ input.PredictorModel.PrepareData(host,
+ new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema),
+ out RoleMappedData rmd, out IPredictor predictor
+);
+
+ var calibrated = predictor as CalibratedPredictorBase;
+ while (calibrated != null)
+ {
+ predictor = calibrated.SubPredictor;
+ calibrated = predictor as CalibratedPredictorBase;
+ }
+ var ensemble = predictor as SchemaBindablePipelineEnsembleBase;
+ host.CheckUserArg(ensemble != null, nameof(input.PredictorModel.Predictor), "Predictor is not a pipeline ensemble predictor");
+
+ var summaries = new IDataView[ensemble.PredictorModels.Length];
+ var stats = new IDataView[ensemble.PredictorModels.Length];
+ for (int i = 0; i < ensemble.PredictorModels.Length; i++)
+ {
+ var pm = ensemble.PredictorModels[i];
+
+ pm.PrepareData(host, new EmptyDataView(host, pm.TransformModel.InputSchema), out rmd, out IPredictor pred);
+ summaries[i] = SummarizePredictor.GetSummaryAndStats(host, pred, rmd.Schema, out stats[i]);
+ }
+ return new SummaryOutput() { Summaries = summaries, Stats = stats };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs
new file mode 100644
index 0000000000..57001190ac
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Ensemble.EntryPoints;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
+using Microsoft.ML.Runtime.EntryPoints;
+
+[assembly: EntryPointModule(typeof(AllSelectorFactory))]
+[assembly: EntryPointModule(typeof(AllSelectorMultiClassFactory))]
+[assembly: EntryPointModule(typeof(BestDiverseSelectorBinary))]
+[assembly: EntryPointModule(typeof(BestDiverseSelectorMultiClass))]
+[assembly: EntryPointModule(typeof(BestDiverseSelectorRegression))]
+[assembly: EntryPointModule(typeof(BestPerformanceRegressionSelector))]
+[assembly: EntryPointModule(typeof(BestPerformanceSelector))]
+[assembly: EntryPointModule(typeof(BestPerformanceSelectorMultiClass))]
+
+namespace Microsoft.ML.Ensemble.EntryPoints
+{
+ [TlcModule.Component(Name = AllSelector.LoadName, FriendlyName = AllSelector.UserName)]
+ public sealed class AllSelectorFactory : ISupportBinarySubModelSelectorFactory, ISupportRegressionSubModelSelectorFactory
+ {
+ IBinarySubModelSelector IComponentFactory.CreateComponent(IHostEnvironment env) => new AllSelector(env);
+
+ IRegressionSubModelSelector IComponentFactory.CreateComponent(IHostEnvironment env) => new AllSelector(env);
+ }
+
+ [TlcModule.Component(Name = AllSelectorMultiClass.LoadName, FriendlyName = AllSelectorMultiClass.UserName)]
+ public sealed class AllSelectorMultiClassFactory : ISupportMulticlassSubModelSelectorFactory
+ {
+ IMulticlassSubModelSelector IComponentFactory.CreateComponent(IHostEnvironment env) => new AllSelectorMultiClass(env);
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs
new file mode 100644
index 0000000000..4518666d34
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Internal.Utilities;
+
+namespace Microsoft.ML.Runtime.Ensemble
+{
+ public sealed class FeatureSubsetModel where TPredictor : IPredictor
+ {
+ public readonly TPredictor Predictor;
+ public readonly BitArray SelectedFeatures;
+ public readonly int Cardinality;
+
+ public KeyValuePair[] Metrics { get; set; }
+
+ public FeatureSubsetModel(TPredictor predictor, BitArray features = null,
+ KeyValuePair[] metrics = null)
+ {
+ Predictor = predictor;
+ int card;
+ if (features != null && (card = Utils.GetCardinality(features)) < features.Count)
+ {
+ SelectedFeatures = features;
+ Cardinality = card;
+ }
+ Metrics = metrics;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
new file mode 100644
index 0000000000..ddd4557788
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
@@ -0,0 +1,15 @@
+
+
+
+ netstandard2.0
+ Microsoft.ML.Ensemble
+ CORECLR
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
new file mode 100644
index 0000000000..45cd764d13
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
@@ -0,0 +1,62 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(Average), null, typeof(SignatureCombiner), Average.UserName)]
+[assembly: LoadableClass(typeof(Average), null, typeof(SignatureLoadModel), Average.UserName, Average.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public sealed class Average : BaseAverager, ICanSaveModel, IRegressionOutputCombiner
+ {
+ public const string UserName = "Average";
+ public const string LoadName = "Average";
+ public const string LoaderSignature = "AverageCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "AVG COMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ public Average(IHostEnvironment env)
+ : base(env, LoaderSignature)
+ {
+ }
+
+ private Average(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ }
+
+ public static Average Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new Average(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+
+ public override Combiner GetCombiner()
+ {
+ // Force the weights to null.
+ return(ref Single dst, Single[] src, Single[] weights) =>
+ CombineCore(ref dst, src, null);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs
new file mode 100644
index 0000000000..824300e594
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs
@@ -0,0 +1,78 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime.Model;
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public abstract class BaseAverager : IBinaryOutputCombiner
+ {
+ protected readonly IHost Host;
+ public BaseAverager(IHostEnvironment env, string name)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckNonWhiteSpace(name, nameof(name));
+ Host = env.Register(name);
+ }
+
+ protected BaseAverager(IHostEnvironment env, string name, ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(env);
+ env.AssertNonWhiteSpace(name);
+ Host = env.Register(name);
+ Host.CheckValue(ctx, nameof(ctx));
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ int cbFloat = ctx.Reader.ReadInt32();
+ Host.CheckDecode(cbFloat == sizeof(Single));
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ SaveCore(ctx);
+ }
+
+ protected virtual void SaveCore(ModelSaveContext ctx)
+ {
+ // *** Binary format ***
+ // int: sizeof(Single)
+ ctx.Writer.Write(sizeof(Single));
+ }
+
+ public abstract Combiner GetCombiner();
+
+ protected void CombineCore(ref Single dst, Single[] src, Single[] weights = null)
+ {
+ Single sum = 0;
+ Single weightTotal = 0;
+ if (weights == null)
+ {
+ for (int i = 0; i < src.Length; i++)
+ {
+ if (!Single.IsNaN(src[i]))
+ {
+ sum += src[i];
+ weightTotal++;
+ }
+ }
+ }
+ else
+ {
+ for (int i = 0; i < src.Length; i++)
+ {
+ if (!Single.IsNaN(src[i]))
+ {
+ sum += weights[i] * src[i];
+ weightTotal += weights[i];
+ }
+ }
+ }
+ dst = sum / weightTotal;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
new file mode 100644
index 0000000000..64ec41d613
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
@@ -0,0 +1,68 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Numeric;
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public abstract class BaseMultiAverager : BaseMultiCombiner
+ {
+ internal BaseMultiAverager(IHostEnvironment env, string name, ArgumentsBase args)
+ : base(env, name, args)
+ {
+ }
+
+ internal BaseMultiAverager(IHostEnvironment env, string name, ModelLoadContext ctx)
+ : base(env, name, ctx)
+ {
+ }
+
+ protected void CombineCore(ref VBuffer dst, VBuffer[] src, Single[] weights = null)
+ {
+ Host.AssertNonEmpty(src);
+ Host.Assert(weights == null || Utils.Size(weights) == Utils.Size(src));
+
+ // REVIEW: Should this be tolerant of NaNs?
+ int len = GetClassCount(src);
+ if (!TryNormalize(src))
+ {
+ GetNaNOutput(ref dst, len);
+ return;
+ }
+
+ var values = dst.Values;
+ if (Utils.Size(values) < len)
+ values = new Single[len];
+ else
+ Array.Clear(values, 0, len);
+
+ // Set the output to values.
+ dst = new VBuffer(len, values, dst.Indices);
+
+ Single weightTotal;
+ if (weights == null)
+ {
+ weightTotal = src.Length;
+ for (int i = 0; i < src.Length; i++)
+ VectorUtils.Add(ref src[i], ref dst);
+ }
+ else
+ {
+ weightTotal = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ var w = weights[i];
+ weightTotal += w;
+ VectorUtils.AddMult(ref src[i], w, ref dst);
+ }
+ }
+
+ VectorUtils.ScaleBy(ref dst, 1 / weightTotal);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
new file mode 100644
index 0000000000..1258313df1
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
@@ -0,0 +1,109 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Numeric;
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public abstract class BaseMultiCombiner : IMultiClassOutputCombiner
+ {
+ protected readonly IHost Host;
+
+ public abstract class ArgumentsBase
+ {
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to normalize the output of base models before combining them",
+ ShortName = "norm", SortOrder = 50)]
+ public bool Normalize = true;
+ }
+
+ protected readonly bool Normalize;
+
+ internal BaseMultiCombiner(IHostEnvironment env, string name, ArgumentsBase args)
+ {
+ Contracts.AssertValue(env);
+ env.AssertNonWhiteSpace(name);
+ Host = env.Register(name);
+ Host.CheckValue(args, nameof(args));
+
+ Normalize = args.Normalize;
+ }
+
+ internal BaseMultiCombiner(IHostEnvironment env, string name, ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(env);
+ env.AssertNonWhiteSpace(name);
+ Host = env.Register(name);
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ // bool: _normalize
+ int cbFloat = ctx.Reader.ReadInt32();
+ Host.CheckDecode(cbFloat == sizeof(Single));
+ Normalize = ctx.Reader.ReadBoolByte();
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ SaveCore(ctx);
+ }
+
+ protected virtual void SaveCore(ModelSaveContext ctx)
+ {
+ // *** Binary format ***
+ // int: sizeof(Single)
+ // bool: _normalize
+ ctx.Writer.Write(sizeof(Single));
+ ctx.Writer.WriteBoolByte(Normalize);
+ }
+
+ public abstract Combiner> GetCombiner();
+
+ protected int GetClassCount(VBuffer[] values)
+ {
+ int len = 0;
+ foreach (var item in values)
+ {
+ if (len < item.Length)
+ len = item.Length;
+ }
+ return len;
+ }
+
+ protected bool TryNormalize(VBuffer[] values)
+ {
+ if (!Normalize)
+ return true;
+
+ for (int i = 0; i < values.Length; i++)
+ {
+ // Leave a zero vector as all zeros. Otherwise, make the L1 norm equal to 1.
+ var sum = VectorUtils.L1Norm(ref values[i]);
+ if (!FloatUtils.IsFinite(sum))
+ return false;
+ if (sum > 0)
+ VectorUtils.ScaleBy(ref values[i], 1 / sum);
+ }
+ return true;
+ }
+
+ protected void GetNaNOutput(ref VBuffer dst, int len)
+ {
+ Contracts.Assert(len >= 0);
+ var values = dst.Values;
+ if (Utils.Size(values) < len)
+ values = new Single[len];
+ for (int i = 0; i < len; i++)
+ values[i] = Single.NaN;
+ dst = new VBuffer(len, values, dst.Indices);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
new file mode 100644
index 0000000000..a5c9c757a4
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public abstract class BaseScalarStacking : BaseStacking
+ {
+ internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args)
+ : base(env, name, args)
+ {
+ }
+
+ internal BaseScalarStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
+ : base(env, name, ctx)
+ {
+ }
+
+ protected override void FillFeatureBuffer(Single[] src, ref VBuffer dst)
+ {
+ Contracts.AssertNonEmpty(src);
+ int len = src.Length;
+ var values = dst.Values;
+ if (Utils.Size(values) < len)
+ values = new Single[len];
+ Array.Copy(src, values, len);
+ dst = new VBuffer(len, values, dst.Indices);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
new file mode 100644
index 0000000000..f30174a31d
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -0,0 +1,200 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Threading.Tasks;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Training;
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ using ColumnRole = RoleMappedSchema.ColumnRole;
+ public abstract class BaseStacking : IStackingTrainer
+ {
+ public abstract class ArgumentsBase
+ {
+ [Argument(ArgumentType.AtMostOnce, ShortName = "vp", SortOrder = 50,
+ HelpText = "The proportion of instances to be selected to test the individual base learner. If it is 0, it uses training set")]
+ [TGUI(Label = "Validation Dataset Proportion")]
+ public Single ValidationDatasetProportion = 0.3f;
+
+ [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
+ Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
+ [TGUI(Label = "Base predictor")]
+ public SubComponent>, TSigBase> BasePredictorType;
+ }
+
+ protected readonly SubComponent>, TSigBase> BasePredictorType;
+ protected readonly IHost Host;
+ protected IPredictorProducing Meta;
+
+ public Single ValidationDatasetProportion { get; }
+
+ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
+ {
+ Contracts.AssertValue(env);
+ env.AssertNonWhiteSpace(name);
+ Host = env.Register(name);
+ Host.AssertValue(args, "args");
+ Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1,
+ nameof(args.ValidationDatasetProportion),
+ "The validation proportion for stacking should be greater than or equal to 0 and less than 1");
+ Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType));
+
+ ValidationDatasetProportion = args.ValidationDatasetProportion;
+ BasePredictorType = args.BasePredictorType;
+ }
+
+ internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(env);
+ env.AssertNonWhiteSpace(name);
+ Host = env.Register(name);
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ // Float: _validationDatasetProportion
+ int cbFloat = ctx.Reader.ReadInt32();
+ env.CheckDecode(cbFloat == sizeof(Single));
+ ValidationDatasetProportion = ctx.Reader.ReadFloat();
+ env.CheckDecode(0 <= ValidationDatasetProportion && ValidationDatasetProportion < 1);
+
+ ctx.LoadModel, SignatureLoadModel>(env, out Meta, "MetaPredictor");
+ CheckMeta();
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ Host.Check(Meta != null, "Can't save an untrained Stacking combiner");
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ SaveCore(ctx);
+ }
+
+ protected virtual void SaveCore(ModelSaveContext ctx)
+ {
+ Host.Assert(Meta != null);
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ // Float: _validationDatasetProportion
+ ctx.Writer.Write(sizeof(Single));
+ ctx.Writer.Write(ValidationDatasetProportion);
+
+ ctx.SaveModel(Meta, "MetaPredictor");
+ }
+
+ public Combiner GetCombiner()
+ {
+ Contracts.Check(Meta != null, "Training of stacking combiner not complete");
+
+ // Subtle point: We shouldn't get the ValueMapper delegate and cache it in a field
+ // since generally ValueMappers cannot be assumed to be thread safe - they often
+ // capture buffers needed for efficient operation.
+ var mapper = (IValueMapper)Meta;
+ var map = mapper.GetMapper, TOutput>();
+
+ var feat = default(VBuffer);
+ Combiner res =
+ (ref TOutput dst, TOutput[] src, Single[] weights) =>
+ {
+ FillFeatureBuffer(src, ref feat);
+ map(ref feat, ref dst);
+ };
+ return res;
+ }
+
+ protected abstract void FillFeatureBuffer(TOutput[] src, ref VBuffer dst);
+
+ private void CheckMeta()
+ {
+ Contracts.Assert(Meta != null);
+
+ var ivm = Meta as IValueMapper;
+ Contracts.Check(ivm != null, "Stacking predictor doesn't implement the expected interface");
+ if (!ivm.InputType.IsVector || ivm.InputType.ItemType != NumberType.Float)
+ throw Contracts.Except("Stacking predictor input type is unsupported: {0}", ivm.InputType);
+ if (ivm.OutputType.RawType != typeof(TOutput))
+ throw Contracts.Except("Stacking predictor output type is unsupported: {0}", ivm.OutputType);
+ }
+
+ public void Train(List>> models, RoleMappedData data, IHostEnvironment env)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(Stacking.LoadName);
+ host.CheckValue(models, nameof(models));
+ host.CheckValue(data, nameof(data));
+
+ using (var ch = host.Start("Training stacked model"))
+ {
+ ch.Check(Meta == null, "Train called multiple times");
+ ch.Check(BasePredictorType != null);
+
+ var maps = new ValueMapper, TOutput>[models.Count];
+ for (int i = 0; i < maps.Length; i++)
+ {
+ Contracts.Assert(models[i].Predictor is IValueMapper);
+ var m = (IValueMapper)models[i].Predictor;
+ maps[i] = m.GetMapper, TOutput>();
+ }
+
+ // REVIEW: Should implement this better....
+ var labels = new Single[100];
+ var features = new VBuffer[100];
+ int count = 0;
+ // REVIEW: Should this include bad values or filter them?
+ using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
+ {
+ TOutput[] predictions = new TOutput[maps.Length];
+ var vBuffers = new VBuffer[maps.Length];
+ while (cursor.MoveNext())
+ {
+ Parallel.For(0, maps.Length, i =>
+ {
+ var model = models[i];
+ if (model.SelectedFeatures != null)
+ {
+ EnsembleUtils.SelectFeatures(ref cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
+ maps[i](ref vBuffers[i], ref predictions[i]);
+ }
+ else
+ maps[i](ref cursor.Features, ref predictions[i]);
+ });
+
+ Utils.EnsureSize(ref labels, count + 1);
+ Utils.EnsureSize(ref features, count + 1);
+ labels[count] = cursor.Label;
+ FillFeatureBuffer(predictions, ref features[count]);
+ count++;
+ }
+ }
+
+ ch.Info("The number of instances used for stacking trainer is {0}", count);
+
+ var bldr = new ArrayDataViewBuilder(host);
+ Array.Resize(ref labels, count);
+ Array.Resize(ref features, count);
+ bldr.AddColumn(DefaultColumnNames.Label, NumberType.Float, labels);
+ bldr.AddColumn(DefaultColumnNames.Features, NumberType.Float, features);
+
+ var view = bldr.GetDataView();
+ var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
+
+ var trainer = BasePredictorType.CreateInstance(host);
+ if (trainer.Info.NeedNormalization)
+ ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
+ Meta = trainer.Train(rmd);
+ CheckMeta();
+
+ ch.Done();
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs
new file mode 100644
index 0000000000..512974b717
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs
@@ -0,0 +1,71 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ ///
+ /// Signature for combiners.
+ ///
+ public delegate void SignatureCombiner();
+
+ public delegate void Combiner(ref TOutput dst, TOutput[] src, Single[] weights);
+
+ public interface IOutputCombiner
+ {
+ }
+
+ ///
+ /// Generic interface for combining outputs of multiple models
+ ///
+ public interface IOutputCombiner : IOutputCombiner
+ {
+ Combiner GetCombiner();
+ }
+
+ public interface IStackingTrainer
+ {
+ void Train(List>> models, RoleMappedData data, IHostEnvironment env);
+ Single ValidationDatasetProportion { get; }
+ }
+
+ public interface IRegressionOutputCombiner : IOutputCombiner
+ {
+ }
+
+ public interface IBinaryOutputCombiner : IOutputCombiner
+ {
+ }
+
+ public interface IMultiClassOutputCombiner : IOutputCombiner>
+ {
+ }
+
+ [TlcModule.ComponentKind("EnsembleMulticlassOutputCombiner")]
+ public interface ISupportMulticlassOutputCombinerFactory : IComponentFactory
+ {
+ }
+
+ [TlcModule.ComponentKind("EnsembleBinaryOutputCombiner")]
+ public interface ISupportBinaryOutputCombinerFactory : IComponentFactory
+ {
+
+ }
+
+ [TlcModule.ComponentKind("EnsembleRegressionOutputCombiner")]
+ public interface ISupportRegressionOutputCombinerFactory : IComponentFactory
+ {
+
+ }
+
+ public interface IWeightedAverager
+ {
+ string WeightageMetricName { get; }
+ }
+
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
new file mode 100644
index 0000000000..de8d950de4
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
@@ -0,0 +1,87 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(Median), null, typeof(SignatureCombiner), Median.UserName, Median.LoadName)]
+[assembly: LoadableClass(typeof(Median), null, typeof(SignatureLoadModel), Median.UserName, Median.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ ///
+ /// Generic interface for combining outputs of multiple models
+ ///
+ public sealed class Median : IRegressionOutputCombiner, IBinaryOutputCombiner, ICanSaveModel
+ {
+ private readonly IHost _host;
+ public const string UserName = "Median";
+ public const string LoadName = "Median";
+ public const string LoaderSignature = "MedianCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "MEDICOMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ public Median(IHostEnvironment env)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(LoaderSignature);
+ }
+
+ private Median(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(env);
+ _host = env.Register(LoaderSignature);
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ int cbFloat = ctx.Reader.ReadInt32();
+ _host.CheckDecode(cbFloat == sizeof(Single));
+ }
+
+ public static Median Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new Median(env, ctx);
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ _host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: sizeof(Float)
+ ctx.Writer.Write(sizeof(Single));
+ }
+
+ public Combiner GetCombiner()
+ {
+ return CombineCore;
+ }
+
+ private void CombineCore(ref Single dst, Single[] src, Single[] weights)
+ {
+ // REVIEW: This mutates "src". We need to ensure that the documentation of
+ // combiners makes it clear that combiners are allowed to do this. Note that "normalization"
+ // in the multi-class case also mutates.
+ _host.AssertNonEmpty(src);
+ _host.Assert(weights == null || Utils.Size(weights) == Utils.Size(src));
+ dst = MathUtils.GetMedianInPlace(src, src.Length);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
new file mode 100644
index 0000000000..c147f932f3
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
@@ -0,0 +1,72 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Arguments), typeof(SignatureCombiner),
+ Average.UserName, MultiAverage.LoadName)]
+[assembly: LoadableClass(typeof(MultiAverage), null, typeof(SignatureLoadModel), Average.UserName,
+ MultiAverage.LoadName, MultiAverage.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public sealed class MultiAverage : BaseMultiAverager, ICanSaveModel
+ {
+ public const string LoadName = "MultiAverage";
+ public const string LoaderSignature = "MultiAverageCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "MAVGCOMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = Average.UserName)]
+ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ {
+ public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiAverage(env, this);
+ }
+
+ public MultiAverage(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature, args)
+ {
+ }
+
+ private MultiAverage(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ }
+
+ public static MultiAverage Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new MultiAverage(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+
+ public override Combiner> GetCombiner()
+ {
+ // Force the weights to null.
+ return
+ (ref VBuffer dst, VBuffer[] src, Single[] weights) =>
+ CombineCore(ref dst, src, null);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
new file mode 100644
index 0000000000..c3e6869d69
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
@@ -0,0 +1,102 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Arguments), typeof(SignatureCombiner),
+ Median.UserName, MultiMedian.LoadName)]
+[assembly: LoadableClass(typeof(MultiMedian), null, typeof(SignatureLoadModel), Median.UserName, MultiMedian.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ ///
+ /// Generic interface for combining outputs of multiple models
+ ///
+ public sealed class MultiMedian : BaseMultiCombiner, ICanSaveModel
+ {
+ public const string LoadName = "MultiMedian";
+ public const string LoaderSignature = "MultiMedianCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "MMEDCOMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = Median.UserName)]
+ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ {
+ public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiMedian(env, this);
+ }
+
+ public MultiMedian(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature, args)
+ {
+ }
+
+ private MultiMedian(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ }
+
+ public static MultiMedian Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new MultiMedian(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+
+ public override Combiner> GetCombiner()
+ {
+ Single[] raw = null;
+ return
+ (ref VBuffer dst, VBuffer[] src, Single[] weights) =>
+ {
+ Host.AssertNonEmpty(src);
+ Host.Assert(weights == null || Utils.Size(weights) == Utils.Size(src));
+
+ int len = GetClassCount(src);
+ if (!TryNormalize(src))
+ {
+ GetNaNOutput(ref dst, len);
+ return;
+ }
+
+ var values = dst.Values;
+ if (Utils.Size(values) < len)
+ values = new Single[len];
+
+ int count = src.Length;
+ if (Utils.Size(raw) < count)
+ raw = new Single[count];
+ for (int i = 0; i < len; i++)
+ {
+ for (int j = 0; j < count; j++)
+ raw[j] = i < src[j].Length ? src[j].GetItemOrDefault(i) : 0;
+ values[i] = MathUtils.GetMedianInPlace(raw, count);
+ }
+
+ // Set the output to values.
+ dst = new VBuffer(len, values, dst.Indices);
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
new file mode 100644
index 0000000000..2ef74c8169
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -0,0 +1,99 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
+ Stacking.UserName, MultiStacking.LoadName)]
+
+[assembly: LoadableClass(typeof(MultiStacking), null, typeof(SignatureLoadModel),
+ Stacking.UserName, MultiStacking.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ using TVectorPredictor = IPredictorProducing>;
+ public sealed class MultiStacking : BaseStacking, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner
+ {
+ public const string LoadName = "MultiStacking";
+ public const string LoaderSignature = "MultiStackingCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "MSTACK C",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
+ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ {
+ public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
+
+ public Arguments()
+ {
+ // REVIEW: Perhaps we can have a better non-parametetric learner.
+ BasePredictorType = new SubComponent, SignatureMultiClassClassifierTrainer>(
+ "OVA", "p=FastTreeBinaryClassification");
+ }
+ }
+
+ public MultiStacking(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature, args)
+ {
+ }
+
+ private MultiStacking(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ }
+
+ public static MultiStacking Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new MultiStacking(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+
+ protected override void FillFeatureBuffer(VBuffer[] src, ref VBuffer dst)
+ {
+ Contracts.AssertNonEmpty(src);
+
+ // REVIEW: Would there be any value in ever making dst sparse?
+ int len = 0;
+ for (int i = 0; i < src.Length; i++)
+ len += src[i].Length;
+
+ var values = dst.Values;
+ if (Utils.Size(values) < len)
+ values = new Single[len];
+ dst = new VBuffer(len, values, dst.Indices);
+
+ int iv = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ src[i].CopyTo(values, iv);
+ iv += src[i].Length;
+ Contracts.Assert(iv <= len);
+ }
+ Contracts.Assert(iv == len);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
new file mode 100644
index 0000000000..ee55b94c77
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
@@ -0,0 +1,109 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Numeric;
+
+[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)]
+[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureLoadModel), Voting.UserName, MultiVoting.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ // REVIEW: Why is MultiVoting based on BaseMultiCombiner? Normalizing the model outputs
+ // is senseless, so the base adds no real functionality.
+ public sealed class MultiVoting : BaseMultiCombiner, ICanSaveModel
+ {
+ public const string LoadName = "MultiVoting";
+ public const string LoaderSignature = "MultiVotingCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "MVOTCOMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ private sealed class Arguments : ArgumentsBase
+ {
+ }
+
+ public MultiVoting(IHostEnvironment env)
+ : base(env, LoaderSignature, new Arguments() { Normalize = false })
+ {
+ Host.Assert(!Normalize);
+ }
+
+ private MultiVoting(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ Host.CheckDecode(!Normalize);
+ }
+
+ public static MultiVoting Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new MultiVoting(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ Contracts.Assert(!Normalize);
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+
+ public override Combiner> GetCombiner()
+ {
+ return CombineCore;
+ }
+
+ private void CombineCore(ref VBuffer dst, VBuffer[] src, Single[] weights = null)
+ {
+ Host.AssertNonEmpty(src);
+ Host.Assert(weights == null || Utils.Size(weights) == Utils.Size(src));
+
+ int count = Utils.Size(src);
+ if (count == 0)
+ {
+ dst = new VBuffer(0, dst.Values, dst.Indices);
+ return;
+ }
+
+ int len = GetClassCount(src);
+ var values = dst.Values;
+ if (Utils.Size(values) < len)
+ values = new Single[len];
+ else
+ Array.Clear(values, 0, len);
+
+ int voteCount = 0;
+ for (int i = 0; i < count; i++)
+ {
+ int index = VectorUtils.ArgMax(ref src[i]);
+ if (index >= 0)
+ {
+ values[index]++;
+ voteCount++;
+ }
+ }
+
+ // Normalize by dividing by the number of votes.
+ for (int i = 0; i < len; i++)
+ values[i] /= voteCount;
+
+ // Set the output to values.
+ dst = new VBuffer(len, values, dst.Indices);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
new file mode 100644
index 0000000000..9bda1d151a
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
@@ -0,0 +1,104 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Arguments), typeof(SignatureCombiner),
+ MultiWeightedAverage.UserName, MultiWeightedAverage.LoadName)]
+
+[assembly: LoadableClass(typeof(MultiWeightedAverage), null, typeof(SignatureLoadModel),
+ MultiWeightedAverage.UserName, MultiWeightedAverage.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ ///
+ /// Generic interface for combining outputs of multiple models
+ ///
+ public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager, ICanSaveModel
+ {
+ public const string UserName = "Multi Weighted Average";
+ public const string LoadName = "MultiWeightedAverage";
+ public const string LoaderSignature = "MultiWeightedAverageComb";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "MWAVCOMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
+ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ {
+ IMultiClassOutputCombiner IComponentFactory.CreateComponent(IHostEnvironment env) => new MultiWeightedAverage(env, this);
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The metric type to be used to find the weights for each model", ShortName = "wn", SortOrder = 50)]
+ [TGUI(Label = "Metric Name", Description = "The weights are calculated according to the selected metric")]
+ public MultiWeightageKind WeightageName = MultiWeightageKind.AccuracyMicroAvg;
+ }
+
+ private readonly MultiWeightageKind _weightageKind;
+ public string WeightageMetricName { get { return _weightageKind.ToString(); } }
+
+ public MultiWeightedAverage(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature, args)
+ {
+ _weightageKind = args.WeightageName;
+ Host.CheckUserArg(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind), nameof(args.WeightageName));
+ }
+
+ private MultiWeightedAverage(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ // *** Binary format ***
+ // int: _weightageKind
+
+ _weightageKind = (MultiWeightageKind)ctx.Reader.ReadInt32();
+ Host.CheckDecode(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind));
+ }
+
+ public static MultiWeightedAverage Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new MultiWeightedAverage(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ // *** Binary format ***
+ // int: _weightageKind
+
+ Host.Assert(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind));
+ ctx.Writer.Write((int)_weightageKind);
+ }
+
+ public override Combiner> GetCombiner()
+ {
+ return CombineCore;
+ }
+ }
+
+ // These values are serialized, so should not be changed.
+ public enum MultiWeightageKind
+ {
+ [TGUI(Label = MultiClassClassifierEvaluator.AccuracyMicro)]
+ AccuracyMicroAvg = 0,
+ [TGUI(Label = MultiClassClassifierEvaluator.AccuracyMacro)]
+ AccuracyMacroAvg = 1
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
new file mode 100644
index 0000000000..0b5f8e6057
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -0,0 +1,72 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner),
+ Stacking.UserName, RegressionStacking.LoadName)]
+
+[assembly: LoadableClass(typeof(RegressionStacking), null, typeof(SignatureLoadModel),
+ Stacking.UserName, RegressionStacking.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ using TScalarPredictor = IPredictorProducing;
+
+ public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
+ {
+ public const string LoadName = "RegressionStacking";
+ public const string LoaderSignature = "RegressionStacking";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "RSTACK C",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
+ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory
+ {
+ public Arguments()
+ {
+ BasePredictorType = new SubComponent, SignatureRegressorTrainer>("FastTreeRegression");
+ }
+
+ public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
+ }
+
+ public RegressionStacking(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature, args)
+ {
+ }
+
+ private RegressionStacking(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ }
+
+ public static RegressionStacking Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new RegressionStacking(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
new file mode 100644
index 0000000000..f3481e9936
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -0,0 +1,70 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)]
+[assembly: LoadableClass(typeof(Stacking), null, typeof(SignatureLoadModel), Stacking.UserName, Stacking.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ using TScalarPredictor = IPredictorProducing;
+ public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
+ {
+ public const string UserName = "Stacking";
+ public const string LoadName = "Stacking";
+ public const string LoaderSignature = "StackingCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: " STACK C",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
+ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory
+ {
+ public Arguments()
+ {
+ BasePredictorType = new SubComponent, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification");
+ }
+
+ public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
+ }
+
+ public Stacking(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature, args)
+ {
+ }
+
+ private Stacking(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ }
+
+ public static Stacking Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new Stacking(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
new file mode 100644
index 0000000000..932f99d93a
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
@@ -0,0 +1,94 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(Voting), null, typeof(SignatureCombiner), Voting.UserName, Voting.LoadName)]
+[assembly: LoadableClass(typeof(Voting), null, typeof(SignatureLoadModel), Voting.UserName, Voting.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public sealed class Voting : IBinaryOutputCombiner, ICanSaveModel
+ {
+ private readonly IHost _host;
+ public const string UserName = "Voting";
+ public const string LoadName = "Voting";
+ public const string LoaderSignature = "VotingCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "VOT COMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ public Voting(IHostEnvironment env)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(LoaderSignature);
+ }
+
+ private Voting(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(env);
+ _host = env.Register(LoaderSignature);
+ _host.AssertValue(ctx);
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ int cbFloat = ctx.Reader.ReadInt32();
+ _host.CheckDecode(cbFloat == sizeof(Single));
+ }
+
+ public static Voting Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new Voting(env, ctx);
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ _host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: sizeof(Single)
+ ctx.Writer.Write(sizeof(Single));
+ }
+
+ public Combiner GetCombiner()
+ {
+ return CombineCore;
+ }
+
+ private void CombineCore(ref Single dst, Single[] src, Single[] weights)
+ {
+ _host.AssertNonEmpty(src);
+ _host.Assert(weights == null || Utils.Size(weights) == Utils.Size(src));
+
+ int len = Utils.Size(src);
+ int pos = 0;
+ int neg = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ var v = src[i];
+ if (v > 0)
+ pos++;
+ else if (v <= 0)
+ neg++;
+ }
+ dst = (Single)(pos - neg) / (pos + neg);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
new file mode 100644
index 0000000000..8b16ffd0a2
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Arguments), typeof(SignatureCombiner),
+ WeightedAverage.UserName, WeightedAverage.LoadName)]
+
+[assembly: LoadableClass(typeof(WeightedAverage), null, typeof(SignatureLoadModel),
+ WeightedAverage.UserName, WeightedAverage.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
+{
+ public sealed class WeightedAverage : BaseAverager, IWeightedAverager, ICanSaveModel
+ {
+ public const string UserName = "Weighted Average";
+ public const string LoadName = "WeightedAverage";
+ public const string LoaderSignature = "WeightedAverageCombiner";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "WAVGCOMB",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
+ public sealed class Arguments: ISupportBinaryOutputCombinerFactory
+ {
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The metric type to be used to find the weights for each model", ShortName = "wn", SortOrder = 50)]
+ [TGUI(Label = "Weightage Name", Description = "The weights are calculated according to the selected metric")]
+ public WeightageKind WeightageName = WeightageKind.Auc;
+
+ public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new WeightedAverage(env, this);
+ }
+
+ private WeightageKind _weightageKind;
+
+ public string WeightageMetricName { get { return _weightageKind.ToString(); } }
+
+ public WeightedAverage(IHostEnvironment env, Arguments args)
+ : base(env, LoaderSignature)
+ {
+ _weightageKind = args.WeightageName;
+ Host.CheckUserArg(Enum.IsDefined(typeof(WeightageKind), _weightageKind), nameof(args.WeightageName));
+ }
+
+ private WeightedAverage(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, LoaderSignature, ctx)
+ {
+ // *** Binary format ***
+ // int: _weightageKind
+ _weightageKind = (WeightageKind)ctx.Reader.ReadInt32();
+ Host.CheckDecode(Enum.IsDefined(typeof(WeightageKind), _weightageKind));
+ }
+
+ public static WeightedAverage Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ return new WeightedAverage(env, ctx);
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ base.SaveCore(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: _weightageKind
+
+ Contracts.Assert(Enum.IsDefined(typeof(WeightageKind), _weightageKind));
+ ctx.Writer.Write((int)_weightageKind);
+ }
+
+ public override Combiner GetCombiner()
+ {
+ return CombineCore;
+ }
+ }
+
+ // These values are serialized, so should not be changed.
+ public enum WeightageKind
+ {
+ [TGUI(Label = BinaryClassifierEvaluator.Accuracy)]
+ Accuracy = 0,
+ [TGUI(Label = BinaryClassifierEvaluator.Auc)]
+ Auc = 1,
+ [TGUI(Label = BinaryClassifierEvaluator.PosPrecName)]
+ PosPrecision = 2,
+ [TGUI(Label = BinaryClassifierEvaluator.PosRecallName)]
+ PosRecall = 3,
+ [TGUI(Label = BinaryClassifierEvaluator.NegPrecName)]
+ NegPrecision = 4,
+ [TGUI(Label = BinaryClassifierEvaluator.NegRecallName)]
+ NegRecall = 5,
+ }
+
+}
diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
new file mode 100644
index 0000000000..3ac78ed91e
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
@@ -0,0 +1,749 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble;
+using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Calibration;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(typeof(SchemaBindablePipelineEnsembleBase), null, typeof(SignatureLoadModel),
+ SchemaBindablePipelineEnsembleBase.UserName, SchemaBindablePipelineEnsembleBase.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Ensemble
+{
+ ///
+ /// This class represents an ensemble predictor, where each predictor has its own featurization pipeline. It is
+ /// useful for the distributed training scenario, where the featurization includes trainable transforms (for example,
+ /// categorical transform, or normalization).
+ ///
+ public abstract class SchemaBindablePipelineEnsembleBase : ICanGetTrainingLabelNames, ICanSaveModel,
+ ISchemaBindableMapper, ICanSaveSummary, ICanGetSummaryInKeyValuePairs
+ {
+ private abstract class BoundBase : ISchemaBoundRowMapper
+ {
+ protected readonly SchemaBindablePipelineEnsembleBase Parent;
+ private readonly HashSet _inputColIndices;
+
+ protected readonly ISchemaBoundRowMapper[] Mappers;
+ protected readonly IRowToRowMapper[] BoundPipelines;
+ protected readonly int[] ScoreCols;
+
+ public ISchemaBindableMapper Bindable => Parent;
+ public RoleMappedSchema InputSchema { get; }
+ public ISchema OutputSchema { get; }
+
+ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema schema)
+ {
+ Parent = parent;
+ InputSchema = schema;
+ OutputSchema = new ScoreMapperSchema(Parent.ScoreType, Parent._scoreColumnKind);
+ _inputColIndices = new HashSet();
+ for (int i = 0; i < Parent._inputCols.Length; i++)
+ {
+ var name = Parent._inputCols[i];
+ if (!InputSchema.Schema.TryGetColumnIndex(name, out int col))
+ throw Parent.Host.Except("Schema does not contain required input column '{0}'", name);
+ _inputColIndices.Add(col);
+ }
+
+ Mappers = new ISchemaBoundRowMapper[Parent.PredictorModels.Length];
+ BoundPipelines = new IRowToRowMapper[Parent.PredictorModels.Length];
+ ScoreCols = new int[Parent.PredictorModels.Length];
+ for (int i = 0; i < Mappers.Length; i++)
+ {
+ // Get the RoleMappedSchema to pass to the predictor.
+ var emptyDv = new EmptyDataView(Parent.Host, schema.Schema);
+ Parent.PredictorModels[i].PrepareData(Parent.Host, emptyDv, out RoleMappedData rmd, out IPredictor predictor);
+
+ // Get the predictor as a bindable mapper, and bind it to the RoleMappedSchema found above.
+ var bindable = ScoreUtils.GetSchemaBindableMapper(Parent.Host, Parent.PredictorModels[i].Predictor, null);
+ Mappers[i] = bindable.Bind(Parent.Host, rmd.Schema) as ISchemaBoundRowMapper;
+ if (Mappers[i] == null)
+ throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i);
+
+ // Make sure there is a score column, and remember its index.
+ if (!Mappers[i].OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out ScoreCols[i]))
+ throw Parent.Host.Except("Predictor {0} does not contain a score column", i);
+
+ // Get the pipeline.
+ var dv = new EmptyDataView(Parent.Host, schema.Schema);
+ var tm = new TransformModel(Parent.Host, dv, dv);
+ var pipeline = Parent.PredictorModels[i].TransformModel.Apply(Parent.Host, tm);
+ BoundPipelines[i] = pipeline.AsRowToRowMapper(Parent.Host);
+ if (BoundPipelines[i] == null)
+ throw Parent.Host.Except("Transform pipeline {0} contains transforms that do not implement IRowToRowMapper", i);
+ }
+ }
+
+ public Func GetDependencies(Func predicate)
+ {
+ for (int i = 0; i < OutputSchema.ColumnCount; i++)
+ {
+ if (predicate(i))
+ return col => _inputColIndices.Contains(col);
+ }
+ return col => false;
+ }
+
+ public IEnumerable> GetInputColumnRoles()
+ {
+ yield break;
+ }
+
+ public IRow GetOutputRow(IRow input, Func predicate, out Action disposer)
+ {
+ return new SimpleRow(OutputSchema, input, new[] { CreateScoreGetter(input, predicate, out disposer) });
+ }
+
+ public abstract Delegate CreateScoreGetter(IRow input, Func mapperPredicate, out Action disposer);
+ }
+
+ // A generic base class for pipeline ensembles. This class contains the combiner.
+ private abstract class SchemaBindablePipelineEnsemble : SchemaBindablePipelineEnsembleBase, IPredictorProducing
+ {
+ protected sealed class Bound : BoundBase
+ {
+ private readonly IOutputCombiner _combiner;
+
+ public Bound(SchemaBindablePipelineEnsemble parent, RoleMappedSchema schema)
+ : base(parent, schema)
+ {
+ _combiner = parent.Combiner;
+ }
+
+ public override Delegate CreateScoreGetter(IRow input, Func mapperPredicate, out Action disposer)
+ {
+ disposer = null;
+
+ if (!mapperPredicate(0))
+ return null;
+
+ var getters = new ValueGetter[Mappers.Length];
+ for (int i = 0; i < Mappers.Length; i++)
+ {
+ // First get the output row from the pipelines. The input predicate of the predictor
+ // is the output predicate of the pipeline.
+ var inputPredicate = Mappers[i].GetDependencies(mapperPredicate);
+ var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate, out Action disp);
+ disposer += disp;
+
+ // Next we get the output row from the predictors. We activate the score column as output predicate.
+ var predictorRow = Mappers[i].GetOutputRow(pipelineRow, col => col == ScoreCols[i], out disp);
+ disposer += disp;
+ getters[i] = predictorRow.GetGetter(ScoreCols[i]);
+ }
+
+ var comb = _combiner.GetCombiner();
+ var buffer = new T[Mappers.Length];
+ ValueGetter scoreGetter =
+ (ref T dst) =>
+ {
+ for (int i = 0; i < Mappers.Length; i++)
+ getters[i](ref buffer[i]);
+ comb(ref dst, buffer, null);
+ };
+ return scoreGetter;
+ }
+
+ public ValueGetter GetLabelGetter(IRow input, int i, out Action disposer)
+ {
+ Parent.Host.Assert(0 <= i && i < Mappers.Length);
+ Parent.Host.Check(Mappers[i].InputSchema.Label != null, "Mapper was not trained using a label column");
+
+ // The label should be in the output row of the i'th pipeline
+ var pipelineRow = BoundPipelines[i].GetRow(input, col => col == Mappers[i].InputSchema.Label.Index, out disposer);
+ return RowCursorUtils.GetLabelGetter(pipelineRow, Mappers[i].InputSchema.Label.Index);
+ }
+
+ public ValueGetter GetWeightGetter(IRow input, int i, out Action disposer)
+ {
+ Parent.Host.Assert(0 <= i && i < Mappers.Length);
+
+ if (Mappers[i].InputSchema.Weight == null)
+ {
+ ValueGetter weight = (ref Single dst) => dst = 1;
+ disposer = null;
+ return weight;
+ }
+ // The weight should be in the output row of the i'th pipeline if it exists.
+ var inputPredicate = Mappers[i].GetDependencies(col => col == Mappers[i].InputSchema.Weight.Index);
+ var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate, out disposer);
+ return pipelineRow.GetGetter(Mappers[i].InputSchema.Weight.Index);
+ }
+ }
+
+ protected readonly IOutputCombiner Combiner;
+
+ protected SchemaBindablePipelineEnsemble(IHostEnvironment env, IPredictorModel[] predictors,
+ IOutputCombiner combiner, string registrationName, string scoreColumnKind)
+ : base(env, predictors, registrationName, scoreColumnKind)
+ {
+ Combiner = combiner;
+ }
+
+ protected SchemaBindablePipelineEnsemble(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind)
+ : base(env, ctx, scoreColumnKind)
+ {
+ // *** Binary format ***
+ //
+ // The combiner
+
+ ctx.LoadModel, SignatureLoadModel>(Host, out Combiner, "Combiner");
+ }
+
+ protected override void SaveCore(ModelSaveContext ctx)
+ {
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ //
+ // The combiner
+
+ ctx.SaveModel(Combiner, "Combiner");
+ }
+
+ public override ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ {
+ return new Bound(this, schema);
+ }
+ }
+
+ // This is an implementation of pipeline ensembles that combines scores of type float (regression and anomaly detection).
+ private sealed class ImplOne : SchemaBindablePipelineEnsemble
+ {
+ protected override ColumnType ScoreType => NumberType.R4;
+
+ public override PredictionKind PredictionKind
+ {
+ get
+ {
+ if (_scoreColumnKind == MetadataUtils.Const.ScoreColumnKind.Regression)
+ return PredictionKind.Regression;
+ if (_scoreColumnKind == MetadataUtils.Const.ScoreColumnKind.AnomalyDetection)
+ return PredictionKind.AnomalyDetection;
+ throw Host.Except("Unknown prediction kind");
+ }
+ }
+
+ public ImplOne(IHostEnvironment env, IPredictorModel[] predictors, IRegressionOutputCombiner combiner, string scoreColumnKind)
+ : base(env, predictors, combiner, LoaderSignature, scoreColumnKind)
+ {
+ }
+
+ public ImplOne(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind)
+ : base(env, ctx, scoreColumnKind)
+ {
+ }
+ }
+
+ // This is an implementation of pipeline ensemble that combines scores of type vectors of float (multi-class).
+ private sealed class ImplVec : SchemaBindablePipelineEnsemble>
+ {
+ protected override ColumnType ScoreType { get { return _scoreType; } }
+
+ public override PredictionKind PredictionKind
+ {
+ get
+ {
+ if (_scoreColumnKind == MetadataUtils.Const.ScoreColumnKind.MultiClassClassification)
+ return PredictionKind.MultiClassClassification;
+ throw Host.Except("Unknown prediction kind");
+ }
+ }
+
+ private readonly VectorType _scoreType;
+
+ public ImplVec(IHostEnvironment env, IPredictorModel[] predictors, IMultiClassOutputCombiner combiner)
+ : base(env, predictors, combiner, LoaderSignature, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification)
+ {
+ int classCount = CheckLabelColumn(Host, predictors, false);
+ _scoreType = new VectorType(NumberType.R4, classCount);
+ }
+
+ public ImplVec(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind)
+ : base(env, ctx, scoreColumnKind)
+ {
+ int classCount = CheckLabelColumn(Host, PredictorModels, false);
+ _scoreType = new VectorType(NumberType.R4, classCount);
+ }
+ }
+
+ // This is an implementation of pipeline ensembles that combines scores of type float, and also provides calibration (for binary classification).
+ private sealed class ImplOneWithCalibrator : SchemaBindablePipelineEnsemble, ISelfCalibratingPredictor
+ {
+ protected override ColumnType ScoreType { get { return NumberType.R4; } }
+
+ public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
+
+ public ImplOneWithCalibrator(IHostEnvironment env, IPredictorModel[] predictors, IBinaryOutputCombiner combiner)
+ : base(env, predictors, combiner, LoaderSignature, MetadataUtils.Const.ScoreColumnKind.BinaryClassification)
+ {
+ Host.Assert(_scoreColumnKind == MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
+ CheckBinaryLabel(true, Host, PredictorModels);
+ }
+
+ public ImplOneWithCalibrator(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind)
+ : base(env, ctx, scoreColumnKind)
+ {
+ Host.Assert(_scoreColumnKind == MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
+ CheckBinaryLabel(false, Host, PredictorModels);
+ }
+
+ private static void CheckBinaryLabel(bool user, IHostEnvironment env, IPredictorModel[] predictors)
+ {
+ int classCount = CheckLabelColumn(env, predictors, true);
+ if (classCount != 2)
+ {
+ var error = string.Format("Expected label to have exactly 2 classes, instead has {0}", classCount);
+ throw user ? env.ExceptParam(nameof(predictors), error) : env.ExceptDecode(error);
+ }
+ }
+
+ public IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer caliTrainer, int maxRows)
+ {
+ Host.CheckValue(ch, nameof(ch));
+ ch.CheckValue(data, nameof(data));
+ ch.CheckValue(caliTrainer, nameof(caliTrainer));
+
+ if (caliTrainer.NeedsTraining)
+ {
+ var bound = new Bound(this, new RoleMappedSchema(data.Schema));
+ using (var curs = data.GetRowCursor(col => true))
+ {
+ var scoreGetter = (ValueGetter)bound.CreateScoreGetter(curs, col => true, out Action disposer);
+
+ // We assume that we can use the label column of the first predictor, since if the labels are not identical
+ // then the whole model is garbage anyway.
+ var labelGetter = bound.GetLabelGetter(curs, 0, out Action disp);
+ disposer += disp;
+ var weightGetter = bound.GetWeightGetter(curs, 0, out disp);
+ disposer += disp;
+ try
+ {
+ int num = 0;
+ while (curs.MoveNext())
+ {
+ Single label = 0;
+ labelGetter(ref label);
+ if (!FloatUtils.IsFinite(label))
+ continue;
+ Single score = 0;
+ scoreGetter(ref score);
+ if (!FloatUtils.IsFinite(score))
+ continue;
+ Single weight = 0;
+ weightGetter(ref weight);
+ if (!FloatUtils.IsFinite(weight))
+ continue;
+
+ caliTrainer.ProcessTrainingExample(score, label > 0, weight);
+
+ if (maxRows > 0 && ++num >= maxRows)
+ break;
+ }
+ }
+ finally
+ {
+ disposer?.Invoke();
+ }
+ }
+ }
+
+ var calibrator = caliTrainer.FinishTraining(ch);
+ return CalibratorUtils.CreateCalibratedPredictor(Host, this, calibrator);
+ }
+ }
+
+ private readonly string[] _inputCols;
+
+ protected readonly IHost Host;
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "PIPELNEN",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Save predictor models in a subdirectory
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+ public const string UserName = "Pipeline Ensemble";
+ public const string LoaderSignature = "PipelineEnsemble";
+
+ private readonly string _scoreColumnKind;
+
+ protected abstract ColumnType ScoreType { get; }
+
+ public abstract PredictionKind PredictionKind { get; }
+
+ internal IPredictorModel[] PredictorModels { get; }
+
+ private SchemaBindablePipelineEnsembleBase(IHostEnvironment env, IPredictorModel[] predictors, string registrationName, string scoreColumnKind)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ Host = env.Register(registrationName);
+ Host.CheckNonEmpty(predictors, nameof(predictors));
+ Host.CheckNonWhiteSpace(scoreColumnKind, nameof(scoreColumnKind));
+
+ PredictorModels = predictors;
+ _scoreColumnKind = scoreColumnKind;
+
+ HashSet inputCols = null;
+ for (int i = 0; i < predictors.Length; i++)
+ {
+ var predModel = predictors[i];
+
+ // Get the input column names.
+ var inputSchema = predModel.TransformModel.InputSchema;
+ if (inputCols == null)
+ {
+ inputCols = new HashSet();
+ for (int j = 0; j < inputSchema.ColumnCount; j++)
+ {
+ if (inputSchema.IsHidden(j))
+ continue;
+ inputCols.Add(inputSchema.GetColumnName(j));
+ }
+ _inputCols = inputCols.ToArray();
+ }
+ else
+ {
+ int nonHiddenCols = 0;
+ for (int j = 0; j < inputSchema.ColumnCount; j++)
+ {
+ if (inputSchema.IsHidden(j))
+ continue;
+ var name = inputSchema.GetColumnName(j);
+ if (!inputCols.Contains(name))
+ throw Host.Except("Inconsistent schemas: Some schemas do not contain the column '{0}'", name);
+ nonHiddenCols++;
+ }
+ Host.Check(nonHiddenCols == _inputCols.Length,
+ "Inconsistent schemas: not all schemas have the same number of columns");
+ }
+ }
+ }
+
+ protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind)
+ {
+ Host = env.Register(LoaderSignature);
+ Host.AssertNonEmpty(scoreColumnKind);
+
+ _scoreColumnKind = scoreColumnKind;
+
+ // *** Binary format ***
+ // int: id of _scoreColumnKind (loaded in the Create method)
+ // int: number of predictors
+ // The predictor models
+ // int: the number of input columns
+ // for each input column:
+ // int: id of the column name
+
+ var length = ctx.Reader.ReadInt32();
+ Host.CheckDecode(length > 0);
+ PredictorModels = new IPredictorModel[length];
+ for (int i = 0; i < PredictorModels.Length; i++)
+ {
+ string dir =
+ ctx.Header.ModelVerWritten == 0x00010001
+ ? "PredictorModels"
+ : Path.Combine(ctx.Directory, "PredictorModels");
+ using (var ent = ctx.Repository.OpenEntry(dir, $"PredictorModel_{i:000}"))
+ PredictorModels[i] = new PredictorModel(Host, ent.Stream);
+ }
+
+ length = ctx.Reader.ReadInt32();
+ Host.CheckDecode(length >= 0);
+ _inputCols = new string[length];
+ for (int i = 0; i < length; i++)
+ _inputCols[i] = ctx.LoadNonEmptyString();
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ Host.AssertValue(ctx);
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: id of _scoreColumnKind (loaded in the Create method)
+ // int: number of predictors
+ // The predictor models
+ // int: the number of input columns
+ // for each input column:
+ // int: id of the column name
+
+ ctx.SaveNonEmptyString(_scoreColumnKind);
+
+ Host.AssertNonEmpty(PredictorModels);
+ ctx.Writer.Write(PredictorModels.Length);
+
+ for (int i = 0; i < PredictorModels.Length; i++)
+ {
+ var dir = Path.Combine(ctx.Directory, "PredictorModels");
+ using (var ent = ctx.Repository.CreateEntry(dir, $"PredictorModel_{i:000}"))
+ PredictorModels[i].Save(Host, ent.Stream);
+ }
+
+ Contracts.AssertValue(_inputCols);
+ ctx.Writer.Write(_inputCols.Length);
+ foreach (var name in _inputCols)
+ ctx.SaveNonEmptyString(name);
+
+ SaveCore(ctx);
+ }
+
+ protected abstract void SaveCore(ModelSaveContext ctx);
+
+ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, IPredictorModel[] predictors, IOutputCombiner combiner, string scoreColumnKind)
+ {
+ switch (scoreColumnKind)
+ {
+ case MetadataUtils.Const.ScoreColumnKind.BinaryClassification:
+ var binaryCombiner = combiner as IBinaryOutputCombiner;
+ if (binaryCombiner == null)
+ throw env.Except("Combiner type incompatible with score column kind");
+ return new ImplOneWithCalibrator(env, predictors, binaryCombiner);
+ case MetadataUtils.Const.ScoreColumnKind.Regression:
+ case MetadataUtils.Const.ScoreColumnKind.AnomalyDetection:
+ var regressionCombiner = combiner as IRegressionOutputCombiner;
+ if (regressionCombiner == null)
+ throw env.Except("Combiner type incompatible with score column kind");
+ return new ImplOne(env, predictors, regressionCombiner, scoreColumnKind);
+ case MetadataUtils.Const.ScoreColumnKind.MultiClassClassification:
+ var vectorCombiner = combiner as IMultiClassOutputCombiner;
+ if (vectorCombiner == null)
+ throw env.Except("Combiner type incompatible with score column kind");
+ return new ImplVec(env, predictors, vectorCombiner);
+ default:
+ throw env.Except("Unknown score kind");
+ }
+ }
+
+ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ var scoreColumnKind = ctx.LoadNonEmptyString();
+ switch (scoreColumnKind)
+ {
+ case MetadataUtils.Const.ScoreColumnKind.BinaryClassification:
+ return new ImplOneWithCalibrator(env, ctx, scoreColumnKind);
+ case MetadataUtils.Const.ScoreColumnKind.Regression:
+ case MetadataUtils.Const.ScoreColumnKind.AnomalyDetection:
+ return new ImplOne(env, ctx, scoreColumnKind);
+ case MetadataUtils.Const.ScoreColumnKind.MultiClassClassification:
+ return new ImplVec(env, ctx, scoreColumnKind);
+ default:
+ throw env.Except("Unknown score kind");
+ }
+ }
+
+ public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
+
+ public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
+ {
+ for (int i = 0; i < PredictorModels.Length; i++)
+ {
+ writer.WriteLine("Partition model {0} summary:", i);
+
+ if (!(PredictorModels[i].Predictor is ICanSaveSummary summaryModel))
+ {
+ writer.WriteLine("Model of type {0}", PredictorModels[i].Predictor.GetType().Name);
+ continue;
+ }
+
+ // Load the feature names for the i'th model.
+ var dv = new EmptyDataView(Host, PredictorModels[i].TransformModel.InputSchema);
+ PredictorModels[i].PrepareData(Host, dv, out RoleMappedData rmd, out IPredictor pred);
+ summaryModel.SaveSummary(writer, rmd.Schema);
+ }
+ }
+
+ // Checks that the predictors have matching label columns, and returns the number of classes in all predictors.
+ protected static int CheckLabelColumn(IHostEnvironment env, IPredictorModel[] models, bool isBinary)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckNonEmpty(models, nameof(models));
+
+ var model = models[0];
+ var edv = new EmptyDataView(env, model.TransformModel.InputSchema);
+ model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred);
+ var labelInfo = rmd.Schema.Label;
+ if (labelInfo == null)
+ throw env.Except("Training schema for model 0 does not have a label column");
+
+ var labelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index);
+ if (!labelType.IsKey)
+ return CheckNonKeyLabelColumnCore(env, pred, models, isBinary, labelType);
+
+ if (isBinary && labelType.KeyCount != 2)
+ throw env.Except("Label is not binary");
+ var schema = rmd.Schema.Schema;
+ var mdType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelInfo.Index);
+ if (mdType == null || !mdType.IsKnownSizeVector)
+ throw env.Except("Label column of type key must have a vector of key values metadata");
+
+ return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, labelType.AsKey, schema, labelInfo.Index, mdType);
+ }
+
+ // When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the
+ // OutputType property of the IValueMapper.
+ // If any of the predictors do not implement IValueMapper we throw an exception. Returns the class count.
+ private static int CheckNonKeyLabelColumnCore(IHostEnvironment env, IPredictor pred, IPredictorModel[] models, bool isBinary, ColumnType labelType)
+ {
+ env.Assert(!labelType.IsKey);
+ env.AssertNonEmpty(models);
+
+ if (isBinary)
+ return 2;
+
+ // The label is numeric, we just have to check that the number of classes is the same.
+ if (!(pred is IValueMapper vm))
+ throw env.Except("Cannot determine the number of classes the predictor outputs");
+ var classCount = vm.OutputType.VectorSize;
+
+ for (int i = 1; i < models.Length; i++)
+ {
+ var model = models[i];
+ var edv = new EmptyDataView(env, model.TransformModel.InputSchema);
+ model.PrepareData(env, edv, out RoleMappedData rmd, out pred);
+ vm = pred as IValueMapper;
+ if (vm.OutputType.VectorSize != classCount)
+ throw env.Except("Label of model {0} has different number of classes than model 0", i);
+ }
+ return classCount;
+ }
+
+ // Checks that all the label columns of the model have the same key type as their label column - including the same
+ // cardinality and the same key values, and returns the cardinality of the label column key.
+ private static int CheckKeyLabelColumnCore(IHostEnvironment env, IPredictorModel[] models, KeyType labelType, ISchema schema, int labelIndex, ColumnType keyValuesType)
+ where T : IEquatable
+ {
+ env.Assert(keyValuesType.ItemType.RawType == typeof(T));
+ env.AssertNonEmpty(models);
+ var labelNames = default(VBuffer);
+ schema.GetMetadata(MetadataUtils.Kinds.KeyValues, labelIndex, ref labelNames);
+ var classCount = labelNames.Length;
+
+ var curLabelNames = default(VBuffer);
+ for (int i = 1; i < models.Length; i++)
+ {
+ var model = models[i];
+ var edv = new EmptyDataView(env, model.TransformModel.InputSchema);
+ model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred);
+ var labelInfo = rmd.Schema.Label;
+ if (labelInfo == null)
+ throw env.Except("Training schema for model {0} does not have a label column", i);
+
+ var curLabelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index);
+ if (!labelType.Equals(curLabelType.AsKey))
+ throw env.Except("Label column of model {0} has different type than model 0", i);
+
+ var mdType = rmd.Schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelInfo.Index);
+ if (!mdType.Equals(keyValuesType))
+ throw env.Except("Label column of model {0} has different key value type than model 0", i);
+ rmd.Schema.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, labelInfo.Index, ref curLabelNames);
+ if (!AreEqual(ref labelNames, ref curLabelNames))
+ throw env.Except("Label of model {0} has different values than model 0", i);
+ }
+ return classCount;
+ }
+
+ private static bool AreEqual(ref VBuffer v1, ref VBuffer v2)
+ where T : IEquatable
+ {
+ if (v1.Length != v2.Length)
+ return false;
+ return v1.DenseValues().Zip(v2.DenseValues(), (x1, x2) => x1.Equals(x2)).All(b => b);
+ }
+
+ ///
+ /// This method outputs a Key-Value Pair (kvp) per model in the ensemble.
+ /// * The key is the model number such as "Partition model 0 summary". If the model implements
+ /// then this string is followed by the first line of the model summary (the first line contains a description specific to the
+ /// model kind, such as "Feature gains" for FastTree or "Feature weights" for linear).
+ /// * The value:
+ /// - If the model implements then the value is the list of Key-Value pairs
+ /// containing the detailed summary for that model (for example, linear models have a list containing kvps where the keys
+ /// are the feature names and the values are the weights. FastTree has a similar list with the feature gains as values).
+ /// - If the model does not implement but does implement ,
+ /// the value is a string containing the summary of that model.
+ /// - If neither of those interfaces are implemented then the value is a string containing the name of the type of model.
+ ///
+ ///
+ public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema)
+ {
+ Host.CheckValueOrNull(schema);
+
+ var list = new List>();
+
+ var sb = new StringBuilder();
+ for (int i = 0; i < PredictorModels.Length; i++)
+ {
+ var key = string.Format("Partition model {0} summary:", i);
+ var summaryKvps = PredictorModels[i].Predictor as ICanGetSummaryInKeyValuePairs;
+ var summaryModel = PredictorModels[i].Predictor as ICanSaveSummary;
+ if (summaryKvps == null && summaryModel == null)
+ {
+ list.Add(new KeyValuePair(key, PredictorModels[i].Predictor.GetType().Name));
+ continue;
+ }
+
+ // Load the feature names for the i'th model.
+ var dv = new EmptyDataView(Host, PredictorModels[i].TransformModel.InputSchema);
+ PredictorModels[i].PrepareData(Host, dv, out RoleMappedData rmd, out IPredictor pred);
+
+ if (summaryModel != null)
+ {
+ sb.Clear();
+ using (StringWriter sw = new StringWriter(sb))
+ summaryModel.SaveSummary(sw, rmd.Schema);
+ }
+
+ if (summaryKvps != null)
+ {
+ var listCur = summaryKvps.GetSummaryInKeyValuePairs(rmd.Schema);
+ if (summaryModel != null)
+ {
+ using (var reader = new StringReader(sb.ToString()))
+ {
+ string firstLine = null;
+ while (string.IsNullOrEmpty(firstLine))
+ firstLine = reader.ReadLine();
+ if (!string.IsNullOrEmpty(firstLine))
+ key += ("\r\n" + firstLine);
+ }
+ }
+ list.Add(new KeyValuePair(key, listCur));
+ }
+ else
+ {
+ Host.AssertValue(summaryModel);
+ list.Add(new KeyValuePair(key, sb.ToString()));
+ }
+
+ }
+ return list;
+ }
+
+ public string[] GetLabelNamesOrNull(out ColumnType labelType)
+ {
+ Host.AssertNonEmpty(PredictorModels);
+ return PredictorModels[0].GetLabelInfo(Host, out labelType);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs
new file mode 100644
index 0000000000..6a34da757a
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs
@@ -0,0 +1,45 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
+{
+ public abstract class BaseDisagreementDiversityMeasure : IDiversityMeasure
+ {
+ public List> CalculateDiversityMeasure(IList>> models,
+ ConcurrentDictionary>, TOutput[]> predictions)
+ {
+ Contracts.Assert(models.Count > 1);
+ Contracts.Assert(predictions.Count == models.Count);
+
+ var diversityValues = new List>();
+
+ for (int i = 0; i < (models.Count - 1); i++)
+ {
+ for (int j = i + 1; j < models.Count; j++)
+ {
+ Single differencesCount = 0;
+ var modelXOutputs = predictions[models[i]];
+ var modelYOutputs = predictions[models[j]];
+ for (int k = 0; k < modelXOutputs.Length; k++)
+ {
+ differencesCount += GetDifference(ref modelXOutputs[k], ref modelYOutputs[k]);
+ }
+ diversityValues.Add(new ModelDiversityMetric()
+ {
+ DiversityNumber = differencesCount,
+ ModelX = models[i],
+ ModelY = models[j]
+ });
+ }
+ }
+ return diversityValues;
+ }
+
+ protected abstract Single GetDifference(ref TOutput tOutput1, ref TOutput tOutput2);
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs
new file mode 100644
index 0000000000..7a19237974
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs
@@ -0,0 +1,25 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
+
+[assembly: LoadableClass(typeof(DisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure),
+ DisagreementDiversityMeasure.UserName, DisagreementDiversityMeasure.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
+{
+ public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IBinaryDiversityMeasure
+ {
+ public const string UserName = "Disagreement Diversity Measure";
+ public const string LoadName = "DisagreementDiversityMeasure";
+
+ protected override Single GetDifference(ref Single valueX, ref Single valueY)
+ {
+ return (valueX > 0 && valueY < 0 || valueX < 0 && valueY > 0) ? 1 : 0;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs
new file mode 100644
index 0000000000..1ee03a9489
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
+{
+ public class ModelDiversityMetric
+ {
+ public FeatureSubsetModel> ModelX { get; set; }
+ public FeatureSubsetModel> ModelY { get; set; }
+ public Single DiversityNumber { get; set; }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs
new file mode 100644
index 0000000000..fe4632791d
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs
@@ -0,0 +1,26 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
+using Microsoft.ML.Runtime.Numeric;
+
+[assembly: LoadableClass(typeof(MultiDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure),
+ DisagreementDiversityMeasure.UserName, MultiDisagreementDiversityMeasure.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
+{
+ public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure>, IMulticlassDiversityMeasure
+ {
+ public const string LoadName = "MultiDisagreementDiversityMeasure";
+
+ protected override Single GetDifference(ref VBuffer valueX, ref VBuffer valueY)
+ {
+ return (VectorUtils.ArgMax(ref valueX) != VectorUtils.ArgMax(ref valueY)) ? 1 : 0;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs
new file mode 100644
index 0000000000..62724d387e
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs
@@ -0,0 +1,24 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
+
+[assembly: LoadableClass(typeof(RegressionDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure),
+ DisagreementDiversityMeasure.UserName, RegressionDisagreementDiversityMeasure.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
+{
+ public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IRegressionDiversityMeasure
+ {
+ public const string LoadName = "RegressionDisagreementDiversityMeasure";
+
+ protected override Single GetDifference(ref Single valueX, ref Single valueY)
+ {
+ return Math.Abs(valueX - valueY);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs
new file mode 100644
index 0000000000..84a70a6ee4
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs
@@ -0,0 +1,29 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector;
+
+[assembly: LoadableClass(typeof(AllFeatureSelector), null, typeof(SignatureEnsembleFeatureSelector),
+ AllFeatureSelector.UserName, AllFeatureSelector.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector
+{
+ public sealed class AllFeatureSelector : IFeatureSelector
+ {
+ public const string UserName = "All Feature Selector";
+ public const string LoadName = "AllFeatureSelector";
+
+ public AllFeatureSelector(IHostEnvironment env)
+ {
+ }
+
+ public Subset SelectFeatures(RoleMappedData data, IRandom rand)
+ {
+ return new Subset(data);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs
new file mode 100644
index 0000000000..c0c9b8968f
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs
@@ -0,0 +1,62 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Training;
+
+[assembly: LoadableClass(typeof(RandomFeatureSelector), typeof(RandomFeatureSelector.Arguments),
+ typeof(SignatureEnsembleFeatureSelector), RandomFeatureSelector.UserName, RandomFeatureSelector.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector
+{
+ public class RandomFeatureSelector : IFeatureSelector
+ {
+ public const string UserName = "Random Feature Selector";
+ public const string LoadName = "RandomFeatureSelector";
+
+ [TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
+ public sealed class Arguments: ISupportFeatureSelectorFactory
+ {
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The proportion of features to be selected. The range is 0.0-1.0", ShortName = "fp", SortOrder = 50)]
+ public Single FeaturesSelectionProportion = 0.8f;
+
+ public IFeatureSelector CreateComponent(IHostEnvironment env) => new RandomFeatureSelector(env, this);
+ }
+
+ private readonly Arguments _args;
+ private readonly IHost _host;
+
+ public RandomFeatureSelector(IHostEnvironment env, Arguments args)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+
+ _host = env.Register(LoadName);
+ _args = args;
+ _host.Check(0 < _args.FeaturesSelectionProportion && _args.FeaturesSelectionProportion < 1,
+ "The feature proportion for RandomFeatureSelector should be greater than 0 and lesser than 1");
+ }
+
+ public Subset SelectFeatures(RoleMappedData data, IRandom rand)
+ {
+ _host.CheckValue(data, nameof(data));
+ data.CheckFeatureFloatVector();
+
+ var type = data.Schema.Feature.Type;
+ int len = type.VectorSize;
+ var features = new BitArray(len);
+ for (int j = 0; j < len; j++)
+ features[j] = rand.NextDouble() < _args.FeaturesSelectionProportion;
+ var dataNew = EnsembleUtils.SelectFeatures(_host, data, features);
+ return new Subset(dataNew, features);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs
new file mode 100644
index 0000000000..01589c4714
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs
@@ -0,0 +1,43 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
+using Microsoft.ML.Runtime.EntryPoints;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector
+{
+ public interface IDiversityMeasure
+ {
+ List> CalculateDiversityMeasure(IList>> models,
+ ConcurrentDictionary>, TOutput[]> predictions);
+ }
+
+ public delegate void SignatureEnsembleDiversityMeasure();
+
+ public interface IBinaryDiversityMeasure : IDiversityMeasure
+ { }
+ public interface IRegressionDiversityMeasure : IDiversityMeasure
+ { }
+ public interface IMulticlassDiversityMeasure : IDiversityMeasure>
+ { }
+
+ [TlcModule.ComponentKind("EnsembleBinaryDiversityMeasure")]
+ public interface ISupportBinaryDiversityMeasureFactory : IComponentFactory
+ {
+ }
+
+ [TlcModule.ComponentKind("EnsembleRegressionDiversityMeasure")]
+ public interface ISupportRegressionDiversityMeasureFactory : IComponentFactory
+ {
+ }
+
+ [TlcModule.ComponentKind("EnsembleMulticlassDiversityMeasure")]
+ public interface ISupportMulticlassDiversityMeasureFactory : IComponentFactory
+ {
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs
new file mode 100644
index 0000000000..99e90c5c01
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs
@@ -0,0 +1,21 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector
+{
+ public interface IFeatureSelector
+ {
+ Subset SelectFeatures(RoleMappedData data, IRandom rand);
+ }
+
+ public delegate void SignatureEnsembleFeatureSelector();
+
+ [TlcModule.ComponentKind("EnsembleFeatureSelector")]
+ public interface ISupportFeatureSelectorFactory : IComponentFactory
+ {
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs
new file mode 100644
index 0000000000..96e9f5b886
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs
@@ -0,0 +1,52 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector
+{
+ public interface ISubModelSelector
+ {
+ IList>> Prune(IList>> models);
+
+ void CalculateMetrics(FeatureSubsetModel> model, ISubsetSelector subsetSelector, Subset subset,
+ Batch batch, bool needMetrics);
+
+ Single ValidationDatasetProportion { get; }
+ }
+
+ public interface IRegressionSubModelSelector : ISubModelSelector
+ {
+ }
+
+ public interface IBinarySubModelSelector : ISubModelSelector
+ {
+ }
+
+ public interface IMulticlassSubModelSelector : ISubModelSelector>
+ {
+ }
+
+ public delegate void SignatureEnsembleSubModelSelector();
+
+ [TlcModule.ComponentKind("EnsembleMulticlassSubModelSelector")]
+ public interface ISupportMulticlassSubModelSelectorFactory : IComponentFactory
+ {
+ }
+
+ [TlcModule.ComponentKind("EnsembleBinarySubModelSelector")]
+ public interface ISupportBinarySubModelSelectorFactory: IComponentFactory
+ {
+
+ }
+
+ [TlcModule.ComponentKind("EnsembleRegressionSubModelSelector")]
+ public interface ISupportRegressionSubModelSelectorFactory : IComponentFactory
+ {
+
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs
new file mode 100644
index 0000000000..2a0f088219
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs
@@ -0,0 +1,26 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector
+{
+ public interface ISubsetSelector
+ {
+ void Initialize(RoleMappedData data, int size, int batchSize, Single validationDatasetProportion);
+ IEnumerable GetBatches(IRandom rand);
+ IEnumerable GetSubsets(Batch batch, IRandom rand);
+ RoleMappedData GetTestData(Subset subset, Batch batch);
+ }
+
+ public delegate void SignatureEnsembleDataSelector();
+
+ [TlcModule.ComponentKind("EnsembleSubsetSelector")]
+ public interface ISupportSubsetSelectorFactory : IComponentFactory
+ {
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs
new file mode 100644
index 0000000000..4196ab3558
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs
@@ -0,0 +1,28 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
+
+[assembly: LoadableClass(typeof(AllSelector), null, typeof(SignatureEnsembleSubModelSelector), AllSelector.UserName, AllSelector.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector
+{
+ public class AllSelector : BaseSubModelSelector, IBinarySubModelSelector, IRegressionSubModelSelector
+ {
+ public const string UserName = "All Selector";
+ public const string LoadName = "AllSelector";
+
+ public override Single ValidationDatasetProportion => 0;
+
+ protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
+
+ public AllSelector(IHostEnvironment env)
+ : base(env, LoadName)
+ {
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs
new file mode 100644
index 0000000000..6c82fc25f5
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs
@@ -0,0 +1,30 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
+
+[assembly: LoadableClass(typeof(AllSelectorMultiClass), null, typeof(SignatureEnsembleSubModelSelector),
+ AllSelectorMultiClass.UserName, AllSelectorMultiClass.LoadName)]
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector
+{
+ public class AllSelectorMultiClass : BaseSubModelSelector>, IMulticlassSubModelSelector
+ {
+ public const string UserName = "All Selector";
+ public const string LoadName = "AllSelectorMultiClass";
+
+ public override Single ValidationDatasetProportion => 0;
+
+ protected override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
+
+ public AllSelectorMultiClass(IHostEnvironment env)
+ : base(env, LoadName)
+ {
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs
new file mode 100644
index 0000000000..8701e2833c
--- /dev/null
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs
@@ -0,0 +1,121 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using Microsoft.ML.Runtime.CommandLine;
+
+namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector
+{
+ public abstract class BaseBestPerformanceSelector : SubModelDataSelector
+ {
+ protected abstract string MetricName { get; }
+
+ protected virtual bool IsAscMetric => true;
+
+ protected BaseBestPerformanceSelector(ArgumentsBase args, IHostEnvironment env, string name)
+ : base(args, env, name)
+ {
+ }
+
+ public override void CalculateMetrics(FeatureSubsetModel> model,
+ ISubsetSelector subsetSelector, Subset subset, Batch batch, bool needMetrics)
+ {
+ base.CalculateMetrics(model, subsetSelector, subset, batch, true);
+ }
+
+ public override IList>> Prune(IList>> models)
+ {
+ using (var ch = Host.Start("Pruning"))
+ {
+ var sortedModels = models.ToArray();
+ Array.Sort(sortedModels, new ModelPerformanceComparer(MetricName, IsAscMetric));
+ Print(ch, sortedModels, MetricName);
+ int modelCountToBeSelected = (int)(models.Count * LearnersSelectionProportion);
+ if (modelCountToBeSelected == 0)
+ modelCountToBeSelected = 1;
+
+ var retval = sortedModels.Where(m => m != null).Take(modelCountToBeSelected).ToList();
+ ch.Done();
+ return retval;
+ }
+ }
+
+ protected static string FindMetricName(Type type, object value)
+ {
+ Contracts.Assert(type.IsEnum);
+ Contracts.Assert(value.GetType() == type);
+
+ foreach (var field in type.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly))
+ {
+ if (field.FieldType != type)
+ continue;
+ if (field.GetCustomAttribute() != null)
+ continue;
+ var displayAttr = field.GetCustomAttribute();
+ if (displayAttr != null)
+ {
+ var valCur = field.GetValue(null);
+ if (value.Equals(valCur))
+ return displayAttr.Name;
+ }
+ }
+ Contracts.Assert(false);
+ return null;
+ }
+
+ private sealed class ModelPerformanceComparer : IComparer>>
+ {
+ private readonly string _metricName;
+ private readonly bool _isAscMetric;
+
+ public ModelPerformanceComparer(string metricName, bool isAscMetric)
+ {
+ Contracts.AssertValue(metricName);
+
+ _metricName = metricName;
+ _isAscMetric = isAscMetric;
+ }
+
+ public int Compare(FeatureSubsetModel> x, FeatureSubsetModel