@@ -589,11 +589,14 @@ def _try_authenticate_plain(self, future):
589589 self .config ['sasl_plain_password' ]]).encode ('utf-8' ))
590590 size = Int32 .encode (len (msg ))
591591 try :
592- self ._send_bytes_blocking (size + msg )
592+ with self ._lock :
593+ if not self ._can_send_recv ():
594+ return future .failure (Errors .NodeNotReadyError (str (self )))
595+ self ._send_bytes_blocking (size + msg )
593596
594- # The server will send a zero sized message (that is Int32(0)) on success.
595- # The connection is closed on failure
596- data = self ._recv_bytes_blocking (4 )
597+ # The server will send a zero sized message (that is Int32(0)) on success.
598+ # The connection is closed on failure
599+ data = self ._recv_bytes_blocking (4 )
597600
598601 except ConnectionError as e :
599602 log .exception ("%s: Error receiving reply from server" , self )
@@ -617,6 +620,9 @@ def _try_authenticate_gssapi(self, future):
617620 ).canonicalize (gssapi .MechType .kerberos )
618621 log .debug ('%s: GSSAPI name: %s' , self , gssapi_name )
619622
623+ self ._lock .acquire ()
624+ if not self ._can_send_recv ():
625+ return future .failure (Errors .NodeNotReadyError (str (self )))
620626 # Establish security context and negotiate protection level
621627 # For reference RFC 2222, section 7.2.1
622628 try :
@@ -659,13 +665,16 @@ def _try_authenticate_gssapi(self, future):
659665 self ._send_bytes_blocking (size + msg )
660666
661667 except ConnectionError as e :
668+ self ._lock .release ()
662669 log .exception ("%s: Error receiving reply from server" , self )
663670 error = Errors .KafkaConnectionError ("%s: %s" % (self , e ))
664671 self .close (error = error )
665672 return future .failure (error )
666673 except Exception as e :
674+ self ._lock .release ()
667675 return future .failure (e )
668676
677+ self ._lock .release ()
669678 log .info ('%s: Authenticated as %s via GSSAPI' , self , gssapi_name )
670679 return future .success (True )
671680
@@ -674,6 +683,9 @@ def _try_authenticate_oauth(self, future):
674683
675684 msg = bytes (self ._build_oauth_client_request ().encode ("utf-8" ))
676685 size = Int32 .encode (len (msg ))
686+ self ._lock .acquire ()
687+ if not self ._can_send_recv ():
688+ return future .failure (Errors .NodeNotReadyError (str (self )))
677689 try :
678690 # Send SASL OAuthBearer request with OAuth token
679691 self ._send_bytes_blocking (size + msg )
@@ -683,11 +695,14 @@ def _try_authenticate_oauth(self, future):
683695 data = self ._recv_bytes_blocking (4 )
684696
685697 except ConnectionError as e :
698+ self ._lock .release ()
686699 log .exception ("%s: Error receiving reply from server" , self )
687700 error = Errors .KafkaConnectionError ("%s: %s" % (self , e ))
688701 self .close (error = error )
689702 return future .failure (error )
690703
704+ self ._lock .release ()
705+
691706 if data != b'\x00 \x00 \x00 \x00 ' :
692707 error = Errors .AuthenticationFailedError ('Unrecognized response during authentication' )
693708 return future .failure (error )
@@ -787,26 +802,33 @@ def close(self, error=None):
787802 will be failed with this exception.
788803 Default: kafka.errors.KafkaConnectionError.
789804 """
790- if self .state is ConnectionStates .DISCONNECTED :
791- if error is not None :
792- log .warning ('%s: Duplicate close() with error: %s' , self , error )
793- return
794- log .info ('%s: Closing connection. %s' , self , error or '' )
795- self .state = ConnectionStates .DISCONNECTING
796- self .config ['state_change_callback' ](self )
797- self ._update_reconnect_backoff ()
798- self ._close_socket ()
799- self .state = ConnectionStates .DISCONNECTED
800- self ._sasl_auth_future = None
801- self ._protocol = KafkaProtocol (
802- client_id = self .config ['client_id' ],
803- api_version = self .config ['api_version' ])
804- if error is None :
805- error = Errors .Cancelled (str (self ))
806- while self .in_flight_requests :
807- (_correlation_id , (future , _timestamp )) = self .in_flight_requests .popitem ()
805+ with self ._lock :
806+ if self .state is ConnectionStates .DISCONNECTED :
807+ return
808+ log .info ('%s: Closing connection. %s' , self , error or '' )
809+ self .state = ConnectionStates .DISCONNECTING
810+ self .config ['state_change_callback' ](self )
811+ self ._update_reconnect_backoff ()
812+ self ._close_socket ()
813+ self .state = ConnectionStates .DISCONNECTED
814+ self ._sasl_auth_future = None
815+ self ._protocol = KafkaProtocol (
816+ client_id = self .config ['client_id' ],
817+ api_version = self .config ['api_version' ])
818+ if error is None :
819+ error = Errors .Cancelled (str (self ))
820+ ifrs = list (self .in_flight_requests .items ())
821+ self .in_flight_requests .clear ()
822+ self .config ['state_change_callback' ](self )
823+
824+ # drop lock before processing futures
825+ for (_correlation_id , (future , _timestamp )) in ifrs :
808826 future .failure (error )
809- self .config ['state_change_callback' ](self )
827+
828+ def _can_send_recv (self ):
829+ """Return True iff socket is ready for requests / responses"""
830+ return self .state in (ConnectionStates .AUTHENTICATING ,
831+ ConnectionStates .CONNECTED )
810832
811833 def send (self , request , blocking = True ):
812834 """Queue request for async network send, return Future()"""
@@ -820,18 +842,20 @@ def send(self, request, blocking=True):
820842 return self ._send (request , blocking = blocking )
821843
822844 def _send (self , request , blocking = True ):
823- assert self .state in (ConnectionStates .AUTHENTICATING , ConnectionStates .CONNECTED )
824845 future = Future ()
825846 with self ._lock :
847+ if not self ._can_send_recv ():
848+ return future .failure (Errors .NodeNotReadyError (str (self )))
849+
826850 correlation_id = self ._protocol .send_request (request )
827851
828- log .debug ('%s Request %d: %s' , self , correlation_id , request )
829- if request .expect_response ():
830- sent_time = time .time ()
831- assert correlation_id not in self .in_flight_requests , 'Correlation ID already in-flight!'
832- self .in_flight_requests [correlation_id ] = (future , sent_time )
833- else :
834- future .success (None )
852+ log .debug ('%s Request %d: %s' , self , correlation_id , request )
853+ if request .expect_response ():
854+ sent_time = time .time ()
855+ assert correlation_id not in self .in_flight_requests , 'Correlation ID already in-flight!'
856+ self .in_flight_requests [correlation_id ] = (future , sent_time )
857+ else :
858+ future .success (None )
835859
836860 # Attempt to replicate behavior from prior to introduction of
837861 # send_pending_requests() / async sends
@@ -842,16 +866,15 @@ def _send(self, request, blocking=True):
842866
843867 def send_pending_requests (self ):
844868 """Can block on network if request is larger than send_buffer_bytes"""
845- if self .state not in (ConnectionStates .AUTHENTICATING ,
846- ConnectionStates .CONNECTED ):
847- return Errors .NodeNotReadyError (str (self ))
848- with self ._lock :
849- data = self ._protocol .send_bytes ()
850869 try :
851- # In the future we might manage an internal write buffer
852- # and send bytes asynchronously. For now, just block
853- # sending each request payload
854- total_bytes = self ._send_bytes_blocking (data )
870+ with self ._lock :
871+ if not self ._can_send_recv ():
872+ return Errors .NodeNotReadyError (str (self ))
873+ # In the future we might manage an internal write buffer
874+ # and send bytes asynchronously. For now, just block
875+ # sending each request payload
876+ data = self ._protocol .send_bytes ()
877+ total_bytes = self ._send_bytes_blocking (data )
855878 if self ._sensors :
856879 self ._sensors .bytes_sent .record (total_bytes )
857880 return total_bytes
@@ -871,18 +894,6 @@ def recv(self):
871894
872895 Return list of (response, future) tuples
873896 """
874- if not self .connected () and not self .state is ConnectionStates .AUTHENTICATING :
875- log .warning ('%s cannot recv: socket not connected' , self )
876- # If requests are pending, we should close the socket and
877- # fail all the pending request futures
878- if self .in_flight_requests :
879- self .close (Errors .KafkaConnectionError ('Socket not connected during recv with in-flight-requests' ))
880- return ()
881-
882- elif not self .in_flight_requests :
883- log .warning ('%s: No in-flight-requests to recv' , self )
884- return ()
885-
886897 responses = self ._recv ()
887898 if not responses and self .requests_timed_out ():
888899 log .warning ('%s timed out after %s ms. Closing connection.' ,
@@ -895,7 +906,8 @@ def recv(self):
895906 # augment respones w/ correlation_id, future, and timestamp
896907 for i , (correlation_id , response ) in enumerate (responses ):
897908 try :
898- (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
909+ with self ._lock :
910+ (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
899911 except KeyError :
900912 self .close (Errors .KafkaConnectionError ('Received unrecognized correlation id' ))
901913 return ()
@@ -911,6 +923,12 @@ def recv(self):
911923 def _recv (self ):
912924 """Take all available bytes from socket, return list of any responses from parser"""
913925 recvd = []
926+ self ._lock .acquire ()
927+ if not self ._can_send_recv ():
928+ log .warning ('%s cannot recv: socket not connected' , self )
929+ self ._lock .release ()
930+ return ()
931+
914932 while len (recvd ) < self .config ['sock_chunk_buffer_count' ]:
915933 try :
916934 data = self ._sock .recv (self .config ['sock_chunk_bytes' ])
@@ -920,6 +938,7 @@ def _recv(self):
920938 # without an exception raised
921939 if not data :
922940 log .error ('%s: socket disconnected' , self )
941+ self ._lock .release ()
923942 self .close (error = Errors .KafkaConnectionError ('socket disconnected' ))
924943 return []
925944 else :
@@ -932,11 +951,13 @@ def _recv(self):
932951 break
933952 log .exception ('%s: Error receiving network data'
934953 ' closing socket' , self )
954+ self ._lock .release ()
935955 self .close (error = Errors .KafkaConnectionError (e ))
936956 return []
937957 except BlockingIOError :
938958 if six .PY3 :
939959 break
960+ self ._lock .release ()
940961 raise
941962
942963 recvd_data = b'' .join (recvd )
@@ -946,20 +967,23 @@ def _recv(self):
946967 try :
947968 responses = self ._protocol .receive_bytes (recvd_data )
948969 except Errors .KafkaProtocolError as e :
970+ self ._lock .release ()
949971 self .close (e )
950972 return []
951973 else :
974+ self ._lock .release ()
952975 return responses
953976
954977 def requests_timed_out (self ):
955- if self .in_flight_requests :
956- get_timestamp = lambda v : v [1 ]
957- oldest_at = min (map (get_timestamp ,
958- self .in_flight_requests .values ()))
959- timeout = self .config ['request_timeout_ms' ] / 1000.0
960- if time .time () >= oldest_at + timeout :
961- return True
962- return False
978+ with self ._lock :
979+ if self .in_flight_requests :
980+ get_timestamp = lambda v : v [1 ]
981+ oldest_at = min (map (get_timestamp ,
982+ self .in_flight_requests .values ()))
983+ timeout = self .config ['request_timeout_ms' ] / 1000.0
984+ if time .time () >= oldest_at + timeout :
985+ return True
986+ return False
963987
964988 def _handle_api_version_response (self , response ):
965989 error_type = Errors .for_code (response .error_code )
0 commit comments