Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
public class InferModelAction extends ActionType<InferModelAction.Response> {

public static final InferModelAction INSTANCE = new InferModelAction();
public static final String NAME = "cluster:admin/xpack/ml/infer";
public static final String NAME = "cluster:admin/xpack/ml/inference/infer";

private InferModelAction() {
super(NAME, Response::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,38 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class ClassificationConfig implements InferenceConfig {

public static final String NAME = "classification";

public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0;

public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);

private final int numTopClasses;

public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses);
}

public ClassificationConfig(Integer numTopClasses) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
}
Expand Down Expand Up @@ -78,4 +92,9 @@ public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.CLASSIFICATION.equals(targetType);
}

@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

Expand All @@ -13,4 +14,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {

boolean isTargetTypeSupported(TargetType targetType);

/**
* All nodes in the cluster must be at least this version
*/
Version getMinimalSupportedVersion();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining the need for this?

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class RegressionConfig implements InferenceConfig {

public static final String NAME = "regression";
private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0;

public static RegressionConfig fromMap(Map<String, Object> map) {
if (map.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
}
return new RegressionConfig();
}

public RegressionConfig() {
}
Expand Down Expand Up @@ -61,4 +72,9 @@ public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.REGRESSION.equals(targetType);
}

@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -26,6 +27,11 @@ public boolean isTargetTypeSupported(TargetType targetType) {
return true;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.CURRENT;
}

@Override
public String getWriteableName() {
return "null";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ public final class Messages {
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
"Failed to serialize the trained model [{0}] for storage";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";

public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,35 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.Collections;

import static org.hamcrest.Matchers.equalTo;

public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {

public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
}

public void testFromMap() {
ClassificationConfig expected = new ClassificationConfig(0);
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));

expected = new ClassificationConfig(3);
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
equalTo(expected));
}

public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}

@Override
protected ClassificationConfig createTestInstance() {
return randomClassificationConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.Collections;

import static org.hamcrest.Matchers.equalTo;

public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {

public static RegressionConfig randomRegressionConfig() {
return new RegressionConfig();
}

public void testFromMap() {
RegressionConfig expected = new RegressionConfig();
assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected));
}

public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}

@Override
protected RegressionConfig createTestInstance() {
return randomRegressionConfig();
Expand Down
Loading