diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs
index 46002c5abf..ecea73a495 100644
--- a/src/Microsoft.ML/CSharpApi.cs
+++ b/src/Microsoft.ML/CSharpApi.cs
@@ -12300,13 +12300,13 @@ namespace Runtime
{
public abstract class CalibratorTrainer : ComponentKind {}
+
+
///
///
///
public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer
{
-
-
///
/// The slope parameter of f(x) = 1 / (1 + exp(-slope * x + offset)
///
@@ -12320,45 +12320,45 @@ public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer
internal override string ComponentName => "FixedPlattCalibrator";
}
+
+
///
///
///
public sealed class NaiveCalibratorCalibratorTrainer : CalibratorTrainer
{
-
-
internal override string ComponentName => "NaiveCalibrator";
}
+
+
///
///
///
public sealed class PavCalibratorCalibratorTrainer : CalibratorTrainer
{
-
-
internal override string ComponentName => "PavCalibrator";
}
+
+
///
/// Platt calibration.
///
public sealed class PlattCalibratorCalibratorTrainer : CalibratorTrainer
{
-
-
internal override string ComponentName => "PlattCalibrator";
}
public abstract class ClassificationLossFunction : ComponentKind {}
+
+
///
/// Exponential loss.
///
public sealed class ExpLossClassificationLossFunction : ClassificationLossFunction
{
-
-
///
/// Beta (dilation)
///
@@ -12367,13 +12367,13 @@ public sealed class ExpLossClassificationLossFunction : ClassificationLossFuncti
internal override string ComponentName => "ExpLoss";
}
+
+
///
/// Hinge loss.
///
public sealed class HingeLossClassificationLossFunction : ClassificationLossFunction
{
-
-
///
/// Margin value
///
@@ -12382,23 +12382,23 @@ public sealed class HingeLossClassificationLossFunction : ClassificationLossFunc
internal override string ComponentName => "HingeLoss";
}
+
+
///
/// Log loss.
///
public sealed class LogLossClassificationLossFunction : ClassificationLossFunction
{
-
-
internal override string ComponentName => "LogLoss";
}
+
+
///
/// Smoothed Hinge loss.
///
public sealed class SmoothedHingeLossClassificationLossFunction : ClassificationLossFunction
{
-
-
///
/// Smoothing constant
///
@@ -12409,13 +12409,13 @@ public sealed class SmoothedHingeLossClassificationLossFunction : Classification
public abstract class EarlyStoppingCriterion : ComponentKind {}
+
+
///
/// Stop in case of loss of generality.
///
public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
{
-
-
///
/// Threshold in range [0,1].
///
@@ -12425,13 +12425,13 @@ public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
internal override string ComponentName => "GL";
}
+
+
///
/// Stops in case of low progress.
///
public sealed class LPEarlyStoppingCriterion : EarlyStoppingCriterion
{
-
-
///
/// Threshold in range [0,1].
///
@@ -12447,13 +12447,13 @@ public sealed class LPEarlyStoppingCriterion : EarlyStoppingCriterion
internal override string ComponentName => "LP";
}
+
+
///
/// Stops in case of generality to progress ration exceeds threshold.
///
public sealed class PQEarlyStoppingCriterion : EarlyStoppingCriterion
{
-
-
///
/// Threshold in range [0,1].
///
@@ -12469,13 +12469,13 @@ public sealed class PQEarlyStoppingCriterion : EarlyStoppingCriterion
internal override string ComponentName => "PQ";
}
+
+
///
/// Stop if validation score exceeds threshold value.
///
public sealed class TREarlyStoppingCriterion : EarlyStoppingCriterion
{
-
-
///
/// Tolerance threshold. (Non negative value)
///
@@ -12485,13 +12485,13 @@ public sealed class TREarlyStoppingCriterion : EarlyStoppingCriterion
internal override string ComponentName => "TR";
}
+
+
///
/// Stops in case of consecutive loss in generality.
///
public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
{
-
-
///
/// The window size.
///
@@ -12503,13 +12503,13 @@ public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
public abstract class FastTreeTrainer : ComponentKind {}
+
+
///
/// Uses a logit-boost boosted tree learner to perform binary classification.
///
public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTrainer
{
-
-
///
/// Should we use derivatives optimized for unbalanced sets
///
@@ -12856,13 +12856,13 @@ public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTraine
internal override string ComponentName => "FastTreeBinaryClassification";
}
+
+
///
/// Trains gradient boosted decision trees to the LambdaRank quasi-gradient.
///
public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer
{
-
-
///
/// Comma seperated list of gains associated to each relevance label.
///
@@ -13244,13 +13244,13 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer
internal override string ComponentName => "FastTreeRanking";
}
+
+
///
/// Trains gradient boosted decision trees to fit target values using least-squares.
///
public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer
{
-
-
///
/// Use best regression step trees?
///
@@ -13592,13 +13592,13 @@ public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer
internal override string ComponentName => "FastTreeRegression";
}
+
+
///
/// Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression.
///
public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer
{
-
-
///
/// Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, and intermediate values are compound Poisson loss.
///
@@ -13947,13 +13947,13 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer
public abstract class NgramExtractor : ComponentKind {}
+
+
///
/// Extracts NGrams from text and convert them to vector using dictionary.
///
public sealed class NGramNgramExtractor : NgramExtractor
{
-
-
///
/// Ngram length
///
@@ -13982,13 +13982,13 @@ public sealed class NGramNgramExtractor : NgramExtractor
internal override string ComponentName => "NGram";
}
+
+
///
/// Extracts NGrams from text and convert them to vector using hashing trick.
///
public sealed class NGramHashNgramExtractor : NgramExtractor
{
-
-
///
/// Ngram length
///
@@ -14029,45 +14029,45 @@ public sealed class NGramHashNgramExtractor : NgramExtractor
public abstract class ParallelTraining : ComponentKind {}
+
+
///
/// Single node machine learning process.
///
public sealed class SingleParallelTraining : ParallelTraining
{
-
-
internal override string ComponentName => "Single";
}
public abstract class RegressionLossFunction : ComponentKind {}
+
+
///
/// Poisson loss.
///
public sealed class PoissonLossRegressionLossFunction : RegressionLossFunction
{
-
-
internal override string ComponentName => "PoissonLoss";
}
+
+
///
/// Squared loss.
///
public sealed class SquaredLossRegressionLossFunction : RegressionLossFunction
{
-
-
internal override string ComponentName => "SquaredLoss";
}
+
+
///
/// Tweedie loss.
///
public sealed class TweedieLossRegressionLossFunction : RegressionLossFunction
{
-
-
///
/// Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, and intermediate values are compound Poisson loss.
///
@@ -14078,13 +14078,13 @@ public sealed class TweedieLossRegressionLossFunction : RegressionLossFunction
public abstract class SDCAClassificationLossFunction : ComponentKind {}
+
+
///
/// Hinge loss.
///
public sealed class HingeLossSDCAClassificationLossFunction : SDCAClassificationLossFunction
{
-
-
///
/// Margin value
///
@@ -14093,23 +14093,23 @@ public sealed class HingeLossSDCAClassificationLossFunction : SDCAClassification
internal override string ComponentName => "HingeLoss";
}
+
+
///
/// Log loss.
///
public sealed class LogLossSDCAClassificationLossFunction : SDCAClassificationLossFunction
{
-
-
internal override string ComponentName => "LogLoss";
}
+
+
///
/// Smoothed Hinge loss.
///
public sealed class SmoothedHingeLossSDCAClassificationLossFunction : SDCAClassificationLossFunction
{
-
-
///
/// Smoothing constant
///
@@ -14120,25 +14120,25 @@ public sealed class SmoothedHingeLossSDCAClassificationLossFunction : SDCAClassi
public abstract class SDCARegressionLossFunction : ComponentKind {}
+
+
///
/// Squared loss.
///
public sealed class SquaredLossSDCARegressionLossFunction : SDCARegressionLossFunction
{
-
-
internal override string ComponentName => "SquaredLoss";
}
public abstract class StopWordsRemover : ComponentKind {}
+
+
///
/// Remover with list of stopwords specified by the user.
///
public sealed class CustomStopWordsRemover : StopWordsRemover
{
-
-
///
/// List of stopwords
///
@@ -14147,13 +14147,13 @@ public sealed class CustomStopWordsRemover : StopWordsRemover
internal override string ComponentName => "Custom";
}
+
+
///
/// Remover with predefined list of stop words.
///
public sealed class PredefinedStopWordsRemover : StopWordsRemover
{
-
-
internal override string ComponentName => "Predefined";
}
diff --git a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs
index aa5736bd14..9fb0560701 100644
--- a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs
+++ b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs
@@ -5,7 +5,6 @@
using System;
using System.Collections.Generic;
using System.IO;
-using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
using Newtonsoft.Json;
@@ -165,7 +164,7 @@ private string Serialize(string name, object input, object output)
{
using (var jw = new JsonTextWriter(sw))
{
- jw.Formatting = Formatting.Indented;
+ jw.Formatting = Newtonsoft.Json.Formatting.Indented;
_serializer.Serialize(jw, _helper);
}
return sw.ToString();
diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs
index 17643518ca..f1e45fa446 100644
--- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs
+++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs
@@ -74,7 +74,7 @@ public static string GetOutputType(Type outputType)
return $"Var<{GetCSharpTypeName(outputType)}>";
}
- public static string GetInputType(ModuleCatalog catalog, Type inputType,
+ public static string GetInputType(ModuleCatalog catalog, Type inputType,
Dictionary typesSymbolTable, string rootNameSpace = "")
{
if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Var<>))
@@ -136,13 +136,13 @@ public static string GetInputType(ModuleCatalog catalog, Type inputType,
return $"{enumName}";
default:
if (isNullable)
- return rootNameSpace+typesSymbolTable[type.FullName];
+ return rootNameSpace + typesSymbolTable[type.FullName];
if (isOptional)
- return $"Optional<{rootNameSpace+typesSymbolTable[type.FullName]}>";
+ return $"Optional<{rootNameSpace + typesSymbolTable[type.FullName]}>";
if (typesSymbolTable.ContainsKey(type.FullName))
return rootNameSpace + typesSymbolTable[type.FullName];
else
- return GetSymbolFromType(typesSymbolTable, type.FullName, rootNameSpace);
+ return GetSymbolFromType(typesSymbolTable, type, rootNameSpace);
}
}
@@ -177,7 +177,7 @@ public static string Capitalize(string s)
return char.ToUpperInvariant(s[0]) + s.Substring(1);
}
- public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue,
+ public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue,
Dictionary typesSymbolTable, string rootNameSpace = "")
{
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Var<>))
@@ -299,7 +299,7 @@ public static string GetValue(ModuleCatalog catalog, Type fieldType, object fiel
var properties = propertyBag.Count > 0 ? $" {{ {string.Join(", ", propertyBag)} }}" : "";
return $"new {GetComponentName(componentInfo)}(){properties}";
case TlcModule.DataKind.Unknown:
- return $"new {rootNameSpace+typesSymbolTable[fieldType.FullName]}()";
+ return $"new {rootNameSpace + typesSymbolTable[fieldType.FullName]}()";
default:
return fieldValue.ToString();
}
@@ -321,7 +321,7 @@ public static string GetEnumName(Type type, Dictionary typesSymb
if (typesSymbolTable.ContainsKey(type.FullName))
return rootNamespace + typesSymbolTable[type.FullName];
else
- return GetSymbolFromType(typesSymbolTable, type.FullName, rootNamespace);
+ return GetSymbolFromType(typesSymbolTable, type, rootNamespace);
}
public static string GetJsonFromField(string fieldName, Type fieldType)
@@ -495,16 +495,72 @@ private void GenerateInputOutput(IndentingTextWriter writer,
writer.WriteLine();
}
- static string GetSymbolFromType(Dictionary typesSymbolTable, string fullTypeName, string currentNamespace)
+ ///
+ /// This methods creates a unique name for a class/struct/enum, given a type and a namespace.
+ /// It generates the name based on the property of the type
+ /// (see description here https://msdn.microsoft.com/en-us/library/system.type.fullname(v=vs.110).aspx).
+ /// Example: Assume we have the following structure in namespace X.Y:
+ /// class A {
+ /// class B {
+ /// enum C {
+ /// Value1,
+ /// Value2
+ /// }
+ /// }
+ /// }
+ /// The full name of C would be X.Y.A+B+C. This method will generate the name "ABC" from it. In case
+ /// A is generic with one generic type, then the full name of typeof(A<float>.B.C) would be X.Y.A`1+B+C[[System.Single]].
+ /// In this case, this method will generate the name "ASingleBC".
+ ///
+ /// A dictionary containing the names of the classes already generated.
+ /// This parameter is only used to ensure that the newly generated name is unique.
+ /// The type for which to generate the new name.
+ /// The namespace prefix to the new name.
+ /// A unique name derived from the given type and namespace.
+ private static string GetSymbolFromType(Dictionary typesSymbolTable, Type type, string currentNamespace)
{
- var names = typesSymbolTable.Select(kvp => kvp.Value);
- char dim = fullTypeName.Contains('+') ? '+' : '.';
+ var fullTypeName = type.FullName;
string name = currentNamespace != "" ? currentNamespace + '.' : "";
- if (fullTypeName.Contains('+'))
- name += fullTypeName.Substring(0, fullTypeName.LastIndexOf('+')).Substring(fullTypeName.LastIndexOf('.') + 1);
+ int bracketIndex = fullTypeName.IndexOf('[');
+ Type[] genericTypes = null;
+ if (type.IsGenericType)
+ genericTypes = type.GetGenericArguments();
+ if (bracketIndex > 0)
+ {
+ Contracts.AssertValue(genericTypes);
+ fullTypeName = fullTypeName.Substring(0, bracketIndex);
+ }
+
+ // When the type is nested, the names of the outer types are concatenated with a '+'.
+ var nestedNames = fullTypeName.Split('+');
+ var baseName = nestedNames[0];
+
+ // We currently only handle generic types in the outer most class, support for generic inner classes
+ // can be added if needed.
+ int backTickIndex = baseName.LastIndexOf('`');
+ int dotIndex = baseName.LastIndexOf('.');
+ Contracts.Assert(dotIndex >= 0);
+ if (backTickIndex < 0)
+ name += baseName.Substring(dotIndex + 1);
+ else
+ {
+ name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1);
+ Contracts.AssertValue(genericTypes);
+ if (genericTypes != null)
+ {
+ foreach (var genType in genericTypes)
+ {
+ var splitNames = genType.FullName.Split('+');
+ if (splitNames[0].LastIndexOf('.') >= 0)
+ splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1);
+ name += string.Join("", splitNames);
+ }
+ }
+ }
- name += fullTypeName.Substring(fullTypeName.LastIndexOf(dim) + 1); ;
+ for (int i = 1; i < nestedNames.Length; i++)
+ name += nestedNames[i];
Contracts.Assert(typesSymbolTable.Select(kvp => kvp.Value).All(str => string.Compare(str, name) != 0));
@@ -538,7 +594,7 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu
var enumType = Enum.GetUnderlyingType(type);
- _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type.FullName, currentNamespace);
+ _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace);
if (enumType == typeof(int))
writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}");
else
@@ -623,7 +679,7 @@ private void GenerateStructs(IndentingTextWriter writer,
if (_typesSymbolTable.ContainsKey(type.FullName))
continue;
- _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type.FullName, currentNamespace);
+ _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace);
string classBase = "";
if (type.IsSubclassOf(typeof(OneToOneColumn)))
classBase = $" : OneToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn";
@@ -889,7 +945,7 @@ private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCata
writer.WriteLine("}");
}
- private static void GenerateInputFields(IndentingTextWriter writer,
+ private static void GenerateInputFields(IndentingTextWriter writer,
Type inputType, ModuleCatalog catalog, Dictionary typesSymbolTable, string rootNameSpace = "")
{
var defaults = Activator.CreateInstance(inputType);
@@ -936,7 +992,7 @@ private static void GenerateInputFields(IndentingTextWriter writer,
sweepableParamAttr.Name = fieldInfo.Name;
writer.WriteLine(sweepableParamAttr.ToString());
}
-
+
writer.Write($"public {inputTypeString} {GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}");
var defaultValue = GeneratorUtils.GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaults), typesSymbolTable, rootNameSpace);
if (defaultValue != null)
@@ -1013,16 +1069,16 @@ private void GenerateComponentKind(IndentingTextWriter writer, string kind)
private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.ComponentInfo component, ModuleCatalog catalog)
{
+ GenerateEnums(writer, component.ArgumentType, "Runtime");
+ writer.WriteLine();
+ GenerateStructs(writer, component.ArgumentType, catalog, "Runtime");
+ writer.WriteLine();
writer.WriteLine("/// ");
writer.WriteLine($"/// {component.Description}");
writer.WriteLine("/// ");
writer.WriteLine($"public sealed class {GeneratorUtils.GetComponentName(component)} : {component.Kind}");
writer.WriteLine("{");
writer.Indent();
- GenerateEnums(writer, component.ArgumentType, "");
- writer.WriteLine();
- GenerateStructs(writer, component.ArgumentType, catalog, "");
- writer.WriteLine();
GenerateInputFields(writer, component.ArgumentType, catalog, _typesSymbolTable, "Microsoft.ML.");
writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";");
writer.Outdent();