Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ public class TrainedModelDefinition implements ToXContentObject {
true,
TrainedModelDefinition.Builder::new);
static {
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
PARSER.declareNamedObject(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter client side*/ },
TRAINED_MODEL);
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
Expand Down Expand Up @@ -124,11 +123,6 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
return this;
}

private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0));
}

public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ public class Ensemble implements TrainedModel {
p.namedObject(TrainedModel.class, n, null),
(ensembleBuilder) -> { /* Noop does not matter client side */ },
TRAINED_MODELS);
PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
PARSER.declareNamedObject(Ensemble.Builder::setOutputAggregator,
(p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
(ensembleBuilder) -> { /* Noop does not matter client side */ },
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
Expand Down Expand Up @@ -194,9 +193,6 @@ public Builder setClassificationWeights(List<Double> classificationWeights) {
return this;
}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
this.setOutputAggregator(outputAggregators.get(0));
}

private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
ignoreUnknownFields,
TrainedModelDefinition.Builder::builderForParser);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
parser.declareNamedObject(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ private static ObjectParser<Ensemble.Builder, Void> createParser(boolean lenient
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true),
TRAINED_MODELS);
parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
parser.declareNamedObject(Ensemble.Builder::setOutputAggregator,
(p, c, n) ->
lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) :
p.namedObject(StrictlyParsedOutputAggregator.class, n, null),
(ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/},
AGGREGATE_OUTPUT);
parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
Expand Down Expand Up @@ -414,14 +413,6 @@ public Builder setClassificationWeights(List<Double> classificationWeights) {
return this;
}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
if (outputAggregators.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
AGGREGATE_OUTPUT.getPreferredName());
}
this.setOutputAggregator(outputAggregators.get(0));
}

private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
}
Expand Down