From bd0148feed7a40d8b13e417157c8b605fd73f4cd Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 9 May 2018 10:48:35 -0700 Subject: [PATCH 1/7] Handle generic types and types with multiple '+' signs in them. --- .../ComponentModel/ComponentCatalog.cs | 4 +- .../LogisticRegression/LogisticRegression.cs | 2 +- .../WhiteningTransform.cs | 2 +- src/Microsoft.ML/CSharpApi.cs | 116 +++++++++--------- .../Runtime/Experiment/Experiment.cs | 3 +- .../Internal/Tools/CSharpApiGenerator.cs | 79 +++++++++--- 6 files changed, 124 insertions(+), 82 deletions(-) diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index 28666e7f44..3a66ce27c8 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -342,8 +342,8 @@ private static bool ShouldSkipPath(string path) case "libiomp5md.dll": case "libvw.dll": case "matrixinterf.dll": - case "Microsoft.ML.neuralnetworks.gpucuda.dll": - case "Microsoft.ML.mklimports.dll": + case "Microsoft.MachineLearning.neuralnetworks.gpucuda.dll": + case "Microsoft.MachineLearning.mklimports.dll": case "microsoft.research.controls.decisiontrees.dll": case "Microsoft.ML.neuralnetworks.sse.dll": case "neuraltreeevaluator.dll": diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index f1ea7c4a4b..24426290fa 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -300,7 +300,7 @@ protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor. } catch (DllNotFoundException) { - throw ch.ExceptNotSupp("The MKL library (Microsoft.ML.MklImports.dll) or one of its dependencies is missing."); + throw ch.ExceptNotSupp("The MKL library (Microsoft.MachineLearning.MklImports.dll) or one of its dependencies is missing."); } Float[] stdErrorValues = new Float[numParams]; diff --git a/src/Microsoft.ML.Transforms/WhiteningTransform.cs b/src/Microsoft.ML.Transforms/WhiteningTransform.cs index 2ae3824ba8..4176eafae5 100644 --- a/src/Microsoft.ML.Transforms/WhiteningTransform.cs +++ b/src/Microsoft.ML.Transforms/WhiteningTransform.cs @@ -571,7 +571,7 @@ private static Float DotProduct(Float[] a, int aOffset, Float[] b, int[] indices private static class Mkl { - private const string DllName = "Microsoft.ML.MklImports.dll"; + private const string DllName = "Microsoft.MachineLearning.MklImports.dll"; public enum Layout { diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 46002c5abf..ecea73a495 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -12300,13 +12300,13 @@ namespace Runtime { public abstract class CalibratorTrainer : ComponentKind {} + + /// /// /// public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer { - - /// /// The slope parameter of f(x) = 1 / (1 + exp(-slope * x + offset) /// @@ -12320,45 +12320,45 @@ public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer internal override string ComponentName => "FixedPlattCalibrator"; } + + /// /// /// public sealed class NaiveCalibratorCalibratorTrainer : CalibratorTrainer { - - internal override string ComponentName => "NaiveCalibrator"; } + + /// /// /// public sealed class PavCalibratorCalibratorTrainer : CalibratorTrainer { - - internal override string ComponentName => "PavCalibrator"; } + + /// /// Platt calibration. /// public sealed class PlattCalibratorCalibratorTrainer : CalibratorTrainer { - - internal override string ComponentName => "PlattCalibrator"; } public abstract class ClassificationLossFunction : ComponentKind {} + + /// /// Exponential loss. /// public sealed class ExpLossClassificationLossFunction : ClassificationLossFunction { - - /// /// Beta (dilation) /// @@ -12367,13 +12367,13 @@ public sealed class ExpLossClassificationLossFunction : ClassificationLossFuncti internal override string ComponentName => "ExpLoss"; } + + /// /// Hinge loss. /// public sealed class HingeLossClassificationLossFunction : ClassificationLossFunction { - - /// /// Margin value /// @@ -12382,23 +12382,23 @@ public sealed class HingeLossClassificationLossFunction : ClassificationLossFunc internal override string ComponentName => "HingeLoss"; } + + /// /// Log loss. /// public sealed class LogLossClassificationLossFunction : ClassificationLossFunction { - - internal override string ComponentName => "LogLoss"; } + + /// /// Smoothed Hinge loss. /// public sealed class SmoothedHingeLossClassificationLossFunction : ClassificationLossFunction { - - /// /// Smoothing constant /// @@ -12409,13 +12409,13 @@ public sealed class SmoothedHingeLossClassificationLossFunction : Classification public abstract class EarlyStoppingCriterion : ComponentKind {} + + /// /// Stop in case of loss of generality. /// public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Threshold in range [0,1]. /// @@ -12425,13 +12425,13 @@ public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "GL"; } + + /// /// Stops in case of low progress. /// public sealed class LPEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Threshold in range [0,1]. /// @@ -12447,13 +12447,13 @@ public sealed class LPEarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "LP"; } + + /// /// Stops in case of generality to progress ration exceeds threshold. /// public sealed class PQEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Threshold in range [0,1]. /// @@ -12469,13 +12469,13 @@ public sealed class PQEarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "PQ"; } + + /// /// Stop if validation score exceeds threshold value. /// public sealed class TREarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Tolerance threshold. (Non negative value) /// @@ -12485,13 +12485,13 @@ public sealed class TREarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "TR"; } + + /// /// Stops in case of consecutive loss in generality. /// public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// The window size. /// @@ -12503,13 +12503,13 @@ public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion public abstract class FastTreeTrainer : ComponentKind {} + + /// /// Uses a logit-boost boosted tree learner to perform binary classification. /// public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTrainer { - - /// /// Should we use derivatives optimized for unbalanced sets /// @@ -12856,13 +12856,13 @@ public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTraine internal override string ComponentName => "FastTreeBinaryClassification"; } + + /// /// Trains gradient boosted decision trees to the LambdaRank quasi-gradient. /// public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer { - - /// /// Comma seperated list of gains associated to each relevance label. /// @@ -13244,13 +13244,13 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer internal override string ComponentName => "FastTreeRanking"; } + + /// /// Trains gradient boosted decision trees to fit target values using least-squares. /// public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer { - - /// /// Use best regression step trees? /// @@ -13592,13 +13592,13 @@ public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer internal override string ComponentName => "FastTreeRegression"; } + + /// /// Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. /// public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer { - - /// /// Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, and intermediate values are compound Poisson loss. /// @@ -13947,13 +13947,13 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer public abstract class NgramExtractor : ComponentKind {} + + /// /// Extracts NGrams from text and convert them to vector using dictionary. /// public sealed class NGramNgramExtractor : NgramExtractor { - - /// /// Ngram length /// @@ -13982,13 +13982,13 @@ public sealed class NGramNgramExtractor : NgramExtractor internal override string ComponentName => "NGram"; } + + /// /// Extracts NGrams from text and convert them to vector using hashing trick. /// public sealed class NGramHashNgramExtractor : NgramExtractor { - - /// /// Ngram length /// @@ -14029,45 +14029,45 @@ public sealed class NGramHashNgramExtractor : NgramExtractor public abstract class ParallelTraining : ComponentKind {} + + /// /// Single node machine learning process. /// public sealed class SingleParallelTraining : ParallelTraining { - - internal override string ComponentName => "Single"; } public abstract class RegressionLossFunction : ComponentKind {} + + /// /// Poisson loss. /// public sealed class PoissonLossRegressionLossFunction : RegressionLossFunction { - - internal override string ComponentName => "PoissonLoss"; } + + /// /// Squared loss. /// public sealed class SquaredLossRegressionLossFunction : RegressionLossFunction { - - internal override string ComponentName => "SquaredLoss"; } + + /// /// Tweedie loss. /// public sealed class TweedieLossRegressionLossFunction : RegressionLossFunction { - - /// /// Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, and intermediate values are compound Poisson loss. /// @@ -14078,13 +14078,13 @@ public sealed class TweedieLossRegressionLossFunction : RegressionLossFunction public abstract class SDCAClassificationLossFunction : ComponentKind {} + + /// /// Hinge loss. /// public sealed class HingeLossSDCAClassificationLossFunction : SDCAClassificationLossFunction { - - /// /// Margin value /// @@ -14093,23 +14093,23 @@ public sealed class HingeLossSDCAClassificationLossFunction : SDCAClassification internal override string ComponentName => "HingeLoss"; } + + /// /// Log loss. /// public sealed class LogLossSDCAClassificationLossFunction : SDCAClassificationLossFunction { - - internal override string ComponentName => "LogLoss"; } + + /// /// Smoothed Hinge loss. /// public sealed class SmoothedHingeLossSDCAClassificationLossFunction : SDCAClassificationLossFunction { - - /// /// Smoothing constant /// @@ -14120,25 +14120,25 @@ public sealed class SmoothedHingeLossSDCAClassificationLossFunction : SDCAClassi public abstract class SDCARegressionLossFunction : ComponentKind {} + + /// /// Squared loss. /// public sealed class SquaredLossSDCARegressionLossFunction : SDCARegressionLossFunction { - - internal override string ComponentName => "SquaredLoss"; } public abstract class StopWordsRemover : ComponentKind {} + + /// /// Remover with list of stopwords specified by the user. /// public sealed class CustomStopWordsRemover : StopWordsRemover { - - /// /// List of stopwords /// @@ -14147,13 +14147,13 @@ public sealed class CustomStopWordsRemover : StopWordsRemover internal override string ComponentName => "Custom"; } + + /// /// Remover with predefined list of stop words. /// public sealed class PredefinedStopWordsRemover : StopWordsRemover { - - internal override string ComponentName => "Predefined"; } diff --git a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs index aa5736bd14..9fb0560701 100644 --- a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs +++ b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.EntryPoints.JsonUtils; using Newtonsoft.Json; @@ -165,7 +164,7 @@ private string Serialize(string name, object input, object output) { using (var jw = new JsonTextWriter(sw)) { - jw.Formatting = Formatting.Indented; + jw.Formatting = Newtonsoft.Json.Formatting.Indented; _serializer.Serialize(jw, _helper); } return sw.ToString(); diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 17643518ca..9b1eedbc04 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -74,7 +74,7 @@ public static string GetOutputType(Type outputType) return $"Var<{GetCSharpTypeName(outputType)}>"; } - public static string GetInputType(ModuleCatalog catalog, Type inputType, + public static string GetInputType(ModuleCatalog catalog, Type inputType, Dictionary typesSymbolTable, string rootNameSpace = "") { if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Var<>)) @@ -136,13 +136,13 @@ public static string GetInputType(ModuleCatalog catalog, Type inputType, return $"{enumName}"; default: if (isNullable) - return rootNameSpace+typesSymbolTable[type.FullName]; + return rootNameSpace + typesSymbolTable[type.FullName]; if (isOptional) - return $"Optional<{rootNameSpace+typesSymbolTable[type.FullName]}>"; + return $"Optional<{rootNameSpace + typesSymbolTable[type.FullName]}>"; if (typesSymbolTable.ContainsKey(type.FullName)) return rootNameSpace + typesSymbolTable[type.FullName]; else - return GetSymbolFromType(typesSymbolTable, type.FullName, rootNameSpace); + return GetSymbolFromType(typesSymbolTable, type, rootNameSpace); } } @@ -177,7 +177,7 @@ public static string Capitalize(string s) return char.ToUpperInvariant(s[0]) + s.Substring(1); } - public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, + public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, Dictionary typesSymbolTable, string rootNameSpace = "") { if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Var<>)) @@ -299,7 +299,7 @@ public static string GetValue(ModuleCatalog catalog, Type fieldType, object fiel var properties = propertyBag.Count > 0 ? $" {{ {string.Join(", ", propertyBag)} }}" : ""; return $"new {GetComponentName(componentInfo)}(){properties}"; case TlcModule.DataKind.Unknown: - return $"new {rootNameSpace+typesSymbolTable[fieldType.FullName]}()"; + return $"new {rootNameSpace + typesSymbolTable[fieldType.FullName]}()"; default: return fieldValue.ToString(); } @@ -321,7 +321,7 @@ public static string GetEnumName(Type type, Dictionary typesSymb if (typesSymbolTable.ContainsKey(type.FullName)) return rootNamespace + typesSymbolTable[type.FullName]; else - return GetSymbolFromType(typesSymbolTable, type.FullName, rootNamespace); + return GetSymbolFromType(typesSymbolTable, type, rootNamespace); } public static string GetJsonFromField(string fieldName, Type fieldType) @@ -495,16 +495,59 @@ private void GenerateInputOutput(IndentingTextWriter writer, writer.WriteLine(); } - static string GetSymbolFromType(Dictionary typesSymbolTable, string fullTypeName, string currentNamespace) + static string GetSymbolFromType(Dictionary typesSymbolTable, Type type, string currentNamespace) { + var fullTypeName = type.FullName; var names = typesSymbolTable.Select(kvp => kvp.Value); char dim = fullTypeName.Contains('+') ? '+' : '.'; string name = currentNamespace != "" ? currentNamespace + '.' : ""; + Type[] genericTypes = null; + if (type.IsGenericType) + genericTypes = type.GetGenericArguments(); + if (fullTypeName.Contains('+')) - name += fullTypeName.Substring(0, fullTypeName.LastIndexOf('+')).Substring(fullTypeName.LastIndexOf('.') + 1); + { + int bracketIndex = fullTypeName.IndexOf('['); + if (bracketIndex > 0) + { + Contracts.AssertValue(genericTypes); + fullTypeName = fullTypeName.Substring(0, bracketIndex); + } + + var nestedNames = fullTypeName.Split('+'); + var baseName = nestedNames[0]; + //nestedNames = nestedNames.Select(n => n.IndexOf('`') < 0 ? n : n.Substring(0, n.IndexOf('`'))).ToArray(); - name += fullTypeName.Substring(fullTypeName.LastIndexOf(dim) + 1); ; + //var substr = baseName.Substring(0, fullTypeName.LastIndexOf('+')); + int backTickIndex = baseName.LastIndexOf('`'); + int dotIndex = baseName.LastIndexOf('.'); + if (backTickIndex < 0) + name += baseName.Substring(dotIndex + 1); + else + { + name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1); + Contracts.AssertValue(genericTypes); + if (genericTypes != null) + { + foreach (var genType in genericTypes) + { + var splitNames = genType.FullName.Split('+'); + if (splitNames[0].LastIndexOf('.') >= 0) + splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1); + name += string.Join("", splitNames); + } + } + } + + for (int i = 1; i < nestedNames.Length; i++) + name += nestedNames[i]; + } + else + { + int dimIndex = fullTypeName.LastIndexOf(dim); + name += fullTypeName.Substring(dimIndex + 1); + } Contracts.Assert(typesSymbolTable.Select(kvp => kvp.Value).All(str => string.Compare(str, name) != 0)); @@ -538,7 +581,7 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu var enumType = Enum.GetUnderlyingType(type); - _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type.FullName, currentNamespace); + _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace); if (enumType == typeof(int)) writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}"); else @@ -623,7 +666,7 @@ private void GenerateStructs(IndentingTextWriter writer, if (_typesSymbolTable.ContainsKey(type.FullName)) continue; - _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type.FullName, currentNamespace); + _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace); string classBase = ""; if (type.IsSubclassOf(typeof(OneToOneColumn))) classBase = $" : OneToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn"; @@ -889,7 +932,7 @@ private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCata writer.WriteLine("}"); } - private static void GenerateInputFields(IndentingTextWriter writer, + private static void GenerateInputFields(IndentingTextWriter writer, Type inputType, ModuleCatalog catalog, Dictionary typesSymbolTable, string rootNameSpace = "") { var defaults = Activator.CreateInstance(inputType); @@ -936,7 +979,7 @@ private static void GenerateInputFields(IndentingTextWriter writer, sweepableParamAttr.Name = fieldInfo.Name; writer.WriteLine(sweepableParamAttr.ToString()); } - + writer.Write($"public {inputTypeString} {GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}"); var defaultValue = GeneratorUtils.GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaults), typesSymbolTable, rootNameSpace); if (defaultValue != null) @@ -1013,16 +1056,16 @@ private void GenerateComponentKind(IndentingTextWriter writer, string kind) private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.ComponentInfo component, ModuleCatalog catalog) { + GenerateEnums(writer, component.ArgumentType, ""); + writer.WriteLine(); + GenerateStructs(writer, component.ArgumentType, catalog, ""); + writer.WriteLine(); writer.WriteLine("/// "); writer.WriteLine($"/// {component.Description}"); writer.WriteLine("/// "); writer.WriteLine($"public sealed class {GeneratorUtils.GetComponentName(component)} : {component.Kind}"); writer.WriteLine("{"); writer.Indent(); - GenerateEnums(writer, component.ArgumentType, ""); - writer.WriteLine(); - GenerateStructs(writer, component.ArgumentType, catalog, ""); - writer.WriteLine(); GenerateInputFields(writer, component.ArgumentType, catalog, _typesSymbolTable, "Microsoft.ML."); writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";"); writer.Outdent(); From 8886ab753778ba7108e9345a8883d78837267906 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 9 May 2018 11:22:54 -0700 Subject: [PATCH 2/7] Delete commented code. --- src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 9b1eedbc04..0e620b2983 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -517,9 +517,6 @@ static string GetSymbolFromType(Dictionary typesSymbolTable, Typ var nestedNames = fullTypeName.Split('+'); var baseName = nestedNames[0]; - //nestedNames = nestedNames.Select(n => n.IndexOf('`') < 0 ? n : n.Substring(0, n.IndexOf('`'))).ToArray(); - - //var substr = baseName.Substring(0, fullTypeName.LastIndexOf('+')); int backTickIndex = baseName.LastIndexOf('`'); int dotIndex = baseName.LastIndexOf('.'); if (backTickIndex < 0) From 01df6e8d62566a62fbe6fbf2ce6b7eb06499b5be Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 10 May 2018 11:06:39 -0700 Subject: [PATCH 3/7] Address PR comments. --- .../ComponentModel/ComponentCatalog.cs | 4 +- .../LogisticRegression/LogisticRegression.cs | 2 +- .../WhiteningTransform.cs | 2 +- .../Internal/Tools/CSharpApiGenerator.cs | 78 ++++++++++--------- test/Microsoft.ML.Tests/CSharpCodeGen.cs | 3 +- 5 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index 3a66ce27c8..28666e7f44 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -342,8 +342,8 @@ private static bool ShouldSkipPath(string path) case "libiomp5md.dll": case "libvw.dll": case "matrixinterf.dll": - case "Microsoft.MachineLearning.neuralnetworks.gpucuda.dll": - case "Microsoft.MachineLearning.mklimports.dll": + case "Microsoft.ML.neuralnetworks.gpucuda.dll": + case "Microsoft.ML.mklimports.dll": case "microsoft.research.controls.decisiontrees.dll": case "Microsoft.ML.neuralnetworks.sse.dll": case "neuraltreeevaluator.dll": diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 24426290fa..f1ea7c4a4b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -300,7 +300,7 @@ protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor. } catch (DllNotFoundException) { - throw ch.ExceptNotSupp("The MKL library (Microsoft.MachineLearning.MklImports.dll) or one of its dependencies is missing."); + throw ch.ExceptNotSupp("The MKL library (Microsoft.ML.MklImports.dll) or one of its dependencies is missing."); } Float[] stdErrorValues = new Float[numParams]; diff --git a/src/Microsoft.ML.Transforms/WhiteningTransform.cs b/src/Microsoft.ML.Transforms/WhiteningTransform.cs index 4176eafae5..2ae3824ba8 100644 --- a/src/Microsoft.ML.Transforms/WhiteningTransform.cs +++ b/src/Microsoft.ML.Transforms/WhiteningTransform.cs @@ -571,7 +571,7 @@ private static Float DotProduct(Float[] a, int aOffset, Float[] b, int[] indices private static class Mkl { - private const string DllName = "Microsoft.MachineLearning.MklImports.dll"; + private const string DllName = "Microsoft.ML.MklImports.dll"; public enum Layout { diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 0e620b2983..43f2cfbbb1 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -495,57 +495,59 @@ private void GenerateInputOutput(IndentingTextWriter writer, writer.WriteLine(); } - static string GetSymbolFromType(Dictionary typesSymbolTable, Type type, string currentNamespace) + /// + /// This methods creates a unique name for a class/struct/enum, given a type and a namespace. + /// It generates the name based on the property of the type + /// (see description here https://msdn.microsoft.com/en-us/library/system.type.fullname(v=vs.110).aspx). + /// + /// A dictionary containing the names of the classes already generated. + /// This parameter is only used to ensure that the newly generated name is unique. + /// The type for which to generate the new name. + /// The namespace prefix to the new name. + /// A unique name derived from the given type and namespace. + private static string GetSymbolFromType(Dictionary typesSymbolTable, Type type, string currentNamespace) { var fullTypeName = type.FullName; - var names = typesSymbolTable.Select(kvp => kvp.Value); - char dim = fullTypeName.Contains('+') ? '+' : '.'; string name = currentNamespace != "" ? currentNamespace + '.' : ""; + int bracketIndex = fullTypeName.IndexOf('['); Type[] genericTypes = null; if (type.IsGenericType) genericTypes = type.GetGenericArguments(); - - if (fullTypeName.Contains('+')) + if (bracketIndex > 0) { - int bracketIndex = fullTypeName.IndexOf('['); - if (bracketIndex > 0) - { - Contracts.AssertValue(genericTypes); - fullTypeName = fullTypeName.Substring(0, bracketIndex); - } + Contracts.AssertValue(genericTypes); + fullTypeName = fullTypeName.Substring(0, bracketIndex); + } - var nestedNames = fullTypeName.Split('+'); - var baseName = nestedNames[0]; - int backTickIndex = baseName.LastIndexOf('`'); - int dotIndex = baseName.LastIndexOf('.'); - if (backTickIndex < 0) - name += baseName.Substring(dotIndex + 1); - else + // When the type is nested, the names of the outer types are concatenated with a '+'. + var nestedNames = fullTypeName.Split('+'); + var baseName = nestedNames[0]; + // We currently only handle generic types in the outer most class, support for generic inner classes + // can be added if needed. + int backTickIndex = baseName.LastIndexOf('`'); + int dotIndex = baseName.LastIndexOf('.'); + if (backTickIndex < 0) + name += baseName.Substring(dotIndex + 1); + else + { + name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1); + Contracts.AssertValue(genericTypes); + if (genericTypes != null) { - name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1); - Contracts.AssertValue(genericTypes); - if (genericTypes != null) + foreach (var genType in genericTypes) { - foreach (var genType in genericTypes) - { - var splitNames = genType.FullName.Split('+'); - if (splitNames[0].LastIndexOf('.') >= 0) - splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1); - name += string.Join("", splitNames); - } + var splitNames = genType.FullName.Split('+'); + if (splitNames[0].LastIndexOf('.') >= 0) + splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1); + name += string.Join("", splitNames); } } - - for (int i = 1; i < nestedNames.Length; i++) - name += nestedNames[i]; - } - else - { - int dimIndex = fullTypeName.LastIndexOf(dim); - name += fullTypeName.Substring(dimIndex + 1); } + for (int i = 1; i < nestedNames.Length; i++) + name += nestedNames[i]; + Contracts.Assert(typesSymbolTable.Select(kvp => kvp.Value).All(str => string.Compare(str, name) != 0)); return name; @@ -1053,9 +1055,9 @@ private void GenerateComponentKind(IndentingTextWriter writer, string kind) private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.ComponentInfo component, ModuleCatalog catalog) { - GenerateEnums(writer, component.ArgumentType, ""); + GenerateEnums(writer, component.ArgumentType, "Runtime"); writer.WriteLine(); - GenerateStructs(writer, component.ArgumentType, catalog, ""); + GenerateStructs(writer, component.ArgumentType, catalog, "Runtime"); writer.WriteLine(); writer.WriteLine("/// "); writer.WriteLine($"/// {component.Description}"); diff --git a/test/Microsoft.ML.Tests/CSharpCodeGen.cs b/test/Microsoft.ML.Tests/CSharpCodeGen.cs index c647110702..964ebe3532 100644 --- a/test/Microsoft.ML.Tests/CSharpCodeGen.cs +++ b/test/Microsoft.ML.Tests/CSharpCodeGen.cs @@ -15,7 +15,8 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output) { } - [Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] + [Fact] + //[Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] public void GenerateCSharpAPI() { var cSharpAPIPath = Path.Combine(RootDir, @"src\\Microsoft.ML\\CSharpApi.cs"); From 0c95eed1bd8c31717863ff29d39ff83ac745f883 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 10 May 2018 11:12:21 -0700 Subject: [PATCH 4/7] Skip codegen unit test. --- test/Microsoft.ML.Tests/CSharpCodeGen.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Tests/CSharpCodeGen.cs b/test/Microsoft.ML.Tests/CSharpCodeGen.cs index 964ebe3532..c647110702 100644 --- a/test/Microsoft.ML.Tests/CSharpCodeGen.cs +++ b/test/Microsoft.ML.Tests/CSharpCodeGen.cs @@ -15,8 +15,7 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output) { } - [Fact] - //[Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] + [Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] public void GenerateCSharpAPI() { var cSharpAPIPath = Path.Combine(RootDir, @"src\\Microsoft.ML\\CSharpApi.cs"); From a3ddef9c4a5346bfdb0f741961f380d8386476c4 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 10 May 2018 16:49:12 -0700 Subject: [PATCH 5/7] Add example to comment. --- .../Runtime/Internal/Tools/CSharpApiGenerator.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 43f2cfbbb1..0a3c9db6cd 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -499,6 +499,18 @@ private void GenerateInputOutput(IndentingTextWriter writer, /// This methods creates a unique name for a class/struct/enum, given a type and a namespace. /// It generates the name based on the property of the type /// (see description here https://msdn.microsoft.com/en-us/library/system.type.fullname(v=vs.110).aspx). + /// Example: Assume we have the following structure in namespace X.Y: + /// class A { + /// class B { + /// enum C { + /// Value1, + /// Value2 + /// } + /// } + /// } + /// The full name of C would be X.Y.A+B+C. This method will generate the name "ABC" from it. In case + /// A is generic with one generic type, then the full name of typeof(A<float>.B.C) would be X.Y.A`1+B+C[[System.Single]]. + /// In this case, this method will generate the name "ASingleBC". /// /// A dictionary containing the names of the classes already generated. /// This parameter is only used to ensure that the newly generated name is unique. From 34f32af0ed7fc041b00aa6624d593782fad282a5 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 10 May 2018 16:52:55 -0700 Subject: [PATCH 6/7] Assert that the full type name has a dot in it. --- src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 0a3c9db6cd..73a38d83ab 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -539,6 +539,7 @@ private static string GetSymbolFromType(Dictionary typesSymbolTa // can be added if needed. int backTickIndex = baseName.LastIndexOf('`'); int dotIndex = baseName.LastIndexOf('.'); + Contracts.Assert(dotIndex >= 0); if (backTickIndex < 0) name += baseName.Substring(dotIndex + 1); else From 2eaae1430e5c589ba53674c93d4bb06c0bdfe1f7 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 11 May 2018 08:17:41 -0700 Subject: [PATCH 7/7] Trigger build. --- src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 73a38d83ab..f1e45fa446 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -535,6 +535,7 @@ private static string GetSymbolFromType(Dictionary typesSymbolTa // When the type is nested, the names of the outer types are concatenated with a '+'. var nestedNames = fullTypeName.Split('+'); var baseName = nestedNames[0]; + // We currently only handle generic types in the outer most class, support for generic inner classes // can be added if needed. int backTickIndex = baseName.LastIndexOf('`');