diff --git a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
index 79853140c9..68d731a6a1 100644
--- a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
+++ b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
@@ -71,7 +71,8 @@
"DnnFeaturizerImage",
"Naive",
"ForecastBySsa",
- "TextClassification"
+ "TextClassifcation",
+ "SentenceSimilarity"
]
},
"nugetDependencies": {
@@ -109,6 +110,7 @@
"Microsoft.ML.Vision",
"Microsoft.ML.Transforms.Image",
"Microsoft.ML.Trainers.FastTree",
+ "Microsoft.ML.TorchSharp",
"Microsoft.ML.Trainers.LightGbm"
]
}
diff --git a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
index c7ef37e7be..22432ebf80 100644
--- a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
+++ b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
@@ -126,7 +126,8 @@
"image_classification_option",
"matrix_factorization_option",
"dnn_featurizer_image_option",
- "text_classification_option"
+ "text_classification_option",
+ "sentence_similarity_option"
]
},
"option_name": {
diff --git a/src/Microsoft.ML.AutoML/CodeGen/sentence_similarity_search_space.json b/src/Microsoft.ML.AutoML/CodeGen/sentence_similarity_search_space.json
new file mode 100644
index 0000000000..0ad05a1858
--- /dev/null
+++ b/src/Microsoft.ML.AutoML/CodeGen/sentence_similarity_search_space.json
@@ -0,0 +1,40 @@
+{
+ "$schema": "./search-space-schema.json#",
+ "name": "sentence_similarity_option",
+ "search_space": [
+ {
+ "name": "LabelColumnName",
+ "type": "string",
+ "default": "Label"
+ },
+ {
+ "name": "Sentence1ColumnName",
+ "type": "string",
+ "default": "Sentence1"
+ },
+ {
+ "name": "Sentence2ColumnName",
+ "type": "string"
+ },
+ {
+ "name": "ScoreColumnName",
+ "type": "string",
+ "default": "Score"
+ },
+ {
+ "name": "BatchSize",
+ "type": "integer",
+ "default": 32
+ },
+ {
+ "name": "MaxEpochs",
+ "type": "integer",
+ "default": 10
+ },
+ {
+ "name": "Architecture",
+ "type": "bertArchitecture",
+ "default": "BertArchitecture.Roberta"
+ }
+ ]
+}
diff --git a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
index cfd6c55c07..fae486ca09 100644
--- a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
+++ b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
@@ -306,7 +306,7 @@
"argumentType": "boolean"
}
],
- "nugetDependencies": ["Microsoft.ML"],
+ "nugetDependencies": [ "Microsoft.ML" ],
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers" ],
"searchOption": "lbfgs_option"
},
@@ -514,20 +514,17 @@
{
"functionName": "TextClassification",
"estimatorTypes": [ "MultiClassification" ],
- "arguments": [
- {
- "argumentName": "labelColumnName",
- "argumentType": "string"
- },
- {
- "argumentName": "sentence1ColumnName",
- "argumentType": "string"
- }
- ],
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
"searchOption": "text_classification_option"
},
+ {
+ "functionName": "SentenceSimilarity",
+ "estimatorTypes": [ "Regression" ],
+ "nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
+ "usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
+ "searchOption": "sentence_similarity_option"
+ },
{
"functionName": "ForecastBySsa",
"estimatorTypes": [ "Forecasting" ],
diff --git a/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj b/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
index 1c03c24792..562ee0410e 100644
--- a/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
+++ b/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
@@ -66,7 +66,6 @@
-
diff --git a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/SentenceSimilarity.cs b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/SentenceSimilarity.cs
new file mode 100644
index 0000000000..24ae6cb3a3
--- /dev/null
+++ b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/SentenceSimilarity.cs
@@ -0,0 +1,27 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Text;
+using Microsoft.ML.TorchSharp;
+
+namespace Microsoft.ML.AutoML.CodeGen
+{
+ internal partial class SentenceSimilarityRegression
+ {
+ public override IEstimator BuildFromOption(MLContext context, SentenceSimilarityOption param)
+ {
+ return context.Regression.Trainers.SentenceSimilarity(
+ labelColumnName: param.LabelColumnName,
+ sentence1ColumnName: param.Sentence1ColumnName,
+ scoreColumnName: param.ScoreColumnName,
+ sentence2ColumnName: param.Sentence2ColumnName,
+ batchSize: param.BatchSize,
+ maxEpochs: param.MaxEpochs,
+ architecture: param.Architecture);
+ }
+ }
+}