From 56823fc3a804099dc62d938a8a1d29d23fb1cc0c Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Wed, 26 Jul 2023 10:22:15 -0600 Subject: [PATCH] fixed mac build and minor torch sharp changes --- build/vsts-ci.yml | 2 +- .../NasBert/SentenceSimilarityTrainer.cs | 16 ++++++++++++++-- .../NasBert/TextClassificationTrainer.cs | 14 ++++++++++++-- src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs | 4 ++-- .../TextClassificationTests.cs | 4 +--- 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml index 4c31f2b91f..b3eb3e61d2 100644 --- a/build/vsts-ci.yml +++ b/build/vsts-ci.yml @@ -100,7 +100,7 @@ stages: pool: vmImage: macOS-12 steps: - - script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && brew update && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula + - script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula displayName: Install build dependencies # Only build native assets to avoid conflicts. - script: ./build.sh -projects $(Build.SourcesDirectory)/src/Native/Native.proj -configuration $(BuildConfig) /p:TargetArchitecture=x64 /p:CopyPackageAssets=true diff --git a/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs index 0be90385e5..a56e43060d 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs @@ -58,7 +58,19 @@ namespace Microsoft.ML.TorchSharp.NasBert /// public class SentenceSimilarityTrainer : NasBertTrainer { - internal SentenceSimilarityTrainer(IHostEnvironment env, Options options) : base(env, options) + + public class SentenceSimilarityOptions : NasBertOptions + { + public SentenceSimilarityOptions() + { + BatchSize = 32; + MaxEpoch = 10; + TaskType = BertTaskType.SentenceRegression; + LearningRate = new List() { .0002 }; + WeightDecay = .01; + } + } + internal SentenceSimilarityTrainer(IHostEnvironment env, SentenceSimilarityOptions options) : base(env, options) { } @@ -71,7 +83,7 @@ internal SentenceSimilarityTrainer(IHostEnvironment env, int maxEpochs = 10, IDataView validationSet = null, BertArchitecture architecture = BertArchitecture.Roberta) : - this(env, new NasBertOptions + this(env, new SentenceSimilarityOptions { ScoreColumnName = scoreColumnName, Sentence1ColumnName = sentence1ColumnName, diff --git a/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs index b0c7672b6c..2ee0074308 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs @@ -60,7 +60,17 @@ namespace Microsoft.ML.TorchSharp.NasBert /// public class TextClassificationTrainer : NasBertTrainer { - internal TextClassificationTrainer(IHostEnvironment env, NasBertOptions options) : base(env, options) + public class TextClassificationOptions : NasBertTrainer.NasBertOptions + { + public TextClassificationOptions() + { + TaskType = BertTaskType.TextClassification; + BatchSize = 32; + MaxEpoch = 10; + } + } + + internal TextClassificationTrainer(IHostEnvironment env, TextClassificationOptions options) : base(env, options) { } @@ -74,7 +84,7 @@ internal TextClassificationTrainer(IHostEnvironment env, int maxEpochs = 10, IDataView validationSet = null, BertArchitecture architecture = BertArchitecture.Roberta) : - this(env, new NasBertOptions + this(env, new TextClassificationOptions { PredictionColumnName = predictionColumnName, ScoreColumnName = scoreColumnName, diff --git a/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs b/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs index 06d4350233..a0d9970ea1 100644 --- a/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs +++ b/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs @@ -59,7 +59,7 @@ public static TextClassificationTrainer TextClassification( /// public static TextClassificationTrainer TextClassification( this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - NasBertTrainer.NasBertOptions options) + TextClassificationTrainer.TextClassificationOptions options) => new TextClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options); /// @@ -99,7 +99,7 @@ public static SentenceSimilarityTrainer SentenceSimilarity( /// public static SentenceSimilarityTrainer SentenceSimilarity( this RegressionCatalog.RegressionTrainers catalog, - NasBertTrainer.NasBertOptions options) + SentenceSimilarityTrainer.SentenceSimilarityOptions options) => new SentenceSimilarityTrainer(CatalogUtils.GetEnvironment(catalog), options); diff --git a/test/Microsoft.ML.Tests/TextClassificationTests.cs b/test/Microsoft.ML.Tests/TextClassificationTests.cs index 92c0d4f2a4..0a00cb1047 100644 --- a/test/Microsoft.ML.Tests/TextClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TextClassificationTests.cs @@ -432,14 +432,12 @@ public void TestSentenceSimilarityLargeFileGpu() var dataSplit = ML.Data.TrainTestSplit(dataView, testFraction: 0.2); - var options = new NasBertTrainer.NasBertOptions() + var options = new SentenceSimilarityTrainer.SentenceSimilarityOptions() { TaskType = BertTaskType.SentenceRegression, Sentence1ColumnName = "search_term", Sentence2ColumnName = "product_description", LabelColumnName = "relevance", - LearningRate = new List() { .0002 }, - WeightDecay = .01 }; var estimator = ML.Regression.Trainers.SentenceSimilarity(options);