From 67dc2439517321d4c70937c33c9e7681b42ad1c1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 29 Jul 2022 10:35:42 -0400 Subject: [PATCH 1/3] [ML] address potential bug where trained models get stuck in starting --- .../TrainedModelAssignmentNodeService.java | 7 +- .../deployment/DeploymentManager.java | 73 +++++++++++-------- .../ChunkedTrainedModelRestorer.java | 27 ++++--- 3 files changed, 65 insertions(+), 42 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 1d48f1d1f2297..8c46427f6d249 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -225,14 +225,13 @@ void loadQueuedModels() { } catch (Exception ex) { logger.warn(() -> "[" + modelId + "] Start deployment failed", ex); if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) { - logger.warn(() -> "[" + modelId + "] Start deployment failed", ex); + logger.debug(() -> "[" + modelId + "] Start deployment failed as model was not found", ex); handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex)); } else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) { - logger.trace(() -> "[" + modelId + "] Start deployment failed, will retry", ex); + logger.debug(() -> "[" + modelId + "] Start deployment failed, will retry", ex); // A search phase execution failure should be retried, push task back to the queue loadingToRetry.add(loadingTask); } else { - logger.warn(() -> "[" + modelId + "] Start deployment failed", ex); handleLoadFailure(loadingTask, ex); } } @@ -413,7 +412,7 @@ private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignment for (TrainedModelAssignment assignment : modelsToUpdate) { TrainedModelDeploymentTask task = modelIdToTask.get(assignment.getModelId()); if (task == null) { - logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", task.getModelId())); + logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", assignment.getModelId())); continue; } RoutingInfo routingInfo = assignment.getNodeRoutingTable().get(nodeId); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 0d917debe3d02..2b9e63cc2595d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -61,6 +61,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -149,34 +150,46 @@ private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener { - if (searchVocabResponse.getHits().getHits().length == 0) { - listener.onFailure( - new ResourceNotFoundException( - Messages.getMessage( - Messages.VOCABULARY_NOT_FOUND, - task.getModelId(), - VocabularyConfig.docId(modelConfig.getModelId()) + if (modelConfig.getInferenceConfig()instanceof NlpConfig nlpConfig) { + task.init(nlpConfig); + + SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId()); + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> { + if (searchVocabResponse.getHits().getHits().length == 0) { + listener.onFailure( + new ResourceNotFoundException( + Messages.getMessage( + Messages.VOCABULARY_NOT_FOUND, + task.getModelId(), + VocabularyConfig.docId(modelConfig.getModelId()) + ) ) - ) + ); + return; + } + + Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0)); + NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary); + NlpTask.Processor processor = nlpTask.createProcessor(); + processContext.nlpTaskProcessor.set(processor); + // here, we are being called back on the searching thread, which MAY be a network thread + // `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility + // executor. + executorServiceForDeployment.execute( + () -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener) ); - return; - } - - Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0)); - NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary); - NlpTask.Processor processor = nlpTask.createProcessor(); - processContext.nlpTaskProcessor.set(processor); - // here, we are being called back on the searching thread, which MAY be a network thread - // `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility - // executor. - executorServiceForDeployment.execute(() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)); - }, listener::onFailure)); + }, listener::onFailure)); + } else { + listener.onFailure( + new IllegalArgumentException( + format( + "[%s] must be an pytorch model found inference config of kind [%s]", + modelConfig.getModelId(), + modelConfig.getInferenceConfig().getWriteableName() + ) + ) + ); + } }, listener::onFailure); executeAsyncWithOrigin( @@ -404,10 +417,12 @@ private Consumer onProcessCrash() { } void loadModel(TrainedModelLocation modelLocation, ActionListener listener) { - if (modelLocation instanceof IndexLocation) { - process.get().loadModel(task.getModelId(), ((IndexLocation) modelLocation).getIndexName(), stateStreamer, listener); + if (modelLocation instanceof IndexLocation indexLocation) { + process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener); } else { - throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]"); + listener.onFailure( + new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]") + ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java index 40d0162e15911..2c440941b5224 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java @@ -10,11 +10,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.core.CheckedFunction; @@ -38,8 +38,10 @@ import java.util.concurrent.ExecutorService; import java.util.function.Consumer; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; /** * Searches for and emits {@link TrainedModelDefinitionDoc}s in @@ -71,7 +73,7 @@ public ChunkedTrainedModelRestorer( ExecutorService executorService, NamedXContentRegistry xContentRegistry ) { - this.client = client; + this.client = new OriginSettingClient(client, ML_ORIGIN); this.executorService = executorService; this.xContentRegistry = xContentRegistry; this.modelId = modelId; @@ -122,7 +124,6 @@ public void restoreModelDefinition( logger.debug("[{}] restoring model", modelId); SearchRequest searchRequest = buildSearch(client, modelId, index, searchSize, null); - executorService.execute(() -> doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer)); } @@ -132,8 +133,16 @@ private void doSearch( Consumer successConsumer, Consumer errorConsumer ) { - - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> { + try { + assert Thread.currentThread().getName().contains(NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME) + || Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME) + : format( + "Must execute from [%s] or [%s] but thread is [%s]", + NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME, + UTILITY_THREAD_POOL_NAME, + Thread.currentThread().getName() + ); + SearchResponse searchResponse = client.search(searchRequest).actionGet(); if (searchResponse.getHits().getHits().length == 0) { errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); return; @@ -182,13 +191,13 @@ private void doSearch( searchRequestBuilder.searchAfter(new Object[] { lastHit.getIndex(), lastNum }); executorService.execute(() -> doSearch(searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer)); } - }, e -> { + } catch (Exception e) { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); } else { errorConsumer.accept(e); } - })); + } } private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) { From 9c30dfaa786e4af80098281141cbedb6126256aa Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 29 Jul 2022 10:43:39 -0400 Subject: [PATCH 2/3] Update docs/changelog/88945.yaml --- docs/changelog/88945.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/88945.yaml diff --git a/docs/changelog/88945.yaml b/docs/changelog/88945.yaml new file mode 100644 index 0000000000000..a6cb5ed952d6d --- /dev/null +++ b/docs/changelog/88945.yaml @@ -0,0 +1,6 @@ +pr: 88945 +summary: Address potential bug where trained models get stuck in starting after being + allocated to node +area: Machine Learning +type: bug +issues: [] From be745a424a6825d68f7697d92896e558fc340fc5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 1 Aug 2022 07:46:21 -0400 Subject: [PATCH 3/3] Update x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java Co-authored-by: Dimitris Athanasiou --- .../xpack/ml/inference/deployment/DeploymentManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index ce92b95b7a5a4..4e6fe4fc0ca2e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -183,7 +183,7 @@ private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener