Skip to content

Commit 61376a4

Browse files
Sentence Similarity (#6390)
1 parent 66b362a commit 61376a4

File tree

9 files changed

+1386
-841
lines changed

9 files changed

+1386
-841
lines changed

src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Microsoft.ML.TorchSharp.NasBert.Models
1313
{
1414
internal abstract class BaseModel : BaseModule
1515
{
16-
protected readonly TextClassificationTrainer.Options Options;
16+
protected readonly NasBertTrainer.Options Options;
1717
public BertTaskType HeadType => Options.TaskType;
1818

1919
protected readonly TransformerEncoder Encoder;
@@ -24,7 +24,7 @@ internal abstract class BaseModel : BaseModule
2424
public abstract BaseHead GetHead();
2525
#pragma warning restore CA1024 // Use properties where appropriate
2626

27-
protected BaseModel(TextClassificationTrainer.Options options, int padIndex, int symbolsCount)
27+
protected BaseModel(NasBertTrainer.Options options, int padIndex, int symbolsCount)
2828
: base(nameof(BaseModel))
2929
{
3030
Options = options ?? throw new ArgumentNullException(nameof(options));

src/Microsoft.ML.TorchSharp/NasBert/Models/TextClassificationModel.cs renamed to src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
namespace Microsoft.ML.TorchSharp.NasBert.Models
1111
{
12-
internal sealed class TextClassificationModel : BaseModel
12+
internal sealed class NasBertModel : BaseModel
1313
{
1414
private readonly PredictionHead _predictionHead;
1515

1616
public override BaseHead GetHead() => _predictionHead;
1717

18-
public TextClassificationModel(TextClassificationTrainer.Options options, int padIndex, int symbolsCount, int numClasses)
18+
public NasBertModel(NasBertTrainer.Options options, int padIndex, int symbolsCount, int numClasses)
1919
: base(options, padIndex, symbolsCount)
2020
{
2121
_predictionHead = new PredictionHead(

0 commit comments

Comments
 (0)