Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;

public class Classification implements DataFrameAnalysis {
Expand All @@ -49,6 +51,7 @@ public static Builder builder(String dependentVariable) {
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");

Expand All @@ -67,7 +70,8 @@ public static Builder builder(String dependentVariable) {
(String) a[7],
(Double) a[8],
(Integer) a[9],
(Long) a[10]));
(Long) a[10],
(ClassAssignmentObjective) a[11]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
Expand All @@ -81,6 +85,12 @@ public static Builder builder(String dependentVariable) {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return ClassAssignmentObjective.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
}

private final String dependentVariable;
Expand All @@ -92,13 +102,15 @@ public static Builder builder(String dependentVariable) {
private final Integer numTopFeatureImportanceValues;
private final String predictionFieldName;
private final Double trainingPercent;
private final ClassAssignmentObjective classAssignmentObjective;
private final Integer numTopClasses;
private final Long randomizeSeed;

private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
@Nullable ClassAssignmentObjective classAssignmentObjective) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
Expand All @@ -108,6 +120,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.classAssignmentObjective = classAssignmentObjective;
this.numTopClasses = numTopClasses;
this.randomizeSeed = randomizeSeed;
}
Expand Down Expand Up @@ -157,6 +170,10 @@ public Long getRandomizeSeed() {
return randomizeSeed;
}

public ClassAssignmentObjective getClassAssignmentObjective() {
return classAssignmentObjective;
}

public Integer getNumTopClasses() {
return numTopClasses;
}
Expand Down Expand Up @@ -192,6 +209,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (classAssignmentObjective != null) {
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
}
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
Expand All @@ -202,7 +222,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
}

@Override
Expand All @@ -220,14 +240,28 @@ public boolean equals(Object o) {
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(numTopClasses, that.numTopClasses);
&& Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
}

@Override
public String toString() {
return Strings.toString(this);
}

public enum ClassAssignmentObjective {
MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;

public static ClassAssignmentObjective fromString(String value) {
return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}

public static class Builder {
private String dependentVariable;
private Double lambda;
Expand All @@ -240,6 +274,7 @@ public static class Builder {
private Double trainingPercent;
private Integer numTopClasses;
private Long randomizeSeed;
private ClassAssignmentObjective classAssignmentObjective;

private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
Expand Down Expand Up @@ -295,9 +330,15 @@ public Builder setNumTopClasses(Integer numTopClasses) {
return this;
}

public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) {
this.classAssignmentObjective = classAssignmentObjective;
return this;
}

public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
classAssignmentObjective);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,8 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setRandomizeSeed(42L)
.setClassAssignmentObjective(
org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)
.setNumTopClasses(1)
.setLambda(1.0)
.setGamma(1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import org.elasticsearch.client.ml.datafeed.DatafeedStats;
import org.elasticsearch.client.ml.datafeed.DatafeedUpdate;
import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig;
import org.elasticsearch.client.ml.dataframe.Classification;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest;
Expand Down Expand Up @@ -2969,7 +2970,7 @@ public void testPutDataFrameAnalytics() throws Exception {
// end::put-data-frame-analytics-outlier-detection-customized

// tag::put-data-frame-analytics-classification
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
Expand All @@ -2979,7 +2980,8 @@ public void testPutDataFrameAnalytics() throws Exception {
.setPredictionFieldName("my_prediction_field_name") // <8>
.setTrainingPercent(50.0) // <9>
.setRandomizeSeed(1234L) // <10>
.setNumTopClasses(1) // <11>
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
.setNumTopClasses(1) // <12>
.build();
// end::put-data-frame-analytics-classification

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public static Classification randomClassification() {
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setRandomizeSeed(randomBoolean() ? null : randomLong())
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
<8> The name of the prediction field in the results object.
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<10> The seed to be used by the random generator that picks which rows are used in training.
<11> The number of top classes to be reported in the results. Defaults to 2.
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
<12> The number of top classes to be reported in the results. Defaults to 2.

===== Regression

Expand Down
4 changes: 4 additions & 0 deletions docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
(Optional, double)
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]

`analysis`.`classification`.`class_assignment_objective`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective]

`analysis`.`classification`.`num_top_classes`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes]
Expand Down
8 changes: 8 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=mode]
include::{docdir}/ml/ml-shared.asciidoc[tag=time-span]
end::chunking-config[]

tag::class-assignment-objective[]
Defines the objective to optimize when assigning class labels. Available
objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing
accuracy class labels are chosen to maximize the number of correct predictions.
When maximizing minimum recall labels are chosen to maximize the minimum recall
for any class. Defaults to maximize_minimum_recall.
end::class-assignment-objective[]

tag::custom-rules[]
An array of custom rule objects, which enable you to customize the way detectors
operate. For example, a rule may dictate to the detector conditions under which
Expand Down
Loading