diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 454aba14611d6..4c19a1da631d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -764,7 +764,6 @@ public Collection createComponents(Client client, ClusterService cluster analyticsProcessFactory, dataFrameAnalyticsAuditor, trainedModelProvider, - modelLoadingService, resultsPersisterService, EsExecutors.allocatedProcessors(settings)); MemoryUsageEstimationProcessManager memoryEstimationProcessManager = @@ -773,8 +772,9 @@ public Collection createComponents(Client client, ClusterService cluster DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client, xContentRegistry, dataFrameAnalyticsAuditor); assert client instanceof NodeClient; - DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client, clusterService, - dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, indexNameExpressionResolver); + DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(settings, (NodeClient) client, threadPool, + clusterService, dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, + indexNameExpressionResolver, resultsPersisterService, modelLoadingService); this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager); // Components shared by anomaly detection and data frame analytics diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 93d1fc543d4fc..9f6b6894e7905 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -273,6 +273,7 @@ private void getStartContext(String id, Task task, ActionListener break; case RESUMING_REINDEXING: case RESUMING_ANALYZING: + case RESUMING_INFERENCE: toValidateMappingsListener.onResponse(startContext); break; case FINISHED: diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index acb1b27aa3a18..648d691ed3d84 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -19,20 +19,30 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MappingMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector; +import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory; +import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; import org.elasticsearch.xpack.ml.dataframe.steps.AnalysisStep; import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep; +import org.elasticsearch.xpack.ml.dataframe.steps.FinalStep; +import org.elasticsearch.xpack.ml.dataframe.steps.InferenceStep; import org.elasticsearch.xpack.ml.dataframe.steps.ReindexingStep; import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -43,27 +53,36 @@ public class DataFrameAnalyticsManager { private static final Logger LOGGER = LogManager.getLogger(DataFrameAnalyticsManager.class); + private final Settings settings; /** * We need a {@link NodeClient} to get the reindexing task and be able to report progress */ private final NodeClient client; + private final ThreadPool threadPool; private final ClusterService clusterService; private final DataFrameAnalyticsConfigProvider configProvider; private final AnalyticsProcessManager processManager; private final DataFrameAnalyticsAuditor auditor; private final IndexNameExpressionResolver expressionResolver; + private final ResultsPersisterService resultsPersisterService; + private final ModelLoadingService modelLoadingService; /** Indicates whether the node is shutting down. */ private final AtomicBoolean nodeShuttingDown = new AtomicBoolean(); - public DataFrameAnalyticsManager(NodeClient client, ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider, - AnalyticsProcessManager processManager, DataFrameAnalyticsAuditor auditor, - IndexNameExpressionResolver expressionResolver) { + public DataFrameAnalyticsManager(Settings settings, NodeClient client, ThreadPool threadPool, ClusterService clusterService, + DataFrameAnalyticsConfigProvider configProvider, AnalyticsProcessManager processManager, + DataFrameAnalyticsAuditor auditor, IndexNameExpressionResolver expressionResolver, + ResultsPersisterService resultsPersisterService, ModelLoadingService modelLoadingService) { + this.settings = Objects.requireNonNull(settings); this.client = Objects.requireNonNull(client); + this.threadPool = Objects.requireNonNull(threadPool); this.clusterService = Objects.requireNonNull(clusterService); this.configProvider = Objects.requireNonNull(configProvider); this.processManager = Objects.requireNonNull(processManager); this.auditor = Objects.requireNonNull(auditor); this.expressionResolver = Objects.requireNonNull(expressionResolver); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.modelLoadingService = Objects.requireNonNull(modelLoadingService); } public void execute(DataFrameAnalyticsTask task, ClusterState clusterState) { @@ -141,6 +160,12 @@ private void determineProgressAndResume(DataFrameAnalyticsTask task, DataFrameAn case RESUMING_ANALYZING: executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager)); break; + case RESUMING_INFERENCE: + buildInferenceStep(task, config, ActionListener.wrap( + inferenceStep -> executeStep(task, config, inferenceStep), + task::setFailed + )); + break; case FINISHED: default: task.setFailed(ExceptionsHelper.serverError("Unexpected starting state [" + startingState + "]")); @@ -162,7 +187,15 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager)); break; case ANALYSIS: - // This is the last step + buildInferenceStep(task, config, ActionListener.wrap( + inferenceStep -> executeStep(task, config, inferenceStep), + task::setFailed + )); + break; + case INFERENCE: + executeStep(task, config, new FinalStep(client, task, auditor, config)); + break; + case FINAL: LOGGER.info("[{}] Marking task completed", config.getId()); task.markAsCompleted(); break; @@ -199,6 +232,24 @@ private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFra )); } + private void buildInferenceStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, ActionListener listener) { + ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId()); + + ActionListener extractedFieldsDetectorListener = ActionListener.wrap( + extractedFieldsDetector -> { + ExtractedFields extractedFields = extractedFieldsDetector.detect().v1(); + InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, + resultsPersisterService, task.getParentTaskId(), config, extractedFields, task.getStatsHolder().getProgressTracker(), + task.getStatsHolder().getDataCountsTracker()); + InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner); + listener.onResponse(inferenceStep); + }, + listener::onFailure + ); + + new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(config, extractedFieldsDetectorListener); + } + public boolean isNodeShuttingDown() { return nodeShuttingDown.get(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java index 579a1750fe1ed..3d4222d730ea1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java @@ -287,7 +287,7 @@ public void updateTaskProgress(ActionListener updateProgressListener) { * {@code FINISHED} means the job had finished. */ public enum StartingState { - FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, FINISHED + FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, RESUMING_INFERENCE, FINISHED } public StartingState determineStartingState() { @@ -313,6 +313,9 @@ public static StartingState determineStartingState(String jobId, List new ParameterizedMessage("[{}] Data extractor was cancelled", context.jobId)); isCancelled = true; } @@ -127,7 +127,7 @@ private List tryRequestWithSearchResponse(Supplier request) // We've set allow_partial_search_results to false which means if something // goes wrong the request will throw. SearchResponse searchResponse = request.get(); - LOGGER.debug("[{}] Search response was obtained", context.jobId); + LOGGER.trace(() -> new ParameterizedMessage("[{}] Search response was obtained", context.jobId)); List rows = processSearchResponse(searchResponse); @@ -153,7 +153,7 @@ private SearchRequestBuilder buildSearchRequest() { long from = lastSortKey + 1; long to = from + context.scrollSize; - LOGGER.debug(() -> new ParameterizedMessage( + LOGGER.trace(() -> new ParameterizedMessage( "[{}] Searching docs with [{}] in [{}, {})", context.jobId, DestinationIndex.INCREMENTAL_ID, from, to)); SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) @@ -283,7 +283,7 @@ private Row createRow(SearchHit hit) { } boolean isTraining = trainTestSplitter.get().isTraining(extractedValues); Row row = new Row(extractedValues, hit, isTraining); - LOGGER.debug(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}", + LOGGER.trace(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}", context.jobId, row.getSortKey(), isTraining, Arrays.toString(row.values))); return row; } @@ -306,7 +306,7 @@ public DataSummary collectDataSummary() { SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder(); SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); long rows = searchResponse.getHits().getTotalHits().value; - LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows); + LOGGER.debug(() -> new ParameterizedMessage("[{}] Data summary rows [{}]", context.jobId, rows)); return new DataSummary(rows, organicFeatures.length + processedFeatures.length); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 51d81624aa3f5..cf0b93ef80872 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -10,21 +10,14 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.admin.indices.refresh.RefreshAction; -import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; -import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -32,20 +25,17 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; -import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -69,7 +59,6 @@ public class AnalyticsProcessManager { private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); private final DataFrameAnalyticsAuditor auditor; private final TrainedModelProvider trainedModelProvider; - private final ModelLoadingService modelLoadingService; private final ResultsPersisterService resultsPersisterService; private final int numAllocatedProcessors; @@ -79,7 +68,6 @@ public AnalyticsProcessManager(Settings settings, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider, - ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) { this( @@ -90,7 +78,6 @@ public AnalyticsProcessManager(Settings settings, analyticsProcessFactory, auditor, trainedModelProvider, - modelLoadingService, resultsPersisterService, numAllocatedProcessors); } @@ -103,7 +90,6 @@ public AnalyticsProcessManager(Settings settings, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider, - ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) { this.settings = Objects.requireNonNull(settings); @@ -113,7 +99,6 @@ public AnalyticsProcessManager(Settings settings, this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.auditor = Objects.requireNonNull(auditor); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); - this.modelLoadingService = Objects.requireNonNull(modelLoadingService); this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); this.numAllocatedProcessors = numAllocatedProcessors; } @@ -183,7 +168,6 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont LOGGER.info("[{}] Started loading data", processContext.config.getId()); auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA)); - ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId()); DataFrameAnalyticsConfig config = processContext.config; DataFrameDataExtractor dataExtractor = processContext.dataExtractor.get(); AnalyticsProcess process = processContext.process.get(); @@ -203,14 +187,6 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont resultProcessor.awaitForCompletion(); processContext.setFailureReason(resultProcessor.getFailure()); LOGGER.info("[{}] Result processor has completed", config.getId()); - - runInference(parentTaskClient, task, processContext, dataExtractor.getExtractedFields()); - - processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()), - DataCounts::documentId); - - refreshDest(parentTaskClient, config); - refreshIndices(parentTaskClient, config.getId()); } catch (Exception e) { if (task.isStopping()) { // Errors during task stopping are expected but we still want to log them just in case. @@ -338,43 +314,6 @@ private Consumer onProcessCrash(DataFrameAnalyticsTask task) { }; } - private void runInference(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsTask task, ProcessContext processContext, - ExtractedFields extractedFields) { - if (task.isStopping() || processContext.failureReason.get() != null) { - // If the task is stopping or there has been an error thus far let's not run inference at all - return; - } - - if (processContext.config.getAnalysis().supportsInference()) { - refreshDest(parentTaskClient, processContext.config); - InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService, - task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(), - task.getStatsHolder().getDataCountsTracker()); - processContext.setInferenceRunner(inferenceRunner); - inferenceRunner.run(processContext.resultProcessor.get().getLatestModelId()); - } - } - - private void refreshDest(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsConfig config) { - ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, parentTaskClient, - () -> parentTaskClient.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet()); - } - - private void refreshIndices(ParentTaskAssigningClient parentTaskClient, String jobId) { - RefreshRequest refreshRequest = new RefreshRequest( - AnomalyDetectorsIndex.jobStateIndexPattern(), - MlStatsIndex.indexPattern() - ); - refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); - - LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", - jobId, Arrays.toString(refreshRequest.indices()))); - - try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { - parentTaskClient.admin().indices().refresh(refreshRequest).actionGet(); - } - } - private void closeProcess(DataFrameAnalyticsTask task) { String configId = task.getParams().getId(); LOGGER.info("[{}] Closing process", configId); @@ -415,13 +354,10 @@ class ProcessContext { private final SetOnce> process = new SetOnce<>(); private final SetOnce dataExtractor = new SetOnce<>(); private final SetOnce resultProcessor = new SetOnce<>(); - private final SetOnce inferenceRunner = new SetOnce<>(); private final SetOnce failureReason = new SetOnce<>(); - private final StatsPersister statsPersister; ProcessContext(DataFrameAnalyticsConfig config) { this.config = Objects.requireNonNull(config); - this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); } String getFailureReason() { @@ -436,10 +372,6 @@ void setFailureReason(String failureReason) { this.failureReason.trySet(failureReason); } - void setInferenceRunner(InferenceRunner inferenceRunner) { - this.inferenceRunner.set(inferenceRunner); - } - synchronized void stop() { LOGGER.debug("[{}] Stopping process", config.getId()); if (dataExtractor.get() != null) { @@ -448,9 +380,6 @@ synchronized void stop() { if (resultProcessor.get() != null) { resultProcessor.get().cancel(); } - if (inferenceRunner.get() != null) { - inferenceRunner.get().cancel(); - } if (process.get() != null) { try { process.get().kill(true); @@ -507,6 +436,7 @@ private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask ta DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); + StatsPersister statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); return new AnalyticsResultProcessor( config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister, dataExtractor.get().getExtractedFields()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java index 691836d26a438..cb8a1966f9277 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java @@ -54,6 +54,9 @@ public ProgressTracker(List phaseProgresses) { assert progressPercentPerPhase.containsKey(REINDEXING); assert progressPercentPerPhase.containsKey(LOADING_DATA); assert progressPercentPerPhase.containsKey(WRITING_RESULTS); + // If there is inference it should be the last phase otherwise there + // are assumptions that do not hold. + assert progressPercentPerPhase.containsKey(INFERENCE) == false || INFERENCE.equals(phasesInOrder[phasesInOrder.length - 1]); } public void updateReindexingProgress(int progressPercent) { @@ -96,6 +99,32 @@ private void updatePhase(String phase, int progress) { progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress)); } + /** + * Resets progress to reflect all phases are complete except for inference + * which is set to zero. + */ + public void resetForInference() { + for (Map.Entry phaseProgress : progressPercentPerPhase.entrySet()) { + if (phaseProgress.getKey().equals(INFERENCE)) { + progressPercentPerPhase.put(phaseProgress.getKey(), 0); + } else { + progressPercentPerPhase.put(phaseProgress.getKey(), 100); + } + } + } + + /** + * Returns whether all phases before inference are complete + */ + public boolean areAllPhasesExceptInferenceComplete() { + for (Map.Entry phaseProgress : progressPercentPerPhase.entrySet()) { + if (phaseProgress.getKey().equals(INFERENCE) == false && phaseProgress.getValue() < 100) { + return false; + } + } + return true; + } + public List report() { return Arrays.stream(phasesInOrder) .map(phase -> new PhaseProgress(phase, progressPercentPerPhase.get(phase))) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index 1207821fff37e..1b74ee6ec47d1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -34,15 +34,30 @@ public void setProgressTracker(List progress) { progressTracker = new ProgressTracker(progress); } + /** + * Updates the progress tracker with potentially new in-between phases + * that were introduced in a later version while making sure progress indicators + * are correct. + * @param analysisPhases the full set of phases of the analysis in current version + * @param hasInferencePhase whether the analysis supports inference + */ public void adjustProgressTracker(List analysisPhases, boolean hasInferencePhase) { int reindexingProgressPercent = progressTracker.getReindexingProgressPercent(); + boolean areAllPhasesBeforeInferenceComplete = progressTracker.areAllPhasesExceptInferenceComplete(); progressTracker = ProgressTracker.fromZeroes(analysisPhases, hasInferencePhase); // If reindexing progress was more than 0 and less than 100 (ie not complete) we reset it to 1 // as we will have to do reindexing from scratch and at the same time we want // to differentiate from a job that has never started before. - progressTracker.updateReindexingProgress( - (reindexingProgressPercent > 0 && reindexingProgressPercent < 100) ? 1 : reindexingProgressPercent); + if (reindexingProgressPercent > 0 && reindexingProgressPercent < 100) { + progressTracker.updateReindexingProgress(1); + } else { + progressTracker.updateReindexingProgress(reindexingProgressPercent); + } + + if (hasInferencePhase && areAllPhasesBeforeInferenceComplete) { + progressTracker.resetForInference(); + } } public void resetProgressTracker(List analysisPhases, boolean hasInferencePhase) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java index 150745a624dc2..db9db30c2b0d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java @@ -58,7 +58,7 @@ protected TaskId getParentTaskId() { public final void execute(ActionListener listener) { logger.debug(() -> new ParameterizedMessage("[{}] Executing step [{}]", config.getId(), name())); if (task.isStopping()) { - logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}]", config.getId(), name())); + logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}] step", config.getId(), name())); listener.onResponse(new StepResponse(true)); return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java index c5883a4bb3e3c..78d31c686a9ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java @@ -14,7 +14,7 @@ public interface DataFrameAnalyticsStep { enum Name { - REINDEXING, ANALYSIS; + REINDEXING, ANALYSIS, INFERENCE, FINAL; @Override public String toString() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java new file mode 100644 index 0000000000000..1132b31c111d1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.steps; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.ParentTaskAssigningClient; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +/** + * The final step of a data frame analytics job. + * Allows the job to perform finalizing tasks like refresh indices, + * persist stats, etc. + */ +public class FinalStep extends AbstractDataFrameAnalyticsStep { + + private static final Logger LOGGER = LogManager.getLogger(FinalStep.class); + + public FinalStep(NodeClient client, DataFrameAnalyticsTask task, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsConfig config) { + super(client, task, auditor, config); + } + + @Override + public Name name() { + return Name.FINAL; + } + + @Override + protected void doExecute(ActionListener listener) { + + ActionListener refreshListener = ActionListener.wrap( + refreshResponse -> listener.onResponse(new StepResponse(true)), + listener::onFailure + ); + + ActionListener dataCountsIndexedListener = ActionListener.wrap( + indexResponse -> refreshIndices(refreshListener), + listener::onFailure + ); + + indexDataCounts(dataCountsIndexedListener); + } + + private void indexDataCounts(ActionListener listener) { + DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report(config.getId()); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + dataCounts.toXContent(builder, new ToXContent.MapParams( + Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias()) + .id(DataCounts.documentId(config.getId())) + .setRequireAlias(true) + .source(builder); + parentTaskClient().index(indexRequest, listener); + } catch (IOException e) { + listener.onFailure(ExceptionsHelper.serverError("[{}] Error persisting final data counts", e, config.getId())); + } + } + + private void refreshIndices(ActionListener listener) { + RefreshRequest refreshRequest = new RefreshRequest( + AnomalyDetectorsIndex.jobStateIndexPattern(), + MlStatsIndex.indexPattern(), + config.getDest().getIndex() + ); + refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); + + LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", config.getId(), + Arrays.toString(refreshRequest.indices()))); + + ParentTaskAssigningClient parentTaskClient = parentTaskClient(); + try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { + parentTaskClient.admin().indices().refresh(refreshRequest, listener); + } + } + + @Override + public void cancel(String reason, TimeValue timeout) { + // Not cancellable + } + + @Override + public void updateProgress(ActionListener listener) { + // No progress to update + listener.onResponse(null); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java new file mode 100644 index 0000000000000..381f6090c907a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.steps; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; + +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class InferenceStep extends AbstractDataFrameAnalyticsStep { + + private static final Logger LOGGER = LogManager.getLogger(InferenceStep.class); + + private final ThreadPool threadPool; + private final InferenceRunner inferenceRunner; + + public InferenceStep(NodeClient client, DataFrameAnalyticsTask task, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsConfig config, + ThreadPool threadPool, InferenceRunner inferenceRunner) { + super(client, task, auditor, config); + this.threadPool = Objects.requireNonNull(threadPool); + this.inferenceRunner = Objects.requireNonNull(inferenceRunner); + } + + @Override + public Name name() { + return Name.INFERENCE; + } + + @Override + protected void doExecute(ActionListener listener) { + if (config.getAnalysis().supportsInference() == false) { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] Inference step completed immediately as analysis does not support inference", config.getId())); + listener.onResponse(new StepResponse(false)); + return; + } + + ActionListener modelIdListener = ActionListener.wrap( + modelId -> runInference(modelId, listener), + listener::onFailure + ); + + ActionListener refreshDestListener = ActionListener.wrap( + refreshResponse -> getModelId(modelIdListener), + listener::onFailure + ); + + refreshDestAsync(refreshDestListener); + } + + private void runInference(String modelId, ActionListener listener) { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + try { + inferenceRunner.run(modelId); + listener.onResponse(new StepResponse(isTaskStopping())); + } catch (Exception e) { + if (task.isStopping()) { + listener.onResponse(new StepResponse(false)); + } else { + listener.onFailure(e); + } + } + }); + } + + private void getModelId(ActionListener listener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(1); + searchSourceBuilder.fetchSource(false); + searchSourceBuilder.query(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), config.getId())) + ); + searchSourceBuilder.sort(TrainedModelConfig.CREATE_TIME.getPreferredName(), SortOrder.DESC); + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN); + searchRequest.source(searchSourceBuilder); + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> { + SearchHit[] hits = searchResponse.getHits().getHits(); + if (hits.length == 0) { + listener.onFailure(new ResourceNotFoundException("No model could be found to perform inference")); + } else { + listener.onResponse(hits[0].getId()); + } + }, + listener::onFailure + )); + } + + @Override + public void cancel(String reason, TimeValue timeout) { + inferenceRunner.cancel(); + } + + @Override + public void updateProgress(ActionListener listener) { + // Inference runner updates progress directly + listener.onResponse(null); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java index be81e4c0b863b..aae85f27e65ef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java @@ -118,6 +118,18 @@ public void testDetermineStartingState_GivenWritingResultsIsIncomplete() { assertThat(startingState, equalTo(StartingState.RESUMING_ANALYZING)); } + public void testDetermineStartingState_GivenInferenceIsIncomplete() { + List progress = Arrays.asList(new PhaseProgress("reindexing", 100), + new PhaseProgress("loading_data", 100), + new PhaseProgress("analyzing", 100), + new PhaseProgress("writing_results", 100), + new PhaseProgress("inference", 40)); + + StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress); + + assertThat(startingState, equalTo(StartingState.RESUMING_INFERENCE)); + } + public void testDetermineStartingState_GivenFinished() { List progress = Arrays.asList(new PhaseProgress("reindexing", 100), new PhaseProgress("loading_data", 100), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 252d6bc10aada..77105f213ea26 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -113,9 +113,8 @@ public void setUpMocks() { when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); resultsPersisterService = mock(ResultsPersisterService.class); - modelLoadingService = mock(ModelLoadingService.class); processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess, - processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1); + processFactory, auditor, trainedModelProvider, resultsPersisterService, 1); } public void testRunJob_TaskIsStopping() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java index 2ecad9ed376d5..1110e31cbc513 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java @@ -8,10 +8,14 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -21,12 +25,16 @@ public class DataFrameAnalyticsManagerTests extends ESTestCase { public void testNodeShuttingDown() { DataFrameAnalyticsManager manager = new DataFrameAnalyticsManager( + Settings.EMPTY, mock(NodeClient.class), + mock(ThreadPool.class), mock(ClusterService.class), mock(DataFrameAnalyticsConfigProvider.class), mock(AnalyticsProcessManager.class), mock(DataFrameAnalyticsAuditor.class), - mock(IndexNameExpressionResolver.class)); + mock(IndexNameExpressionResolver.class), + mock(ResultsPersisterService.class), + mock(ModelLoadingService.class)); assertThat(manager.isNodeShuttingDown(), is(false)); manager.markNodeAsShuttingDown(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java index 8e52301ee0b0c..3b9150684df23 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java @@ -11,7 +11,9 @@ import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import static org.hamcrest.Matchers.contains; @@ -148,6 +150,75 @@ public void testUpdatePhase_GivenLowerValueThanCurrentProgress() { assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41)); } + public void testResetForInference_GivenInference() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b"), true); + progressTracker.updateReindexingProgress(10); + progressTracker.updateLoadingDataProgress(20); + progressTracker.updatePhase(new PhaseProgress("a", 30)); + progressTracker.updatePhase(new PhaseProgress("b", 40)); + progressTracker.updateWritingResultsProgress(50); + progressTracker.updateInferenceProgress(60); + + progressTracker.resetForInference(); + + List progress = progressTracker.report(); + assertThat(progress, contains( + new PhaseProgress(ProgressTracker.REINDEXING, 100), + new PhaseProgress(ProgressTracker.LOADING_DATA, 100), + new PhaseProgress("a", 100), + new PhaseProgress("b", 100), + new PhaseProgress(ProgressTracker.WRITING_RESULTS, 100), + new PhaseProgress(ProgressTracker.INFERENCE, 0) + )); + } + + public void testResetForInference_GivenNoInference() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b"), false); + progressTracker.updateReindexingProgress(10); + progressTracker.updateLoadingDataProgress(20); + progressTracker.updatePhase(new PhaseProgress("a", 30)); + progressTracker.updatePhase(new PhaseProgress("b", 40)); + progressTracker.updateWritingResultsProgress(50); + + progressTracker.resetForInference(); + + List progress = progressTracker.report(); + assertThat(progress, contains( + new PhaseProgress(ProgressTracker.REINDEXING, 100), + new PhaseProgress(ProgressTracker.LOADING_DATA, 100), + new PhaseProgress("a", 100), + new PhaseProgress("b", 100), + new PhaseProgress(ProgressTracker.WRITING_RESULTS, 100) + )); + } + + public void testAreAllPhasesExceptInferenceComplete_GivenComplete() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("a"), true); + progressTracker.updateReindexingProgress(100); + progressTracker.updateLoadingDataProgress(100); + progressTracker.updatePhase(new PhaseProgress("a", 100)); + progressTracker.updateWritingResultsProgress(100); + progressTracker.updateInferenceProgress(50); + + assertThat(progressTracker.areAllPhasesExceptInferenceComplete(), is(true)); + } + + public void testAreAllPhasesExceptInferenceComplete_GivenNotComplete() { + Map phasePerProgress = new LinkedHashMap<>(); + phasePerProgress.put(ProgressTracker.REINDEXING, 100); + phasePerProgress.put(ProgressTracker.LOADING_DATA, 100); + phasePerProgress.put("a", 100); + phasePerProgress.put(ProgressTracker.WRITING_RESULTS, 100); + String nonCompletePhase = randomFrom(phasePerProgress.keySet()); + phasePerProgress.put(ProgressTracker.INFERENCE, 50); + phasePerProgress.put(nonCompletePhase, randomIntBetween(0, 99)); + + ProgressTracker progressTracker = new ProgressTracker(phasePerProgress.entrySet().stream() + .map(entry -> new PhaseProgress(entry.getKey(), entry.getValue())).collect(Collectors.toList())); + + assertThat(progressTracker.areAllPhasesExceptInferenceComplete(), is(false)); + } + private static int getProgressForPhase(ProgressTracker progressTracker, String phase) { return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java index 3320250659f28..7a5cf841d582f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java @@ -123,6 +123,32 @@ public void testAdjustProgressTracker_GivenReindexingProgressIncomplete() { assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0)); } + public void testAdjustProgressTracker_GivenAllPhasesCompleteExceptInference() { + List phases = Collections.unmodifiableList( + Arrays.asList( + new PhaseProgress("reindexing", 100), + new PhaseProgress("loading_data", 100), + new PhaseProgress("a", 100), + new PhaseProgress("writing_results", 100), + new PhaseProgress("inference", 20) + ) + ); + StatsHolder statsHolder = new StatsHolder(phases); + + statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), true); + + List phaseProgresses = statsHolder.getProgressTracker().report(); + + assertThat(phaseProgresses, contains( + new PhaseProgress("reindexing", 100), + new PhaseProgress("loading_data", 100), + new PhaseProgress("a", 100), + new PhaseProgress("b", 100), + new PhaseProgress("writing_results", 100), + new PhaseProgress("inference", 0) + )); + } + public void testResetProgressTracker() { List phases = Collections.unmodifiableList( Arrays.asList(