3232import java .util .Iterator ;
3333import java .util .List ;
3434import java .util .concurrent .ConcurrentHashMap ;
35+ import java .util .concurrent .Phaser ;
3536import java .util .stream .Collectors ;
3637
3738/**
@@ -55,6 +56,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
5556 private final ClusterService clusterService ;
5657 private final JobManager jobManager ;
5758 private final JobResultsProvider jobResultsProvider ;
59+ private final Phaser stopPhaser ;
5860 private volatile boolean isMaster ;
5961 private volatile Instant lastUpdateTime ;
6062 private volatile Duration reassignmentRecheckInterval ;
@@ -65,6 +67,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
6567 this .clusterService = clusterService ;
6668 this .jobManager = jobManager ;
6769 this .jobResultsProvider = jobResultsProvider ;
70+ this .stopPhaser = new Phaser (1 );
6871 setReassignmentRecheckInterval (PersistentTasksClusterService .CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING .get (settings ));
6972 clusterService .addLocalNodeMasterListener (this );
7073 clusterService .getClusterSettings ().addSettingsUpdateConsumer (
@@ -89,6 +92,23 @@ public void offMaster() {
8992 lastUpdateTime = null ;
9093 }
9194
95+ /**
96+ * Wait for all outstanding searches to complete.
97+ * After returning, no new searches can be started.
98+ */
99+ public void stop () {
100+ logger .trace ("ML memory tracker stop called" );
101+ // We never terminate the phaser
102+ assert stopPhaser .isTerminated () == false ;
103+ // If there are no registered parties or no unarrived parties then there is a flaw
104+ // in the register/arrive/unregister logic in another method that uses the phaser
105+ assert stopPhaser .getRegisteredParties () > 0 ;
106+ assert stopPhaser .getUnarrivedParties () > 0 ;
107+ stopPhaser .arriveAndAwaitAdvance ();
108+ assert stopPhaser .getPhase () > 0 ;
109+ logger .debug ("ML memory tracker stopped" );
110+ }
111+
92112 @ Override
93113 public String executorName () {
94114 return MachineLearning .UTILITY_THREAD_POOL_NAME ;
@@ -146,13 +166,13 @@ public boolean asyncRefresh() {
146166 try {
147167 ActionListener <Void > listener = ActionListener .wrap (
148168 aVoid -> logger .trace ("Job memory requirement refresh request completed successfully" ),
149- e -> logger .error ("Failed to refresh job memory requirements" , e )
169+ e -> logger .warn ("Failed to refresh job memory requirements" , e )
150170 );
151171 threadPool .executor (executorName ()).execute (
152172 () -> refresh (clusterService .state ().getMetaData ().custom (PersistentTasksCustomMetaData .TYPE ), listener ));
153173 return true ;
154174 } catch (EsRejectedExecutionException e ) {
155- logger .debug ("Couldn't schedule ML memory update - node might be shutting down" , e );
175+ logger .warn ("Couldn't schedule ML memory update - node might be shutting down" , e );
156176 }
157177 }
158178
@@ -246,25 +266,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
246266 return ;
247267 }
248268
269+ // The phaser prevents searches being started after the memory tracker's stop() method has returned
270+ if (stopPhaser .register () != 0 ) {
271+ // Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
272+ stopPhaser .arriveAndDeregister ();
273+ listener .onFailure (new EsRejectedExecutionException ("Couldn't run ML memory update - node is shutting down" ));
274+ return ;
275+ }
276+ ActionListener <Long > phaserListener = ActionListener .wrap (
277+ r -> {
278+ stopPhaser .arriveAndDeregister ();
279+ listener .onResponse (r );
280+ },
281+ e -> {
282+ stopPhaser .arriveAndDeregister ();
283+ listener .onFailure (e );
284+ }
285+ );
286+
249287 try {
250288 jobResultsProvider .getEstablishedMemoryUsage (jobId , null , null ,
251289 establishedModelMemoryBytes -> {
252290 if (establishedModelMemoryBytes <= 0L ) {
253- setJobMemoryToLimit (jobId , listener );
291+ setJobMemoryToLimit (jobId , phaserListener );
254292 } else {
255293 Long memoryRequirementBytes = establishedModelMemoryBytes + Job .PROCESS_MEMORY_OVERHEAD .getBytes ();
256294 memoryRequirementByJob .put (jobId , memoryRequirementBytes );
257- listener .onResponse (memoryRequirementBytes );
295+ phaserListener .onResponse (memoryRequirementBytes );
258296 }
259297 },
260298 e -> {
261299 logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
262- setJobMemoryToLimit (jobId , listener );
300+ setJobMemoryToLimit (jobId , phaserListener );
263301 }
264302 );
265303 } catch (Exception e ) {
266304 logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
267- setJobMemoryToLimit (jobId , listener );
305+ setJobMemoryToLimit (jobId , phaserListener );
268306 }
269307 }
270308
0 commit comments