diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml index 4c31f2b91f..b3eb3e61d2 100644 --- a/build/vsts-ci.yml +++ b/build/vsts-ci.yml @@ -100,7 +100,7 @@ stages: pool: vmImage: macOS-12 steps: - - script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && brew update && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula + - script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula displayName: Install build dependencies # Only build native assets to avoid conflicts. - script: ./build.sh -projects $(Build.SourcesDirectory)/src/Native/Native.proj -configuration $(BuildConfig) /p:TargetArchitecture=x64 /p:CopyPackageAssets=true diff --git a/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs index 0be90385e5..a56e43060d 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs @@ -58,7 +58,19 @@ namespace Microsoft.ML.TorchSharp.NasBert /// public class SentenceSimilarityTrainer : NasBertTrainer { - internal SentenceSimilarityTrainer(IHostEnvironment env, Options options) : base(env, options) + + public class SentenceSimilarityOptions : NasBertOptions + { + public SentenceSimilarityOptions() + { + BatchSize = 32; + MaxEpoch = 10; + TaskType = BertTaskType.SentenceRegression; + LearningRate = new List() { .0002 }; + WeightDecay = .01; + } + } + internal SentenceSimilarityTrainer(IHostEnvironment env, SentenceSimilarityOptions options) : base(env, options) { } @@ -71,7 +83,7 @@ internal SentenceSimilarityTrainer(IHostEnvironment env, int maxEpochs = 10, IDataView validationSet = null, BertArchitecture architecture = BertArchitecture.Roberta) : - this(env, new NasBertOptions + this(env, new SentenceSimilarityOptions { ScoreColumnName = scoreColumnName, Sentence1ColumnName = sentence1ColumnName, diff --git a/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs index b0c7672b6c..2ee0074308 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs @@ -60,7 +60,17 @@ namespace Microsoft.ML.TorchSharp.NasBert /// public class TextClassificationTrainer : NasBertTrainer { - internal TextClassificationTrainer(IHostEnvironment env, NasBertOptions options) : base(env, options) + public class TextClassificationOptions : NasBertTrainer.NasBertOptions + { + public TextClassificationOptions() + { + TaskType = BertTaskType.TextClassification; + BatchSize = 32; + MaxEpoch = 10; + } + } + + internal TextClassificationTrainer(IHostEnvironment env, TextClassificationOptions options) : base(env, options) { } @@ -74,7 +84,7 @@ internal TextClassificationTrainer(IHostEnvironment env, int maxEpochs = 10, IDataView validationSet = null, BertArchitecture architecture = BertArchitecture.Roberta) : - this(env, new NasBertOptions + this(env, new TextClassificationOptions { PredictionColumnName = predictionColumnName, ScoreColumnName = scoreColumnName, diff --git a/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs b/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs index 06d4350233..a0d9970ea1 100644 --- a/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs +++ b/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs @@ -59,7 +59,7 @@ public static TextClassificationTrainer TextClassification( /// public static TextClassificationTrainer TextClassification( this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - NasBertTrainer.NasBertOptions options) + TextClassificationTrainer.TextClassificationOptions options) => new TextClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options); /// @@ -99,7 +99,7 @@ public static SentenceSimilarityTrainer SentenceSimilarity( /// public static SentenceSimilarityTrainer SentenceSimilarity( this RegressionCatalog.RegressionTrainers catalog, - NasBertTrainer.NasBertOptions options) + SentenceSimilarityTrainer.SentenceSimilarityOptions options) => new SentenceSimilarityTrainer(CatalogUtils.GetEnvironment(catalog), options); diff --git a/test/Microsoft.ML.Tests/TextClassificationTests.cs b/test/Microsoft.ML.Tests/TextClassificationTests.cs index 92c0d4f2a4..0a00cb1047 100644 --- a/test/Microsoft.ML.Tests/TextClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TextClassificationTests.cs @@ -432,14 +432,12 @@ public void TestSentenceSimilarityLargeFileGpu() var dataSplit = ML.Data.TrainTestSplit(dataView, testFraction: 0.2); - var options = new NasBertTrainer.NasBertOptions() + var options = new SentenceSimilarityTrainer.SentenceSimilarityOptions() { TaskType = BertTaskType.SentenceRegression, Sentence1ColumnName = "search_term", Sentence2ColumnName = "product_description", LabelColumnName = "relevance", - LearningRate = new List() { .0002 }, - WeightDecay = .01 }; var estimator = ML.Regression.Trainers.SentenceSimilarity(options);