Skip to content

Commit cc05a8c

Browse files
add wrapper
1 parent cb2e495 commit cc05a8c

File tree

5 files changed

+108
-4
lines changed

5 files changed

+108
-4
lines changed

src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
"Naive",
7373
"ForecastBySsa",
7474
"TextClassifcation",
75-
"SentenceSimilarity"
75+
"SentenceSimilarity",
76+
"ObjectDetection"
7677
]
7778
},
7879
"nugetDependencies": {
@@ -187,7 +188,15 @@
187188
"confidenceLevel",
188189
"variableHorizon",
189190
"modelFactory",
190-
"sentence1ColumnName"
191+
"sentence1ColumnName",
192+
"boundingBoxColumnName",
193+
"imageColumnName",
194+
"maxEpoch",
195+
"iOUThreshold",
196+
"scoreThreshold",
197+
"steps",
198+
"initLearningRate",
199+
"weightDecay"
191200
]
192201
},
193202
"argumentType": {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
{
2+
"$schema": "./search-space-schema.json#",
3+
"name": "object_detection_option",
4+
"search_space": [
5+
{
6+
"name": "LabelColumnName",
7+
"type": "string",
8+
"default": "Label"
9+
},
10+
{
11+
"name": "PredictedLabelColumnName",
12+
"type": "string",
13+
"default": "PredictedLabel"
14+
},
15+
{
16+
"name": "BoundingBoxColumnName",
17+
"type": "string",
18+
"default": "BoundingBox"
19+
},
20+
{
21+
"name": "ImageColumnName",
22+
"type": "string",
23+
"default": "Image"
24+
},
25+
{
26+
"name": "ScoreColumnName",
27+
"type": "string",
28+
"default": "Score"
29+
},
30+
{
31+
"name": "MaxEpoch",
32+
"type": "integer",
33+
"default": 10
34+
},
35+
{
36+
"name": "InitLearningRate",
37+
"type": "float",
38+
"default": 1.0
39+
},
40+
{
41+
"name": "WeightDecay",
42+
"type": "float",
43+
"default": 0.0
44+
}
45+
]
46+
}

src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@
145145
"matrix_factorization_option",
146146
"dnn_featurizer_image_option",
147147
"text_classification_option",
148-
"sentence_similarity_option"
148+
"sentence_similarity_option",
149+
"object_detection_option"
149150
]
150151
},
151152
"option_name": {
@@ -198,7 +199,16 @@
198199
"Epoch",
199200
"Architecture",
200201
"AddKeyValueAnnotationsAsText",
201-
"Arch"
202+
"Arch",
203+
"PredictedLabelColumnName",
204+
"BoundingBoxColumnName",
205+
"ImageColumnName",
206+
"IOUThreshold",
207+
"ScoreThreshold",
208+
"Steps",
209+
"MaxEpoch",
210+
"InitLearningRate",
211+
"WeightDecay"
202212
]
203213
},
204214
"option_type": {

src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,13 @@
525525
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
526526
"searchOption": "sentence_similarity_option"
527527
},
528+
{
529+
"functionName": "ObjectDetection",
530+
"estimatorTypes": [ "MultiClassification" ],
531+
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
532+
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
533+
"searchOption": "object_detection_option"
534+
},
528535
{
529536
"functionName": "ForecastBySsa",
530537
"estimatorTypes": [ "Forecasting" ],
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using Microsoft.ML.TorchSharp;
9+
using Microsoft.ML.TorchSharp.AutoFormerV2;
10+
11+
namespace Microsoft.ML.AutoML.CodeGen
12+
{
13+
internal partial class ObjectDetectionMulti
14+
{
15+
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ObjectDetectionOption param)
16+
{
17+
var option = new ObjectDetectionTrainer.Options
18+
{
19+
LabelColumnName = param.LabelColumnName,
20+
PredictedLabelColumnName = param.PredictedLabelColumnName,
21+
BoundingBoxColumnName = param.BoundingBoxColumnName,
22+
ImageColumnName = param.ImageColumnName,
23+
ScoreColumnName = param.ScoreColumnName,
24+
MaxEpoch = param.MaxEpoch,
25+
InitLearningRate = param.InitLearningRate,
26+
WeightDecay = param.WeightDecay,
27+
};
28+
29+
return context.MulticlassClassification.Trainers.ObjectDetection(option);
30+
}
31+
}
32+
}

0 commit comments

Comments
 (0)