3333import java .util .Iterator ;
3434import java .util .List ;
3535import java .util .concurrent .ConcurrentHashMap ;
36+ import java .util .concurrent .Phaser ;
3637import java .util .stream .Collectors ;
3738
3839/**
@@ -56,6 +57,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
5657 private final ClusterService clusterService ;
5758 private final JobManager jobManager ;
5859 private final JobResultsProvider jobResultsProvider ;
60+ private final Phaser stopPhaser ;
5961 private volatile boolean isMaster ;
6062 private volatile Instant lastUpdateTime ;
6163 private volatile Duration reassignmentRecheckInterval ;
@@ -66,6 +68,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
6668 this .clusterService = clusterService ;
6769 this .jobManager = jobManager ;
6870 this .jobResultsProvider = jobResultsProvider ;
71+ this .stopPhaser = new Phaser (1 );
6972 setReassignmentRecheckInterval (PersistentTasksClusterService .CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING .get (settings ));
7073 clusterService .addLocalNodeMasterListener (this );
7174 clusterService .getClusterSettings ().addSettingsUpdateConsumer (
@@ -90,6 +93,23 @@ public void offMaster() {
9093 lastUpdateTime = null ;
9194 }
9295
96+ /**
97+ * Wait for all outstanding searches to complete.
98+ * After returning, no new searches can be started.
99+ */
100+ public void stop () {
101+ logger .trace ("ML memory tracker stop called" );
102+ // We never terminate the phaser
103+ assert stopPhaser .isTerminated () == false ;
104+ // If there are no registered parties or no unarrived parties then there is a flaw
105+ // in the register/arrive/unregister logic in another method that uses the phaser
106+ assert stopPhaser .getRegisteredParties () > 0 ;
107+ assert stopPhaser .getUnarrivedParties () > 0 ;
108+ stopPhaser .arriveAndAwaitAdvance ();
109+ assert stopPhaser .getPhase () > 0 ;
110+ logger .debug ("ML memory tracker stopped" );
111+ }
112+
93113 @ Override
94114 public String executorName () {
95115 return MachineLearning .UTILITY_THREAD_POOL_NAME ;
@@ -153,13 +173,13 @@ public boolean asyncRefresh() {
153173 try {
154174 ActionListener <Void > listener = ActionListener .wrap (
155175 aVoid -> logger .trace ("Job memory requirement refresh request completed successfully" ),
156- e -> logger .error ("Failed to refresh job memory requirements" , e )
176+ e -> logger .warn ("Failed to refresh job memory requirements" , e )
157177 );
158178 threadPool .executor (executorName ()).execute (
159179 () -> refresh (clusterService .state ().getMetaData ().custom (PersistentTasksCustomMetaData .TYPE ), listener ));
160180 return true ;
161181 } catch (EsRejectedExecutionException e ) {
162- logger .debug ("Couldn't schedule ML memory update - node might be shutting down" , e );
182+ logger .warn ("Couldn't schedule ML memory update - node might be shutting down" , e );
163183 }
164184 }
165185
@@ -253,25 +273,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
253273 return ;
254274 }
255275
276+ // The phaser prevents searches being started after the memory tracker's stop() method has returned
277+ if (stopPhaser .register () != 0 ) {
278+ // Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
279+ stopPhaser .arriveAndDeregister ();
280+ listener .onFailure (new EsRejectedExecutionException ("Couldn't run ML memory update - node is shutting down" ));
281+ return ;
282+ }
283+ ActionListener <Long > phaserListener = ActionListener .wrap (
284+ r -> {
285+ stopPhaser .arriveAndDeregister ();
286+ listener .onResponse (r );
287+ },
288+ e -> {
289+ stopPhaser .arriveAndDeregister ();
290+ listener .onFailure (e );
291+ }
292+ );
293+
256294 try {
257295 jobResultsProvider .getEstablishedMemoryUsage (jobId , null , null ,
258296 establishedModelMemoryBytes -> {
259297 if (establishedModelMemoryBytes <= 0L ) {
260- setJobMemoryToLimit (jobId , listener );
298+ setJobMemoryToLimit (jobId , phaserListener );
261299 } else {
262300 Long memoryRequirementBytes = establishedModelMemoryBytes + Job .PROCESS_MEMORY_OVERHEAD .getBytes ();
263301 memoryRequirementByJob .put (jobId , memoryRequirementBytes );
264- listener .onResponse (memoryRequirementBytes );
302+ phaserListener .onResponse (memoryRequirementBytes );
265303 }
266304 },
267305 e -> {
268306 logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
269- setJobMemoryToLimit (jobId , listener );
307+ setJobMemoryToLimit (jobId , phaserListener );
270308 }
271309 );
272310 } catch (Exception e ) {
273311 logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
274- setJobMemoryToLimit (jobId , listener );
312+ setJobMemoryToLimit (jobId , phaserListener );
275313 }
276314 }
277315
0 commit comments