Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ public static FeatureImportance forRegression(String featureName, double importa
return new FeatureImportance(featureName, importance, null);
}

public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
importance,
classImportance);
}

public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
Expand Down Expand Up @@ -170,27 +176,27 @@ private static List<ClassImportance> fromMap(Map<String, Double> classImportance
}

private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
}

public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final String className;
private final Object className;
private final double importance;

public ClassImportance(String className, double importance) {
public ClassImportance(Object className, double importance) {
this.className = className;
this.importance = importance;
}

public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString();
this.className = in.readGenericValue();
this.importance = in.readDouble();
}

public String getClassName() {
public Object getClassName() {
return className;
}

Expand All @@ -207,7 +213,7 @@ public Map<String, Object> toMap() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeGenericValue(className);
out.writeDouble(importance);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
Expand Down Expand Up @@ -129,21 +130,46 @@ public static Map<String, double[]> decodeFeatureImportances(Map<String, String>
return originalFeatureImportance;
}

public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
@Nullable List<String> classificationLabels) {
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
return importances;
}

public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
final int predictedValue,
@Nullable List<String> classificationLabels,
@Nullable PredictionFieldType predictionFieldType) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression
// This indicates logistic regression (binary classification)
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
assert predictedValue == 1 || predictedValue == 0;
// If predicted value is `1`, then the other class is `0`
// If predicted value is `0`, then the other class is `1`
final int otherClass = 1 - predictedValue;
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
importances.add(FeatureImportance.forBinaryClassification(k,
v[0],
Arrays.asList(
new FeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
v[0]),
new FeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)otherClass, otherLabel),
-v[0])
)));
} else {
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
String label = classificationLabels == null ? null : classificationLabels.get(i);
classImportance.add(new FeatureImportance.ClassImportance(
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
fieldType.transformPredictedValue((double)i, label),
v[i]));
}
importances.add(FeatureImportance.forClassification(k, classImportance));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
Expand Down Expand Up @@ -154,14 +155,7 @@ private InferenceResults innerInfer(double[] features, InferenceConfig config, M
RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
addFeatureImportance(featureInfluence, inferenceResult);
}
}
double[] processed = outputAggregator.processValues(inferenceResults);
Expand All @@ -176,18 +170,22 @@ private InferenceResults innerInfer(double[] features, InferenceConfig config, M
InferenceResults result = model.infer(features, subModelInferenceConfig);
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
addFeatureImportance(featureInfluence, inferenceResult);
}
return featureInfluence;
}

private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
}

private InferenceResults buildResults(double[] processedInferences,
double[][] featureImportance,
Map<String, String> featureDecoderMap,
Expand All @@ -208,7 +206,7 @@ private InferenceResults buildResults(double[] processedInferences,
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
transformFeatureImportance(decodedFeatureImportance, null));
transformFeatureImportanceRegression(decodedFeatureImportance));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
Expand All @@ -220,10 +218,13 @@ private InferenceResults buildResults(double[] processedInferences,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
return new ClassificationInferenceResults((double)value.getValue(),
return new ClassificationInferenceResults(value.getValue(),
classificationLabel(topClasses.v1().getValue(), classificationLabels),
topClasses.v2(),
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
transformFeatureImportanceClassification(decodedFeatureImportance,
value.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,
value.getProbability(),
value.getScore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,17 @@ private InferenceResults buildResult(double[] value,
return new ClassificationInferenceResults(classificationValue.getValue(),
classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
classificationValue.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,
classificationValue.getProbability(),
classificationValue.getScore());
case REGRESSION:
return new RegressionInferenceResults(value[0],
config,
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
Expand Down Expand Up @@ -185,8 +187,17 @@ public static class ClassImportance implements ToXContentObject, Writeable {
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new ClassImportance((String)a[0], (Importance)a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
a -> new ClassImportance(a[0], (Importance)a[1]));
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
parser.declareObject(ConstructingObjectParser.constructorArg(),
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
IMPORTANCE);
Expand All @@ -197,22 +208,22 @@ public static ClassImportance fromXContent(XContentParser parser, boolean lenien
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}

public final String className;
public final Object className;
public final Importance importance;

public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString();
this.className = in.readGenericValue();
this.importance = new Importance(in);
}

ClassImportance(String className, Importance importance) {
ClassImportance(Object className, Importance importance) {
this.className = className;
this.importance = importance;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeGenericValue(className);
importance.writeTo(out);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;

import java.io.IOException;
Expand Down Expand Up @@ -154,10 +155,26 @@ public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws I

ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second"));
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
} else {
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
}
}
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
} else {
assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
}
}
}

public static String getClassificationDefinition(boolean customPreprocessor) {
Expand Down