From b4e409114d8f2b1127f19661e3fb33c0aaf8c82c Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 15 Jan 2021 17:24:54 +0200 Subject: [PATCH 1/3] [ML] Improve resuming a DFA job stopped during inference If a DFA job is stopped while in the inference phase, after resuming we should start inference immediately. However, this is currently not the case. Inference is tied in `AnalyticsProcessManager` and thus we start a process, load data, restore state, etc., until we get to start inference. This commit gets rid of this unnecessary delay by factoring inference out as an independent step and ensuring we can resume straight from that phase upon restarting a job. --- .../xpack/ml/MachineLearning.java | 6 +- ...ransportStartDataFrameAnalyticsAction.java | 1 + .../dataframe/DataFrameAnalyticsManager.java | 59 +++++++- .../ml/dataframe/DataFrameAnalyticsTask.java | 5 +- .../extractor/DataFrameDataExtractor.java | 10 +- .../process/AnalyticsProcessManager.java | 72 +--------- .../ml/dataframe/stats/ProgressTracker.java | 29 ++++ .../xpack/ml/dataframe/stats/StatsHolder.java | 19 ++- .../steps/AbstractDataFrameAnalyticsStep.java | 2 +- .../steps/DataFrameAnalyticsStep.java | 2 +- .../xpack/ml/dataframe/steps/FinalStep.java | 116 ++++++++++++++++ .../ml/dataframe/steps/InferenceStep.java | 127 ++++++++++++++++++ .../process/AnalyticsProcessManagerTests.java | 3 +- .../DataFrameAnalyticsManagerTests.java | 10 +- .../dataframe/stats/ProgressTrackerTests.java | 71 ++++++++++ .../ml/dataframe/stats/StatsHolderTests.java | 26 ++++ 16 files changed, 467 insertions(+), 91 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 454aba14611d6..4c19a1da631d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -764,7 +764,6 @@ public Collection createComponents(Client client, ClusterService cluster analyticsProcessFactory, dataFrameAnalyticsAuditor, trainedModelProvider, - modelLoadingService, resultsPersisterService, EsExecutors.allocatedProcessors(settings)); MemoryUsageEstimationProcessManager memoryEstimationProcessManager = @@ -773,8 +772,9 @@ public Collection createComponents(Client client, ClusterService cluster DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client, xContentRegistry, dataFrameAnalyticsAuditor); assert client instanceof NodeClient; - DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client, clusterService, - dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, indexNameExpressionResolver); + DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(settings, (NodeClient) client, threadPool, + clusterService, dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, + indexNameExpressionResolver, resultsPersisterService, modelLoadingService); this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager); // Components shared by anomaly detection and data frame analytics diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 93d1fc543d4fc..9f6b6894e7905 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -273,6 +273,7 @@ private void getStartContext(String id, Task task, ActionListener break; case RESUMING_REINDEXING: case RESUMING_ANALYZING: + case RESUMING_INFERENCE: toValidateMappingsListener.onResponse(startContext); break; case FINISHED: diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index acb1b27aa3a18..648d691ed3d84 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -19,20 +19,30 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MappingMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector; +import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory; +import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; import org.elasticsearch.xpack.ml.dataframe.steps.AnalysisStep; import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep; +import org.elasticsearch.xpack.ml.dataframe.steps.FinalStep; +import org.elasticsearch.xpack.ml.dataframe.steps.InferenceStep; import org.elasticsearch.xpack.ml.dataframe.steps.ReindexingStep; import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -43,27 +53,36 @@ public class DataFrameAnalyticsManager { private static final Logger LOGGER = LogManager.getLogger(DataFrameAnalyticsManager.class); + private final Settings settings; /** * We need a {@link NodeClient} to get the reindexing task and be able to report progress */ private final NodeClient client; + private final ThreadPool threadPool; private final ClusterService clusterService; private final DataFrameAnalyticsConfigProvider configProvider; private final AnalyticsProcessManager processManager; private final DataFrameAnalyticsAuditor auditor; private final IndexNameExpressionResolver expressionResolver; + private final ResultsPersisterService resultsPersisterService; + private final ModelLoadingService modelLoadingService; /** Indicates whether the node is shutting down. */ private final AtomicBoolean nodeShuttingDown = new AtomicBoolean(); - public DataFrameAnalyticsManager(NodeClient client, ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider, - AnalyticsProcessManager processManager, DataFrameAnalyticsAuditor auditor, - IndexNameExpressionResolver expressionResolver) { + public DataFrameAnalyticsManager(Settings settings, NodeClient client, ThreadPool threadPool, ClusterService clusterService, + DataFrameAnalyticsConfigProvider configProvider, AnalyticsProcessManager processManager, + DataFrameAnalyticsAuditor auditor, IndexNameExpressionResolver expressionResolver, + ResultsPersisterService resultsPersisterService, ModelLoadingService modelLoadingService) { + this.settings = Objects.requireNonNull(settings); this.client = Objects.requireNonNull(client); + this.threadPool = Objects.requireNonNull(threadPool); this.clusterService = Objects.requireNonNull(clusterService); this.configProvider = Objects.requireNonNull(configProvider); this.processManager = Objects.requireNonNull(processManager); this.auditor = Objects.requireNonNull(auditor); this.expressionResolver = Objects.requireNonNull(expressionResolver); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.modelLoadingService = Objects.requireNonNull(modelLoadingService); } public void execute(DataFrameAnalyticsTask task, ClusterState clusterState) { @@ -141,6 +160,12 @@ private void determineProgressAndResume(DataFrameAnalyticsTask task, DataFrameAn case RESUMING_ANALYZING: executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager)); break; + case RESUMING_INFERENCE: + buildInferenceStep(task, config, ActionListener.wrap( + inferenceStep -> executeStep(task, config, inferenceStep), + task::setFailed + )); + break; case FINISHED: default: task.setFailed(ExceptionsHelper.serverError("Unexpected starting state [" + startingState + "]")); @@ -162,7 +187,15 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager)); break; case ANALYSIS: - // This is the last step + buildInferenceStep(task, config, ActionListener.wrap( + inferenceStep -> executeStep(task, config, inferenceStep), + task::setFailed + )); + break; + case INFERENCE: + executeStep(task, config, new FinalStep(client, task, auditor, config)); + break; + case FINAL: LOGGER.info("[{}] Marking task completed", config.getId()); task.markAsCompleted(); break; @@ -199,6 +232,24 @@ private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFra )); } + private void buildInferenceStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, ActionListener listener) { + ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId()); + + ActionListener extractedFieldsDetectorListener = ActionListener.wrap( + extractedFieldsDetector -> { + ExtractedFields extractedFields = extractedFieldsDetector.detect().v1(); + InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, + resultsPersisterService, task.getParentTaskId(), config, extractedFields, task.getStatsHolder().getProgressTracker(), + task.getStatsHolder().getDataCountsTracker()); + InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner); + listener.onResponse(inferenceStep); + }, + listener::onFailure + ); + + new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(config, extractedFieldsDetectorListener); + } + public boolean isNodeShuttingDown() { return nodeShuttingDown.get(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java index 579a1750fe1ed..3d4222d730ea1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java @@ -287,7 +287,7 @@ public void updateTaskProgress(ActionListener updateProgressListener) { * {@code FINISHED} means the job had finished. */ public enum StartingState { - FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, FINISHED + FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, RESUMING_INFERENCE, FINISHED } public StartingState determineStartingState() { @@ -313,6 +313,9 @@ public static StartingState determineStartingState(String jobId, List new ParameterizedMessage("[{}] Data extractor was cancelled", context.jobId)); isCancelled = true; } @@ -127,7 +127,7 @@ private List tryRequestWithSearchResponse(Supplier request) // We've set allow_partial_search_results to false which means if something // goes wrong the request will throw. SearchResponse searchResponse = request.get(); - LOGGER.debug("[{}] Search response was obtained", context.jobId); + LOGGER.trace(() -> new ParameterizedMessage("[{}] Search response was obtained", context.jobId)); List rows = processSearchResponse(searchResponse); @@ -153,7 +153,7 @@ private SearchRequestBuilder buildSearchRequest() { long from = lastSortKey + 1; long to = from + context.scrollSize; - LOGGER.debug(() -> new ParameterizedMessage( + LOGGER.trace(() -> new ParameterizedMessage( "[{}] Searching docs with [{}] in [{}, {})", context.jobId, DestinationIndex.INCREMENTAL_ID, from, to)); SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) @@ -283,7 +283,7 @@ private Row createRow(SearchHit hit) { } boolean isTraining = trainTestSplitter.get().isTraining(extractedValues); Row row = new Row(extractedValues, hit, isTraining); - LOGGER.debug(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}", + LOGGER.trace(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}", context.jobId, row.getSortKey(), isTraining, Arrays.toString(row.values))); return row; } @@ -306,7 +306,7 @@ public DataSummary collectDataSummary() { SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder(); SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); long rows = searchResponse.getHits().getTotalHits().value; - LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows); + LOGGER.debug(() -> new ParameterizedMessage("[{}] Data summary rows [{}]", context.jobId, rows)); return new DataSummary(rows, organicFeatures.length + processedFeatures.length); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 51d81624aa3f5..cf0b93ef80872 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -10,21 +10,14 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.admin.indices.refresh.RefreshAction; -import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; -import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -32,20 +25,17 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; -import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -69,7 +59,6 @@ public class AnalyticsProcessManager { private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); private final DataFrameAnalyticsAuditor auditor; private final TrainedModelProvider trainedModelProvider; - private final ModelLoadingService modelLoadingService; private final ResultsPersisterService resultsPersisterService; private final int numAllocatedProcessors; @@ -79,7 +68,6 @@ public AnalyticsProcessManager(Settings settings, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider, - ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) { this( @@ -90,7 +78,6 @@ public AnalyticsProcessManager(Settings settings, analyticsProcessFactory, auditor, trainedModelProvider, - modelLoadingService, resultsPersisterService, numAllocatedProcessors); } @@ -103,7 +90,6 @@ public AnalyticsProcessManager(Settings settings, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider, - ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) { this.settings = Objects.requireNonNull(settings); @@ -113,7 +99,6 @@ public AnalyticsProcessManager(Settings settings, this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.auditor = Objects.requireNonNull(auditor); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); - this.modelLoadingService = Objects.requireNonNull(modelLoadingService); this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); this.numAllocatedProcessors = numAllocatedProcessors; } @@ -183,7 +168,6 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont LOGGER.info("[{}] Started loading data", processContext.config.getId()); auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA)); - ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId()); DataFrameAnalyticsConfig config = processContext.config; DataFrameDataExtractor dataExtractor = processContext.dataExtractor.get(); AnalyticsProcess process = processContext.process.get(); @@ -203,14 +187,6 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont resultProcessor.awaitForCompletion(); processContext.setFailureReason(resultProcessor.getFailure()); LOGGER.info("[{}] Result processor has completed", config.getId()); - - runInference(parentTaskClient, task, processContext, dataExtractor.getExtractedFields()); - - processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()), - DataCounts::documentId); - - refreshDest(parentTaskClient, config); - refreshIndices(parentTaskClient, config.getId()); } catch (Exception e) { if (task.isStopping()) { // Errors during task stopping are expected but we still want to log them just in case. @@ -338,43 +314,6 @@ private Consumer onProcessCrash(DataFrameAnalyticsTask task) { }; } - private void runInference(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsTask task, ProcessContext processContext, - ExtractedFields extractedFields) { - if (task.isStopping() || processContext.failureReason.get() != null) { - // If the task is stopping or there has been an error thus far let's not run inference at all - return; - } - - if (processContext.config.getAnalysis().supportsInference()) { - refreshDest(parentTaskClient, processContext.config); - InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService, - task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(), - task.getStatsHolder().getDataCountsTracker()); - processContext.setInferenceRunner(inferenceRunner); - inferenceRunner.run(processContext.resultProcessor.get().getLatestModelId()); - } - } - - private void refreshDest(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsConfig config) { - ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, parentTaskClient, - () -> parentTaskClient.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet()); - } - - private void refreshIndices(ParentTaskAssigningClient parentTaskClient, String jobId) { - RefreshRequest refreshRequest = new RefreshRequest( - AnomalyDetectorsIndex.jobStateIndexPattern(), - MlStatsIndex.indexPattern() - ); - refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); - - LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", - jobId, Arrays.toString(refreshRequest.indices()))); - - try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { - parentTaskClient.admin().indices().refresh(refreshRequest).actionGet(); - } - } - private void closeProcess(DataFrameAnalyticsTask task) { String configId = task.getParams().getId(); LOGGER.info("[{}] Closing process", configId); @@ -415,13 +354,10 @@ class ProcessContext { private final SetOnce> process = new SetOnce<>(); private final SetOnce dataExtractor = new SetOnce<>(); private final SetOnce resultProcessor = new SetOnce<>(); - private final SetOnce inferenceRunner = new SetOnce<>(); private final SetOnce failureReason = new SetOnce<>(); - private final StatsPersister statsPersister; ProcessContext(DataFrameAnalyticsConfig config) { this.config = Objects.requireNonNull(config); - this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); } String getFailureReason() { @@ -436,10 +372,6 @@ void setFailureReason(String failureReason) { this.failureReason.trySet(failureReason); } - void setInferenceRunner(InferenceRunner inferenceRunner) { - this.inferenceRunner.set(inferenceRunner); - } - synchronized void stop() { LOGGER.debug("[{}] Stopping process", config.getId()); if (dataExtractor.get() != null) { @@ -448,9 +380,6 @@ synchronized void stop() { if (resultProcessor.get() != null) { resultProcessor.get().cancel(); } - if (inferenceRunner.get() != null) { - inferenceRunner.get().cancel(); - } if (process.get() != null) { try { process.get().kill(true); @@ -507,6 +436,7 @@ private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask ta DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); + StatsPersister statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); return new AnalyticsResultProcessor( config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister, dataExtractor.get().getExtractedFields()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java index 691836d26a438..cb8a1966f9277 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java @@ -54,6 +54,9 @@ public ProgressTracker(List phaseProgresses) { assert progressPercentPerPhase.containsKey(REINDEXING); assert progressPercentPerPhase.containsKey(LOADING_DATA); assert progressPercentPerPhase.containsKey(WRITING_RESULTS); + // If there is inference it should be the last phase otherwise there + // are assumptions that do not hold. + assert progressPercentPerPhase.containsKey(INFERENCE) == false || INFERENCE.equals(phasesInOrder[phasesInOrder.length - 1]); } public void updateReindexingProgress(int progressPercent) { @@ -96,6 +99,32 @@ private void updatePhase(String phase, int progress) { progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress)); } + /** + * Resets progress to reflect all phases are complete except for inference + * which is set to zero. + */ + public void resetForInference() { + for (Map.Entry phaseProgress : progressPercentPerPhase.entrySet()) { + if (phaseProgress.getKey().equals(INFERENCE)) { + progressPercentPerPhase.put(phaseProgress.getKey(), 0); + } else { + progressPercentPerPhase.put(phaseProgress.getKey(), 100); + } + } + } + + /** + * Returns whether all phases before inference are complete + */ + public boolean areAllPhasesExceptInferenceComplete() { + for (Map.Entry phaseProgress : progressPercentPerPhase.entrySet()) { + if (phaseProgress.getKey().equals(INFERENCE) == false && phaseProgress.getValue() < 100) { + return false; + } + } + return true; + } + public List report() { return Arrays.stream(phasesInOrder) .map(phase -> new PhaseProgress(phase, progressPercentPerPhase.get(phase))) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index 1207821fff37e..a0bfbad92ff02 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -34,15 +34,30 @@ public void setProgressTracker(List progress) { progressTracker = new ProgressTracker(progress); } + /** + * Updates the progress tracker with potentially new in-between phases + * that were introduced in a later version while making sure progress indicators + * are correct. + * @param analysisPhases the new analysis phases + * @param hasInferencePhase whether the analysis supports inference + */ public void adjustProgressTracker(List analysisPhases, boolean hasInferencePhase) { int reindexingProgressPercent = progressTracker.getReindexingProgressPercent(); + boolean areAllPhasesBeforeInferenceComplete = progressTracker.areAllPhasesExceptInferenceComplete(); progressTracker = ProgressTracker.fromZeroes(analysisPhases, hasInferencePhase); // If reindexing progress was more than 0 and less than 100 (ie not complete) we reset it to 1 // as we will have to do reindexing from scratch and at the same time we want // to differentiate from a job that has never started before. - progressTracker.updateReindexingProgress( - (reindexingProgressPercent > 0 && reindexingProgressPercent < 100) ? 1 : reindexingProgressPercent); + if (reindexingProgressPercent > 0 && reindexingProgressPercent < 100) { + progressTracker.updateReindexingProgress(1); + } else { + progressTracker.updateReindexingProgress(reindexingProgressPercent); + } + + if (hasInferencePhase && areAllPhasesBeforeInferenceComplete) { + progressTracker.resetForInference(); + } } public void resetProgressTracker(List analysisPhases, boolean hasInferencePhase) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java index 150745a624dc2..db9db30c2b0d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java @@ -58,7 +58,7 @@ protected TaskId getParentTaskId() { public final void execute(ActionListener listener) { logger.debug(() -> new ParameterizedMessage("[{}] Executing step [{}]", config.getId(), name())); if (task.isStopping()) { - logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}]", config.getId(), name())); + logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}] step", config.getId(), name())); listener.onResponse(new StepResponse(true)); return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java index c5883a4bb3e3c..78d31c686a9ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java @@ -14,7 +14,7 @@ public interface DataFrameAnalyticsStep { enum Name { - REINDEXING, ANALYSIS; + REINDEXING, ANALYSIS, INFERENCE, FINAL; @Override public String toString() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java new file mode 100644 index 0000000000000..1132b31c111d1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.steps; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.ParentTaskAssigningClient; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +/** + * The final step of a data frame analytics job. + * Allows the job to perform finalizing tasks like refresh indices, + * persist stats, etc. + */ +public class FinalStep extends AbstractDataFrameAnalyticsStep { + + private static final Logger LOGGER = LogManager.getLogger(FinalStep.class); + + public FinalStep(NodeClient client, DataFrameAnalyticsTask task, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsConfig config) { + super(client, task, auditor, config); + } + + @Override + public Name name() { + return Name.FINAL; + } + + @Override + protected void doExecute(ActionListener listener) { + + ActionListener refreshListener = ActionListener.wrap( + refreshResponse -> listener.onResponse(new StepResponse(true)), + listener::onFailure + ); + + ActionListener dataCountsIndexedListener = ActionListener.wrap( + indexResponse -> refreshIndices(refreshListener), + listener::onFailure + ); + + indexDataCounts(dataCountsIndexedListener); + } + + private void indexDataCounts(ActionListener listener) { + DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report(config.getId()); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + dataCounts.toXContent(builder, new ToXContent.MapParams( + Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias()) + .id(DataCounts.documentId(config.getId())) + .setRequireAlias(true) + .source(builder); + parentTaskClient().index(indexRequest, listener); + } catch (IOException e) { + listener.onFailure(ExceptionsHelper.serverError("[{}] Error persisting final data counts", e, config.getId())); + } + } + + private void refreshIndices(ActionListener listener) { + RefreshRequest refreshRequest = new RefreshRequest( + AnomalyDetectorsIndex.jobStateIndexPattern(), + MlStatsIndex.indexPattern(), + config.getDest().getIndex() + ); + refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); + + LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", config.getId(), + Arrays.toString(refreshRequest.indices()))); + + ParentTaskAssigningClient parentTaskClient = parentTaskClient(); + try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { + parentTaskClient.admin().indices().refresh(refreshRequest, listener); + } + } + + @Override + public void cancel(String reason, TimeValue timeout) { + // Not cancellable + } + + @Override + public void updateProgress(ActionListener listener) { + // No progress to update + listener.onResponse(null); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java new file mode 100644 index 0000000000000..381f6090c907a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.steps; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; + +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class InferenceStep extends AbstractDataFrameAnalyticsStep { + + private static final Logger LOGGER = LogManager.getLogger(InferenceStep.class); + + private final ThreadPool threadPool; + private final InferenceRunner inferenceRunner; + + public InferenceStep(NodeClient client, DataFrameAnalyticsTask task, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsConfig config, + ThreadPool threadPool, InferenceRunner inferenceRunner) { + super(client, task, auditor, config); + this.threadPool = Objects.requireNonNull(threadPool); + this.inferenceRunner = Objects.requireNonNull(inferenceRunner); + } + + @Override + public Name name() { + return Name.INFERENCE; + } + + @Override + protected void doExecute(ActionListener listener) { + if (config.getAnalysis().supportsInference() == false) { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] Inference step completed immediately as analysis does not support inference", config.getId())); + listener.onResponse(new StepResponse(false)); + return; + } + + ActionListener modelIdListener = ActionListener.wrap( + modelId -> runInference(modelId, listener), + listener::onFailure + ); + + ActionListener refreshDestListener = ActionListener.wrap( + refreshResponse -> getModelId(modelIdListener), + listener::onFailure + ); + + refreshDestAsync(refreshDestListener); + } + + private void runInference(String modelId, ActionListener listener) { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + try { + inferenceRunner.run(modelId); + listener.onResponse(new StepResponse(isTaskStopping())); + } catch (Exception e) { + if (task.isStopping()) { + listener.onResponse(new StepResponse(false)); + } else { + listener.onFailure(e); + } + } + }); + } + + private void getModelId(ActionListener listener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(1); + searchSourceBuilder.fetchSource(false); + searchSourceBuilder.query(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), config.getId())) + ); + searchSourceBuilder.sort(TrainedModelConfig.CREATE_TIME.getPreferredName(), SortOrder.DESC); + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN); + searchRequest.source(searchSourceBuilder); + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> { + SearchHit[] hits = searchResponse.getHits().getHits(); + if (hits.length == 0) { + listener.onFailure(new ResourceNotFoundException("No model could be found to perform inference")); + } else { + listener.onResponse(hits[0].getId()); + } + }, + listener::onFailure + )); + } + + @Override + public void cancel(String reason, TimeValue timeout) { + inferenceRunner.cancel(); + } + + @Override + public void updateProgress(ActionListener listener) { + // Inference runner updates progress directly + listener.onResponse(null); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 252d6bc10aada..77105f213ea26 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -113,9 +113,8 @@ public void setUpMocks() { when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); resultsPersisterService = mock(ResultsPersisterService.class); - modelLoadingService = mock(ModelLoadingService.class); processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess, - processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1); + processFactory, auditor, trainedModelProvider, resultsPersisterService, 1); } public void testRunJob_TaskIsStopping() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java index 2ecad9ed376d5..1110e31cbc513 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java @@ -8,10 +8,14 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -21,12 +25,16 @@ public class DataFrameAnalyticsManagerTests extends ESTestCase { public void testNodeShuttingDown() { DataFrameAnalyticsManager manager = new DataFrameAnalyticsManager( + Settings.EMPTY, mock(NodeClient.class), + mock(ThreadPool.class), mock(ClusterService.class), mock(DataFrameAnalyticsConfigProvider.class), mock(AnalyticsProcessManager.class), mock(DataFrameAnalyticsAuditor.class), - mock(IndexNameExpressionResolver.class)); + mock(IndexNameExpressionResolver.class), + mock(ResultsPersisterService.class), + mock(ModelLoadingService.class)); assertThat(manager.isNodeShuttingDown(), is(false)); manager.markNodeAsShuttingDown(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java index 8e52301ee0b0c..3b9150684df23 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java @@ -11,7 +11,9 @@ import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import static org.hamcrest.Matchers.contains; @@ -148,6 +150,75 @@ public void testUpdatePhase_GivenLowerValueThanCurrentProgress() { assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41)); } + public void testResetForInference_GivenInference() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b"), true); + progressTracker.updateReindexingProgress(10); + progressTracker.updateLoadingDataProgress(20); + progressTracker.updatePhase(new PhaseProgress("a", 30)); + progressTracker.updatePhase(new PhaseProgress("b", 40)); + progressTracker.updateWritingResultsProgress(50); + progressTracker.updateInferenceProgress(60); + + progressTracker.resetForInference(); + + List progress = progressTracker.report(); + assertThat(progress, contains( + new PhaseProgress(ProgressTracker.REINDEXING, 100), + new PhaseProgress(ProgressTracker.LOADING_DATA, 100), + new PhaseProgress("a", 100), + new PhaseProgress("b", 100), + new PhaseProgress(ProgressTracker.WRITING_RESULTS, 100), + new PhaseProgress(ProgressTracker.INFERENCE, 0) + )); + } + + public void testResetForInference_GivenNoInference() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b"), false); + progressTracker.updateReindexingProgress(10); + progressTracker.updateLoadingDataProgress(20); + progressTracker.updatePhase(new PhaseProgress("a", 30)); + progressTracker.updatePhase(new PhaseProgress("b", 40)); + progressTracker.updateWritingResultsProgress(50); + + progressTracker.resetForInference(); + + List progress = progressTracker.report(); + assertThat(progress, contains( + new PhaseProgress(ProgressTracker.REINDEXING, 100), + new PhaseProgress(ProgressTracker.LOADING_DATA, 100), + new PhaseProgress("a", 100), + new PhaseProgress("b", 100), + new PhaseProgress(ProgressTracker.WRITING_RESULTS, 100) + )); + } + + public void testAreAllPhasesExceptInferenceComplete_GivenComplete() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("a"), true); + progressTracker.updateReindexingProgress(100); + progressTracker.updateLoadingDataProgress(100); + progressTracker.updatePhase(new PhaseProgress("a", 100)); + progressTracker.updateWritingResultsProgress(100); + progressTracker.updateInferenceProgress(50); + + assertThat(progressTracker.areAllPhasesExceptInferenceComplete(), is(true)); + } + + public void testAreAllPhasesExceptInferenceComplete_GivenNotComplete() { + Map phasePerProgress = new LinkedHashMap<>(); + phasePerProgress.put(ProgressTracker.REINDEXING, 100); + phasePerProgress.put(ProgressTracker.LOADING_DATA, 100); + phasePerProgress.put("a", 100); + phasePerProgress.put(ProgressTracker.WRITING_RESULTS, 100); + String nonCompletePhase = randomFrom(phasePerProgress.keySet()); + phasePerProgress.put(ProgressTracker.INFERENCE, 50); + phasePerProgress.put(nonCompletePhase, randomIntBetween(0, 99)); + + ProgressTracker progressTracker = new ProgressTracker(phasePerProgress.entrySet().stream() + .map(entry -> new PhaseProgress(entry.getKey(), entry.getValue())).collect(Collectors.toList())); + + assertThat(progressTracker.areAllPhasesExceptInferenceComplete(), is(false)); + } + private static int getProgressForPhase(ProgressTracker progressTracker, String phase) { return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java index 3320250659f28..7a5cf841d582f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java @@ -123,6 +123,32 @@ public void testAdjustProgressTracker_GivenReindexingProgressIncomplete() { assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0)); } + public void testAdjustProgressTracker_GivenAllPhasesCompleteExceptInference() { + List phases = Collections.unmodifiableList( + Arrays.asList( + new PhaseProgress("reindexing", 100), + new PhaseProgress("loading_data", 100), + new PhaseProgress("a", 100), + new PhaseProgress("writing_results", 100), + new PhaseProgress("inference", 20) + ) + ); + StatsHolder statsHolder = new StatsHolder(phases); + + statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), true); + + List phaseProgresses = statsHolder.getProgressTracker().report(); + + assertThat(phaseProgresses, contains( + new PhaseProgress("reindexing", 100), + new PhaseProgress("loading_data", 100), + new PhaseProgress("a", 100), + new PhaseProgress("b", 100), + new PhaseProgress("writing_results", 100), + new PhaseProgress("inference", 0) + )); + } + public void testResetProgressTracker() { List phases = Collections.unmodifiableList( Arrays.asList( From 5d1fa3506b94c03610d04a15958118885a5468ea Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Sat, 16 Jan 2021 14:35:19 +0200 Subject: [PATCH 2/3] Add unit test for determineStartingState --- .../ml/dataframe/DataFrameAnalyticsTaskTests.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java index be81e4c0b863b..aae85f27e65ef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java @@ -118,6 +118,18 @@ public void testDetermineStartingState_GivenWritingResultsIsIncomplete() { assertThat(startingState, equalTo(StartingState.RESUMING_ANALYZING)); } + public void testDetermineStartingState_GivenInferenceIsIncomplete() { + List progress = Arrays.asList(new PhaseProgress("reindexing", 100), + new PhaseProgress("loading_data", 100), + new PhaseProgress("analyzing", 100), + new PhaseProgress("writing_results", 100), + new PhaseProgress("inference", 40)); + + StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress); + + assertThat(startingState, equalTo(StartingState.RESUMING_INFERENCE)); + } + public void testDetermineStartingState_GivenFinished() { List progress = Arrays.asList(new PhaseProgress("reindexing", 100), new PhaseProgress("loading_data", 100), From 3676214987164bb7c12302530357a87806bfe96b Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 18 Jan 2021 16:42:43 +0200 Subject: [PATCH 3/3] Clarify comment --- .../org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index a0bfbad92ff02..1b74ee6ec47d1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -38,7 +38,7 @@ public void setProgressTracker(List progress) { * Updates the progress tracker with potentially new in-between phases * that were introduced in a later version while making sure progress indicators * are correct. - * @param analysisPhases the new analysis phases + * @param analysisPhases the full set of phases of the analysis in current version * @param hasInferencePhase whether the analysis supports inference */ public void adjustProgressTracker(List analysisPhases, boolean hasInferencePhase) {