Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public final class Messages {
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";

public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.integration;

import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionModule;
Expand Down Expand Up @@ -67,7 +66,6 @@
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {

private static final String BOOLEAN_FIELD = "boolean-field";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.ml.integration;

import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionModule;
import org.elasticsearch.action.DocWriteRequest;
Expand Down Expand Up @@ -45,7 +44,6 @@
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {

private static final String NUMERICAL_FEATURE_FIELD = "feature";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;

public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {

private TrainedModelProvider trainedModelProvider;

@Before
public void createComponents() throws Exception {
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
waitForMlTemplates();
}

public void testStoreModelViaChunkedPersister() throws IOException {
String modelId = "stored-chunked-model";
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
.setId(modelId)
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
.setDest(new DataFrameAnalyticsDest("my_dest", null))
.setAnalysis(new Regression("foo"))
.build();
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
String compressedDefinition = configBuilder.build().getCompressedDefinition();
int totalSize = compressedDefinition.length();
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);

ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
analyticsConfig,
new DataFrameAnalyticsAuditor(client(), "test-node"),
(ex) -> { throw new ElasticsearchException(ex); },
new ExtractedFields(extractedFieldList, Collections.emptyMap())
);

//Accuracy for size is not tested here
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
for (int i = 0; i < chunks.size(); i++) {
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
}

PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L));

PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);

TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
}

private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
long bytesUsed = definitionBuilder.build().ramBytesUsed();
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(bytesUsed)
.setEstimatedOperations(operations)
.setInput(TrainedModelInputTests.createRandomInput());
}

public static List<String> chunkStringWithSize(String str, int chunkSize) {
List<String> subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize);
for (int i = 0; i < str.length(); i += chunkSize) {
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
}
return subStrings;
}

@Override
public NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
Expand Down Expand Up @@ -157,8 +160,8 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception {
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
}

public void testGetTruncatedModelDefinition() throws Exception {
String modelId = "test-get-truncated-model-config";
public void testGetTruncatedModelDeprecatedDefinition() throws Exception {
String modelId = "test-get-truncated-legacy-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
Expand Down Expand Up @@ -196,6 +199,51 @@ public void testGetTruncatedModelDefinition() throws Exception {
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
}

public void testGetTruncatedModelDefinition() throws Exception {
String modelId = "test-get-truncated-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));

List<String> chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3);

List<TrainedModelDefinitionDoc.Builder> docBuilders = IntStream.range(0, chunks.size())
.mapToObj(i -> new TrainedModelDefinitionDoc.Builder()
.setDocNum(i)
.setCompressedString(chunks.get(i))
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
.setDefinitionLength(chunks.get(i).length())
.setEos(i == chunks.size() - 1)
.setModelId(modelId))
.collect(Collectors.toList());
boolean missingEos = randomBoolean();
docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false);
for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) {
TrainedModelDefinitionDoc doc = docBuilders.get(i).build();
try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(),
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) {
AtomicReference<IndexResponse> putDocHolder = new AtomicReference<>();
blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(xContentBuilder)
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
.execute(listener),
putDocHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(nullValue()));
}
}
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue()));
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
}

private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
Expand Down
Loading