From 3bb1ccda3c544e371669c483675625c123d5eb48 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 25 Aug 2020 15:06:51 -0400 Subject: [PATCH 1/6] [ML] adds new n_gram_encoding custom processor --- .../MlInferenceNamedXContentProvider.java | 3 + .../ml/inference/preprocessing/NGram.java | 213 +++++++++++++ .../client/RestHighLevelClientTests.java | 8 +- .../inference/preprocessing/NGramTests.java | 55 ++++ .../MlInferenceNamedXContentProvider.java | 19 +- .../ml/inference/preprocessing/NGram.java | 300 ++++++++++++++++++ .../inference/preprocessing/PreProcessor.java | 2 +- .../inference/preprocessing/NGramTests.java | 144 +++++++++ 8 files changed, 734 insertions(+), 10 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index 30c7598d72f9d..aab08237dbfd9 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; +import org.elasticsearch.client.ml.inference.preprocessing.NGram; import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; @@ -57,6 +58,8 @@ public List getNamedXContentParsers() { FrequencyEncoding::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME), CustomWordEmbedding::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME), + NGram::fromXContent)); // Model namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent)); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java new file mode 100644 index 0000000000000..f2bc9ee79c2f9 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java @@ -0,0 +1,213 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.preprocessing; + +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + + +/** + * PreProcessor for n-gram encoding a string + */ +public class NGram implements PreProcessor { + + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NGram.class); + public static final String NAME = "n_gram_encoding"; + public static final ParseField FIELD = new ParseField("field"); + public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix"); + public static final ParseField NGRAMS = new ParseField("n_grams"); + public static final ParseField START = new ParseField("start"); + public static final ParseField LENGTH = new ParseField("length"); + public static final ParseField CUSTOM = new ParseField("custom"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser( + NAME, + true, + a -> new NGram((String)a[0], + (List)a[1], + (Integer)a[2], + (Integer)a[3], + (Boolean)a[4], + (String)a[5])); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD); + PARSER.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), START); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH); + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX); + } + + public static NGram fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String field; + private final String featurePrefix; + private final List nGrams; + private final Integer start; + private final Integer length; + private final Boolean custom; + + NGram(String field, List nGrams, Integer start, Integer length, Boolean custom, String featurePrefix) { + this.field = field; + this.featurePrefix = featurePrefix; + this.nGrams = nGrams; + this.start = start; + this.length = length; + this.custom = custom; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (field != null) { + builder.field(FIELD.getPreferredName(), field); + } + if (featurePrefix != null) { + builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix); + } + if (nGrams != null) { + builder.field(NGRAMS.getPreferredName(), nGrams); + } + if (start != null) { + builder.field(START.getPreferredName(), start); + } + if (length != null) { + builder.field(LENGTH.getPreferredName(), length); + } + if (custom != null) { + builder.field(CUSTOM.getPreferredName(), custom); + } + builder.endObject(); + return builder; + } + + public String getField() { + return field; + } + + public String getFeaturePrefix() { + return featurePrefix; + } + + public List getnGrams() { + return nGrams; + } + + public Integer getStart() { + return start; + } + + public Integer getLength() { + return length; + } + + public Boolean getCustom() { + return custom; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NGram nGram = (NGram) o; + return Objects.equals(field, nGram.field) && + Objects.equals(featurePrefix, nGram.featurePrefix) && + Objects.equals(nGrams, nGram.nGrams) && + Objects.equals(start, nGram.start) && + Objects.equals(length, nGram.length) && + Objects.equals(custom, nGram.custom); + } + + @Override + public int hashCode() { + return Objects.hash(field, featurePrefix, start, length, custom, nGrams); + } + + public static Builder builder(String field) { + return new Builder(field); + } + + public static class Builder { + + private String field; + private String featurePrefix; + private List nGrams; + private Integer start; + private Integer length; + private Boolean custom; + + public Builder(String field) { + this.field = field; + } + + public Builder setField(String field) { + this.field = field; + return this; + } + + public Builder setCustom(boolean custom) { + this.custom = custom; + return this; + } + + public Builder setFeaturePrefix(String featurePrefix) { + this.featurePrefix = featurePrefix; + return this; + } + + public Builder setnGrams(List nGrams) { + this.nGrams = nGrams; + return this; + } + + public Builder setStart(Integer start) { + this.start = start; + return this; + } + + public Builder setLength(Integer length) { + this.length = length; + return this; + } + + public Builder setCustom(Boolean custom) { + this.custom = custom; + return this; + } + + public NGram build() { + return new NGram(field, nGrams, start, length, custom, featurePrefix); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 03065248efca7..b6d722301251f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -74,6 +74,7 @@ import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.client.ml.inference.preprocessing.NGram; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; @@ -704,7 +705,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(69, namedXContents.size()); + assertEquals(70, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -785,8 +786,9 @@ public void testProvidedNamedXContents() { registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, HuberMetric.NAME), registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); - assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); - assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME)); + assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); + assertThat(names, + hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME, NGram.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME)); assertEquals(Integer.valueOf(4), diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java new file mode 100644 index 0000000000000..477d01e9f57be --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.preprocessing; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + + +public class NGramTests extends AbstractXContentTestCase { + + @Override + protected NGram doParseInstance(XContentParser parser) throws IOException { + return NGram.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected NGram createTestInstance() { + return createRandom(); + } + + public static NGram createRandom() { + return new NGram(randomAlphaOfLength(10), + IntStream.range(1, 5).limit(5).boxed().collect(Collectors.toList()), + randomBoolean() ? null : randomIntBetween(0, 10), + randomBoolean() ? null : randomIntBetween(1, 10), + randomBoolean() ? null : randomBoolean(), + randomBoolean() ? null : randomAlphaOfLength(10)); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 2ba7f114e8b41..69a6b611affa8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -9,6 +9,13 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; @@ -39,12 +46,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import java.util.ArrayList; import java.util.List; @@ -64,6 +65,8 @@ public List getNamedXContentParsers() { (p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME, (p, c) -> CustomWordEmbedding.fromXContentLenient(p))); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, NGram.NAME, + (p, c) -> NGram.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); // PreProcessing Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME, @@ -74,6 +77,8 @@ public List getNamedXContentParsers() { (p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME, (p, c) -> CustomWordEmbedding.fromXContentStrict(p))); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, NGram.NAME, + (p, c) -> NGram.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); // Model Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); @@ -154,6 +159,8 @@ public List getNamedWriteables() { FrequencyEncoding::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, CustomWordEmbedding.NAME.getPreferredName(), CustomWordEmbedding::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(), + NGram::new)); // Model namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java new file mode 100644 index 0000000000000..e37ee8c65d884 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java @@ -0,0 +1,300 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.preprocessing; + +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.IntFunction; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.lucene.util.RamUsageEstimator.sizeOf; + +/** + * PreProcessor for n-gram encoding a string + */ +public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + + private static final int DEFAULT_START = 0; + private static final int DEFAULT_LENGTH = 50; + private static final int MIN_GRAM = 1; + private static final int MAX_GRAM = 5; + + private static String defaultPrefix(Integer start, Integer length) { + return "ngram_" + + (start == null ? DEFAULT_START : start) + + "_" + + (length == null ? DEFAULT_LENGTH : length); + } + + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NGram.class); + public static final ParseField NAME = new ParseField("n_gram_encoding"); + public static final ParseField FIELD = new ParseField("field"); + public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix"); + public static final ParseField NGRAMS = new ParseField("n_grams"); + public static final ParseField START = new ParseField("start"); + public static final ParseField LENGTH = new ParseField("length"); + public static final ParseField CUSTOM = new ParseField("custom"); + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + (a, c) -> new NGram((String)a[0], + (List)a[1], + (Integer)a[2], + (Integer)a[3], + a[4] == null ? c.isCustomByDefault() : (Boolean)a[4], + (String)a[5])); + parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); + parser.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), START); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH); + parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX); + return parser; + } + + public static NGram fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); + } + + public static NGram fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); + } + + private final String field; + private final String featurePrefix; + private final int[] nGrams; + private final int start; + private final int length; + private final boolean custom; + + NGram(String field, + List nGrams, + Integer start, + Integer length, + Boolean custom, + String featurePrefix) { + this(field, + featurePrefix == null ? defaultPrefix(start, length) : featurePrefix, + Sets.newHashSet(nGrams).stream().mapToInt(Integer::intValue).toArray(), + start == null ? DEFAULT_START : start, + length == null ? DEFAULT_LENGTH : length, + custom != null && custom); + } + + public NGram(String field, String featurePrefix, int[] nGrams, int start, int length, boolean custom) { + this.field = ExceptionsHelper.requireNonNull(field, FIELD); + this.featurePrefix = ExceptionsHelper.requireNonNull(featurePrefix, FEATURE_PREFIX); + this.nGrams = ExceptionsHelper.requireNonNull(nGrams, NGRAMS); + if (Arrays.stream(this.nGrams).anyMatch(i -> i < 1)) { + throw ExceptionsHelper.badRequestException( + "[{}] is invalid [{}]; minimum supported value is [{}]; maximum supported value is [{}]", + NGRAMS.getPreferredName(), + Arrays.stream(nGrams).mapToObj(String::valueOf).collect(Collectors.joining(", ")), + MIN_GRAM, + MAX_GRAM); + } + this.start = start; + if (start < 0 && length + start > 0) { + throw ExceptionsHelper.badRequestException( + "if [start] is negative, [length] + [start] must be less than 0"); + } + this.length = length; + if (length <= 0) { + throw ExceptionsHelper.badRequestException("[length] must be a positive integer"); + } + this.custom = custom; + } + + public NGram(StreamInput in) throws IOException { + this.field = in.readString(); + this.featurePrefix = in.readString(); + this.nGrams = in.readIntArray(); + this.start = in.readInt(); + this.length = in.readVInt(); + this.custom = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeString(featurePrefix); + out.writeIntArray(nGrams); + out.writeInt(start); + out.writeVInt(length); + out.writeBoolean(custom); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public List inputFields() { + return Collections.singletonList(field); + } + + @Override + public List outputFields() { + return allPossibleNGramOutputFeatureNames(); + } + + @Override + public void process(Map fields) { + Object value = fields.get(field); + if (value == null) { + return; + } + final String stringValue = value.toString(); + // String is too small for the starting point + if (start > stringValue.length() || stringValue.length() + start < 0) { + return; + } + final int startPos = start < 0 ? (stringValue.length() + start) : start; + final int len = Math.min(startPos + length, stringValue.length()); + for (int i = 0; i < len; i++) { + for (int nGram : nGrams) { + if (startPos + i + nGram - 1 >= len) { + break; + } + fields.put(nGramFeature(nGram, i), stringValue.substring(startPos + i, startPos + i + nGram)); + } + } + } + + @Override + public Map reverseLookup() { + return outputFields().stream().collect(Collectors.toMap(Function.identity(), ignored -> field)); + } + + @Override + public String getOutputFieldType(String outputField) { + return TextFieldMapper.CONTENT_TYPE; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += sizeOf(field); + size += sizeOf(featurePrefix); + size += sizeOf(nGrams); + return size; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD.getPreferredName(), field); + builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix); + builder.field(NGRAMS.getPreferredName(), nGrams); + builder.field(START.getPreferredName(), start); + builder.field(LENGTH.getPreferredName(), length); + builder.field(CUSTOM.getPreferredName(), custom); + builder.endObject(); + return builder; + } + + public String getField() { + return field; + } + + public String getFeaturePrefix() { + return featurePrefix; + } + + public int[] getnGrams() { + return nGrams; + } + + public int getStart() { + return start; + } + + public int getLength() { + return length; + } + + @Override + public boolean isCustom() { + return custom; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NGram nGram = (NGram) o; + return start == nGram.start && + length == nGram.length && + custom == nGram.custom && + Objects.equals(field, nGram.field) && + Objects.equals(featurePrefix, nGram.featurePrefix) && + Arrays.equals(nGrams, nGram.nGrams); + } + + @Override + public int hashCode() { + int result = Objects.hash(field, featurePrefix, start, length, custom); + result = 31 * result + Arrays.hashCode(nGrams); + return result; + } + + private String nGramFeature(int nGram, int pos) { + return featurePrefix + + "." + + nGram + + pos; + } + + private List allPossibleNGramOutputFeatureNames() { + int totalNgrams = 0; + for (int nGram : nGrams) { + totalNgrams += (length - (nGram - 1)); + } + List ngramOutputs = new ArrayList<>(totalNgrams); + + for (int nGram : nGrams) { + IntFunction func = i -> nGramFeature(nGram, i); + IntStream.range(0, (length - (nGram - 1))).mapToObj(func).forEach(ngramOutputs::add); + } + return ngramOutputs; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java index 596664773704c..a32a51b6a04fb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java @@ -36,7 +36,7 @@ public boolean isCustomByDefault() { List inputFields(); /** - * @return The resulting output fields + * @return The resulting output fields. It is imperative that the order is consistent between calls. */ List outputFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java new file mode 100644 index 0000000000000..9fe21a4fa5c85 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.preprocessing; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.hamcrest.Matcher; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.contains; + + +public class NGramTests extends PreProcessingTests { + + @Override + protected NGram doParseInstance(XContentParser parser) throws IOException { + return lenient ? + NGram.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) : + NGram.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT); + } + + @Override + protected NGram createTestInstance() { + return createRandom(); + } + + public static NGram createRandom() { + return createRandom(randomBoolean() ? randomBoolean() : null); + } + + public static NGram createRandom(Boolean isCustom) { + return new NGram( + randomAlphaOfLength(10), + IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()), + randomBoolean() ? null : randomIntBetween(0, 10), + randomBoolean() ? null : randomIntBetween(1, 10), + isCustom, + randomBoolean() ? null : randomAlphaOfLength(10)); + } + + @Override + protected Writeable.Reader instanceReader() { + return NGram::new; + } + + public void testProcessNGramPrefix() { + String field = "text"; + String fieldValue = "this is the value"; + NGram encoding = new NGram(field, "f", new int[]{1, 4}, 0, 5, false); + Map fieldValues = randomFieldValues(field, fieldValue); + + Map> matchers = new HashMap<>(); + matchers.put("f.10", equalTo("t")); + matchers.put("f.11", equalTo("h")); + matchers.put("f.12", equalTo("i")); + matchers.put("f.13", equalTo("s")); + matchers.put("f.14", equalTo(" ")); + matchers.put("f.40", equalTo("this")); + matchers.put("f.41", equalTo("his ")); + testProcess(encoding, fieldValues, matchers); + } + + public void testProcessNGramSuffix() { + String field = "text"; + String fieldValue = "this is the value"; + + NGram encoding = new NGram(field, "f", new int[]{1, 3}, -3, 3, false); + Map fieldValues = randomFieldValues(field, fieldValue); + Map> matchers = new HashMap<>(); + matchers.put("f.10", equalTo("l")); + matchers.put("f.11", equalTo("u")); + matchers.put("f.12", equalTo("e")); + matchers.put("f.30", equalTo("lue")); + matchers.put("f.31", is(nullValue())); + testProcess(encoding, fieldValues, matchers); + } + + public void testProcessNGramInfix() { + String field = "text"; + String fieldValue = "this is the value"; + + NGram encoding = new NGram(field, "f", new int[]{1, 3}, 3, 3, false); + Map fieldValues = randomFieldValues(field, fieldValue); + Map> matchers = new HashMap<>(); + matchers.put("f.10", equalTo("s")); + matchers.put("f.11", equalTo(" ")); + matchers.put("f.12", equalTo("i")); + matchers.put("f.30", equalTo("s i")); + matchers.put("f.31", is(nullValue())); + testProcess(encoding, fieldValues, matchers); + } + + public void testProcessNGramLengthOverrun() { + String field = "text"; + String fieldValue = "this is the value"; + + NGram encoding = new NGram(field, "f", new int[]{1, 3}, 12, 10, false); + Map fieldValues = randomFieldValues(field, fieldValue); + Map> matchers = new HashMap<>(); + matchers.put("f.10", equalTo("v")); + matchers.put("f.11", equalTo("a")); + matchers.put("f.12", equalTo("l")); + matchers.put("f.13", equalTo("u")); + matchers.put("f.14", equalTo("e")); + matchers.put("f.30", equalTo("val")); + matchers.put("f.31", equalTo("alu")); + matchers.put("f.32", equalTo("lue")); + testProcess(encoding, fieldValues, matchers); + } + + public void testInputOutputFields() { + String field = randomAlphaOfLength(10); + NGram encoding = new NGram(field, "f", new int[]{1, 4}, 0, 5, false); + assertThat(encoding.inputFields(), containsInAnyOrder(field)); + assertThat(encoding.outputFields(), + contains("f.10", "f.11","f.12","f.13","f.14","f.40", "f.41")); + + encoding = new NGram(field, Arrays.asList(1, 4), 0, 5, false, null); + assertThat(encoding.inputFields(), containsInAnyOrder(field)); + assertThat(encoding.outputFields(), + contains( + "ngram_0_5.10", + "ngram_0_5.11", + "ngram_0_5.12", + "ngram_0_5.13", + "ngram_0_5.14", + "ngram_0_5.40", + "ngram_0_5.41")); + } + +} From 7de22cde934117533204ff21d5331166cc17862a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 27 Aug 2020 11:23:19 -0400 Subject: [PATCH 2/6] adding tests --- .../ml/inference/preprocessing/NGram.java | 6 +- .../DataFrameAnalysisCustomFeatureIT.java | 275 ++++++++++++++++++ ...NativeDataFrameAnalyticsIntegTestCase.java | 2 +- 3 files changed, 281 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java index e37ee8c65d884..b8b40d419b744 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java @@ -38,6 +38,7 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc private static final int DEFAULT_START = 0; private static final int DEFAULT_LENGTH = 50; + private static final int MAX_LENGTH = 100; private static final int MIN_GRAM = 1; private static final int MAX_GRAM = 5; @@ -128,7 +129,10 @@ public NGram(String field, String featurePrefix, int[] nGrams, int start, int le } this.length = length; if (length <= 0) { - throw ExceptionsHelper.badRequestException("[length] must be a positive integer"); + throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", LENGTH.getPreferredName()); + } + if (length > MAX_LENGTH) { + throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH); } this.custom = custom; } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java new file mode 100644 index 0000000000000..1aff4bda9f959 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java @@ -0,0 +1,275 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.everyItem; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.startsWith; + +public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + private static final String BOOLEAN_FIELD = "boolean-field"; + private static final String NUMERICAL_FIELD = "numerical-field"; + private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field"; + private static final String TEXT_FIELD = "text-field"; + private static final String KEYWORD_FIELD = "keyword-field"; + private static final String NESTED_FIELD = "outer-field.inner-field"; + private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field"; + private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field"; + private static final List BOOLEAN_FIELD_VALUES = List.of(false, true); + private static final List NUMERICAL_FIELD_VALUES = List.of(1.0, 2.0); + private static final List DISCRETE_NUMERICAL_FIELD_VALUES = List.of(10, 20); + private static final List KEYWORD_FIELD_VALUES = List.of("cat", "dog"); + + private String jobId; + private String sourceIndex; + private String destIndex; + private boolean analysisUsesExistingDestIndex; + + @Before + public void setupLogging() { + client().admin().cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder() + .put("logger.org.elasticsearch.xpack.ml.dataframe", "DEBUG") + .put("logger.org.elasticsearch.xpack.core.ml.inference", "DEBUG")) + .get(); + } + + @After + public void cleanup() { + cleanUp(); + client().admin().cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder() + .putNull("logger.org.elasticsearch.xpack.ml.dataframe") + .putNull("logger.org.elasticsearch.xpack.core.ml.inference")) + .get(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = new ArrayList<>(searchModule.getNamedXContents()); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(entries); + } + + public void testNGramCustomFeature() throws Exception { + initialize("test_ngram_feature_processor"); + String predictedClassField = NUMERICAL_FIELD + "_prediction"; + indexData(sourceIndex, 300, 50, NUMERICAL_FIELD); + + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setId(jobId) + .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, + QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()), null)) + .setDest(new DataFrameAnalyticsDest(destIndex, null)) + .setAnalysis(new Regression(NUMERICAL_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(6).build(), + null, + null, + 42L, + null, + null, + Collections.singletonList(new NGram(TEXT_FIELD, "f", new int[]{1, 2}, 0, 2, true)))) + .setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{})) + .build(); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgressIsZero(jobId); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + client().admin().indices().refresh(new RefreshRequest(destIndex)); + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getFieldValue(destDoc, "ml"); + @SuppressWarnings("unchecked") + List> importanceArray = (List>)resultsObject.get("feature_importance"); + assertThat(importanceArray.stream().map(m -> m.get("feature_name").toString()).collect(Collectors.toSet()), + everyItem(startsWith("f."))); + } + + assertProgressComplete(jobId); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertInferenceModelPersisted(jobId); + assertModelStatePersisted(stateDocId()); + } + + private void initialize(String jobId) { + initialize(jobId, false); + } + + private void initialize(String jobId, boolean isDatastream) { + this.jobId = jobId; + this.sourceIndex = jobId + "_source_index"; + this.destIndex = sourceIndex + "_results"; + this.analysisUsesExistingDestIndex = randomBoolean(); + createIndex(sourceIndex, isDatastream); + if (analysisUsesExistingDestIndex) { + createIndex(destIndex, false); + } + } + + private static void createIndex(String index, boolean isDatastream) { + String mapping = "{\n" + + " \"properties\": {\n" + + " \"@timestamp\": {\n" + + " \"type\": \"date\"\n" + + " }," + + " \""+ BOOLEAN_FIELD + "\": {\n" + + " \"type\": \"boolean\"\n" + + " }," + + " \""+ NUMERICAL_FIELD + "\": {\n" + + " \"type\": \"double\"\n" + + " }," + + " \""+ DISCRETE_NUMERICAL_FIELD + "\": {\n" + + " \"type\": \"integer\"\n" + + " }," + + " \""+ TEXT_FIELD + "\": {\n" + + " \"type\": \"text\"\n" + + " }," + + " \""+ KEYWORD_FIELD + "\": {\n" + + " \"type\": \"keyword\"\n" + + " }," + + " \""+ NESTED_FIELD + "\": {\n" + + " \"type\": \"keyword\"\n" + + " }," + + " \""+ ALIAS_TO_KEYWORD_FIELD + "\": {\n" + + " \"type\": \"alias\",\n" + + " \"path\": \"" + KEYWORD_FIELD + "\"\n" + + " }," + + " \""+ ALIAS_TO_NESTED_FIELD + "\": {\n" + + " \"type\": \"alias\",\n" + + " \"path\": \"" + NESTED_FIELD + "\"\n" + + " }" + + " }\n" + + " }"; + if (isDatastream) { + try { + createDataStreamAndTemplate(index, mapping); + } catch (IOException ex) { + throw new ElasticsearchException(ex); + } + } else { + client().admin().indices().prepareCreate(index) + .setMapping(mapping) + .get(); + } + } + + private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < numTrainingRows; i++) { + List source = List.of( + "@timestamp", "2020-12-12", + BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()), + NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()), + DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()), + TEXT_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()), + KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()), + NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())); + IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE); + bulkRequestBuilder.add(indexRequest); + } + for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) { + List source = new ArrayList<>(); + if (BOOLEAN_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()))); + } + if (NUMERICAL_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()))); + } + if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) { + source.addAll( + List.of(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()))); + } + if (TEXT_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(TEXT_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()))); + } + if (KEYWORD_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()))); + } + if (NESTED_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()))); + } + source.addAll(List.of("@timestamp", "2020-12-12")); + IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } + + private static Map getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc, hasKey(field)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + return destDoc; + } + + private String stateDocId() { + return jobId + "_classification_state#1"; + } + + private String expectedDestIndexAuditMessage() { + return (analysisUsesExistingDestIndex ? "Using existing" : "Creating") + " destination index [" + destIndex + "]"; + } + + @Override + boolean supportsInference() { + return true; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 4b45f96941f40..0a1ba7792a9bc 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -258,7 +258,7 @@ protected void assertInferenceModelPersisted(String jobId) { SearchResponse searchResponse = client().prepareSearch(InferenceIndexConstants.LATEST_INDEX_NAME) .setQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), jobId))) .get(); - assertThat("Hits were: " + Strings.toString(searchResponse.getHits()), searchResponse.getHits().getHits(), arrayWithSize(1)); + assertThat("Hits were: " + Strings.toString(searchResponse.getHits()), searchResponse.getHits().getHits(), arrayWithSize(2)); } protected Collection> analyticsTaskList() { From 9481ffe782b103d7ae3c72830a5a8decd7388299 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 27 Aug 2020 11:27:44 -0400 Subject: [PATCH 3/6] removing debug --- .../ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 0a1ba7792a9bc..4b45f96941f40 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -258,7 +258,7 @@ protected void assertInferenceModelPersisted(String jobId) { SearchResponse searchResponse = client().prepareSearch(InferenceIndexConstants.LATEST_INDEX_NAME) .setQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), jobId))) .get(); - assertThat("Hits were: " + Strings.toString(searchResponse.getHits()), searchResponse.getHits().getHits(), arrayWithSize(2)); + assertThat("Hits were: " + Strings.toString(searchResponse.getHits()), searchResponse.getHits().getHits(), arrayWithSize(1)); } protected Collection> analyticsTaskList() { From 6c4507b420a56bf334c1d2142a1b83a0e0c91771 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 27 Aug 2020 14:07:17 -0400 Subject: [PATCH 4/6] fixing test --- .../ml/integration/DataFrameAnalysisCustomFeatureIT.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java index 1aff4bda9f959..0eacab902d63d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java @@ -63,7 +63,6 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics private String jobId; private String sourceIndex; private String destIndex; - private boolean analysisUsesExistingDestIndex; @Before public void setupLogging() { @@ -148,7 +147,7 @@ private void initialize(String jobId, boolean isDatastream) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; this.destIndex = sourceIndex + "_results"; - this.analysisUsesExistingDestIndex = randomBoolean(); + boolean analysisUsesExistingDestIndex = randomBoolean(); createIndex(sourceIndex, isDatastream); if (analysisUsesExistingDestIndex) { createIndex(destIndex, false); @@ -261,11 +260,7 @@ private static Map getDestDoc(DataFrameAnalyticsConfig config, S } private String stateDocId() { - return jobId + "_classification_state#1"; - } - - private String expectedDestIndexAuditMessage() { - return (analysisUsesExistingDestIndex ? "Using existing" : "Creating") + " destination index [" + destIndex + "]"; + return jobId + "_regression_state#1"; } @Override From 3150398bdb0457d57851061febdd0773a88a62f3 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 3 Sep 2020 08:55:03 -0400 Subject: [PATCH 5/6] addressing pr comments --- .../client/ml/inference/preprocessing/NGram.java | 2 -- .../xpack/core/ml/inference/preprocessing/NGram.java | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java index f2bc9ee79c2f9..5fa829d9e67ea 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java @@ -18,7 +18,6 @@ */ package org.elasticsearch.client.ml.inference.preprocessing; -import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -34,7 +33,6 @@ */ public class NGram implements PreProcessor { - public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NGram.class); public static final String NAME = "n_gram_encoding"; public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java index b8b40d419b744..ba9737780cfbe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java @@ -114,7 +114,7 @@ public NGram(String field, String featurePrefix, int[] nGrams, int start, int le this.field = ExceptionsHelper.requireNonNull(field, FIELD); this.featurePrefix = ExceptionsHelper.requireNonNull(featurePrefix, FEATURE_PREFIX); this.nGrams = ExceptionsHelper.requireNonNull(nGrams, NGRAMS); - if (Arrays.stream(this.nGrams).anyMatch(i -> i < 1)) { + if (Arrays.stream(this.nGrams).anyMatch(i -> i < MIN_GRAM || i > MAX_GRAM)) { throw ExceptionsHelper.badRequestException( "[{}] is invalid [{}]; minimum supported value is [{}]; maximum supported value is [{}]", NGRAMS.getPreferredName(), @@ -140,7 +140,7 @@ public NGram(String field, String featurePrefix, int[] nGrams, int start, int le public NGram(StreamInput in) throws IOException { this.field = in.readString(); this.featurePrefix = in.readString(); - this.nGrams = in.readIntArray(); + this.nGrams = in.readVIntArray(); this.start = in.readInt(); this.length = in.readVInt(); this.custom = in.readBoolean(); @@ -150,7 +150,7 @@ public NGram(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeString(field); out.writeString(featurePrefix); - out.writeIntArray(nGrams); + out.writeVIntArray(nGrams); out.writeInt(start); out.writeVInt(length); out.writeBoolean(custom); @@ -186,7 +186,7 @@ public void process(Map fields) { final int len = Math.min(startPos + length, stringValue.length()); for (int i = 0; i < len; i++) { for (int nGram : nGrams) { - if (startPos + i + nGram - 1 >= len) { + if (startPos + i + nGram > len) { break; } fields.put(nGramFeature(nGram, i), stringValue.substring(startPos + i, startPos + i + nGram)); From 0baae5cbbb8a228530e2e1e80cb65a4e130dd6c0 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 3 Sep 2020 10:22:46 -0400 Subject: [PATCH 6/6] moving integration test --- .../xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename x-pack/plugin/ml/qa/native-multi-node-tests/src/{test => javaRestTest}/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java (100%) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java similarity index 100% rename from x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java rename to x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java