@@ -228,20 +228,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
228228
229229 def handle (self ):
230230 from pyspark .accumulators import _accumulatorRegistry
231- while not self .server .server_shutdown :
232- # Poll every 1 second for new data -- don't block in case of shutdown.
233- r , _ , _ = select .select ([self .rfile ], [], [], 1 )
234- if self .rfile in r :
235- num_updates = read_int (self .rfile )
236- for _ in range (num_updates ):
237- (aid , update ) = pickleSer ._read_with_length (self .rfile )
238- _accumulatorRegistry [aid ] += update
239- # Write a byte in acknowledgement
240- self .wfile .write (struct .pack ("!b" , 1 ))
231+ auth_token = self .server .auth_token
232+
233+ def poll (func ):
234+ while not self .server .server_shutdown :
235+ # Poll every 1 second for new data -- don't block in case of shutdown.
236+ r , _ , _ = select .select ([self .rfile ], [], [], 1 )
237+ if self .rfile in r :
238+ if func ():
239+ break
240+
241+ def accum_updates ():
242+ num_updates = read_int (self .rfile )
243+ for _ in range (num_updates ):
244+ (aid , update ) = pickleSer ._read_with_length (self .rfile )
245+ _accumulatorRegistry [aid ] += update
246+ # Write a byte in acknowledgement
247+ self .wfile .write (struct .pack ("!b" , 1 ))
248+ return False
249+
250+ def authenticate_and_accum_updates ():
251+ received_token = self .rfile .read (len (auth_token ))
252+ if isinstance (received_token , bytes ):
253+ received_token = received_token .decode ("utf-8" )
254+ if (received_token == auth_token ):
255+ accum_updates ()
256+ # we've authenticated, we can break out of the first loop now
257+ return True
258+ else :
259+ raise Exception (
260+ "The value of the provided token to the AccumulatorServer is not correct." )
261+
262+ # first we keep polling till we've received the authentication token
263+ poll (authenticate_and_accum_updates )
264+ # now we've authenticated, don't need to check for the token anymore
265+ poll (accum_updates )
241266
242267
243268class AccumulatorServer (SocketServer .TCPServer ):
244269
270+ def __init__ (self , server_address , RequestHandlerClass , auth_token ):
271+ SocketServer .TCPServer .__init__ (self , server_address , RequestHandlerClass )
272+ self .auth_token = auth_token
273+
245274 """
246275 A simple TCP server that intercepts shutdown() in order to interrupt
247276 our continuous polling on the handler.
@@ -254,9 +283,9 @@ def shutdown(self):
254283 self .server_close ()
255284
256285
257- def _start_update_server ():
286+ def _start_update_server (auth_token ):
258287 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
259- server = AccumulatorServer (("localhost" , 0 ), _UpdateRequestHandler )
288+ server = AccumulatorServer (("localhost" , 0 ), _UpdateRequestHandler , auth_token )
260289 thread = threading .Thread (target = server .serve_forever )
261290 thread .daemon = True
262291 thread .start ()
0 commit comments