Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/88945.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 88945
summary: Address potential bug where trained models get stuck in starting after being
allocated to node
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,13 @@ void loadQueuedModels() {
} catch (Exception ex) {
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) {
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
logger.debug(() -> "[" + modelId + "] Start deployment failed as model was not found", ex);
handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex));
} else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) {
logger.trace(() -> "[" + modelId + "] Start deployment failed, will retry", ex);
logger.debug(() -> "[" + modelId + "] Start deployment failed, will retry", ex);
// A search phase execution failure should be retried, push task back to the queue
loadingToRetry.add(loadingTask);
} else {
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
handleLoadFailure(loadingTask, ex);
}
}
Expand Down Expand Up @@ -413,7 +412,7 @@ private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignment
for (TrainedModelAssignment assignment : modelsToUpdate) {
TrainedModelDeploymentTask task = modelIdToTask.get(assignment.getModelId());
if (task == null) {
logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", task.getModelId()));
logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", assignment.getModelId()));
continue;
}
RoutingInfo routingInfo = assignment.getNodeRoutingTable().get(nodeId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

Expand Down Expand Up @@ -149,34 +150,46 @@ private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<T
TrainedModelConfig modelConfig = getModelResponse.getResources().results().get(0);
processContext.modelInput.set(modelConfig.getInput());

assert modelConfig.getInferenceConfig() instanceof NlpConfig;
NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
task.init(nlpConfig);

SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
if (searchVocabResponse.getHits().getHits().length == 0) {
listener.onFailure(
new ResourceNotFoundException(
Messages.getMessage(
Messages.VOCABULARY_NOT_FOUND,
task.getModelId(),
VocabularyConfig.docId(modelConfig.getModelId())
if (modelConfig.getInferenceConfig()instanceof NlpConfig nlpConfig) {
task.init(nlpConfig);

SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
if (searchVocabResponse.getHits().getHits().length == 0) {
listener.onFailure(
new ResourceNotFoundException(
Messages.getMessage(
Messages.VOCABULARY_NOT_FOUND,
task.getModelId(),
VocabularyConfig.docId(modelConfig.getModelId())
)
)
)
);
return;
}

Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
NlpTask.Processor processor = nlpTask.createProcessor();
processContext.nlpTaskProcessor.set(processor);
// here, we are being called back on the searching thread, which MAY be a network thread
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
// executor.
executorServiceForDeployment.execute(
() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)
);
return;
}

Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
NlpTask.Processor processor = nlpTask.createProcessor();
processContext.nlpTaskProcessor.set(processor);
// here, we are being called back on the searching thread, which MAY be a network thread
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
// executor.
executorServiceForDeployment.execute(() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener));
}, listener::onFailure));
}, listener::onFailure));
} else {
listener.onFailure(
new IllegalArgumentException(
format(
"[%s] must be a pytorch model; found inference config of kind [%s]",
modelConfig.getModelId(),
modelConfig.getInferenceConfig().getWriteableName()
)
)
);
}
}, listener::onFailure);

executeAsyncWithOrigin(
Expand Down Expand Up @@ -404,10 +417,12 @@ private Consumer<String> onProcessCrash() {
}

void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
if (modelLocation instanceof IndexLocation) {
process.get().loadModel(task.getModelId(), ((IndexLocation) modelLocation).getIndexName(), stateStreamer, listener);
if (modelLocation instanceof IndexLocation indexLocation) {
process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
} else {
throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]");
listener.onFailure(
new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]")
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
Expand All @@ -38,8 +38,10 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME;
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;

/**
* Searches for and emits {@link TrainedModelDefinitionDoc}s in
Expand Down Expand Up @@ -71,7 +73,7 @@ public ChunkedTrainedModelRestorer(
ExecutorService executorService,
NamedXContentRegistry xContentRegistry
) {
this.client = client;
this.client = new OriginSettingClient(client, ML_ORIGIN);
this.executorService = executorService;
this.xContentRegistry = xContentRegistry;
this.modelId = modelId;
Expand Down Expand Up @@ -122,7 +124,6 @@ public void restoreModelDefinition(

logger.debug("[{}] restoring model", modelId);
SearchRequest searchRequest = buildSearch(client, modelId, index, searchSize, null);

executorService.execute(() -> doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer));
}

Expand All @@ -132,8 +133,16 @@ private void doSearch(
Consumer<Boolean> successConsumer,
Consumer<Exception> errorConsumer
) {

executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
try {
assert Thread.currentThread().getName().contains(NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME)
|| Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
: format(
"Must execute from [%s] or [%s] but thread is [%s]",
NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME,
UTILITY_THREAD_POOL_NAME,
Thread.currentThread().getName()
);
SearchResponse searchResponse = client.search(searchRequest).actionGet();
if (searchResponse.getHits().getHits().length == 0) {
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
Expand Down Expand Up @@ -182,13 +191,13 @@ private void doSearch(
searchRequestBuilder.searchAfter(new Object[] { lastHit.getIndex(), lastNum });
executorService.execute(() -> doSearch(searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer));
}
}, e -> {
} catch (Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
} else {
errorConsumer.accept(e);
}
}));
}
}

private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {
Expand Down