Skip to content

Commit 7a69952

Browse files
authored
Improve connection lock handling; always use context manager (#1895)
1 parent 61fa0b2 commit 7a69952

File tree

1 file changed

+151
-126
lines changed

1 file changed

+151
-126
lines changed

kafka/conn.py

Lines changed: 151 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -593,21 +593,30 @@ def _try_authenticate_plain(self, future):
593593
self.config['sasl_plain_username'],
594594
self.config['sasl_plain_password']]).encode('utf-8'))
595595
size = Int32.encode(len(msg))
596-
try:
597-
with self._lock:
598-
if not self._can_send_recv():
599-
return future.failure(Errors.NodeNotReadyError(str(self)))
600-
self._send_bytes_blocking(size + msg)
601596

602-
# The server will send a zero sized message (that is Int32(0)) on success.
603-
# The connection is closed on failure
604-
data = self._recv_bytes_blocking(4)
597+
err = None
598+
close = False
599+
with self._lock:
600+
if not self._can_send_recv():
601+
err = Errors.NodeNotReadyError(str(self))
602+
close = False
603+
else:
604+
try:
605+
self._send_bytes_blocking(size + msg)
606+
607+
# The server will send a zero sized message (that is Int32(0)) on success.
608+
# The connection is closed on failure
609+
data = self._recv_bytes_blocking(4)
605610

606-
except (ConnectionError, TimeoutError) as e:
607-
log.exception("%s: Error receiving reply from server", self)
608-
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
609-
self.close(error=error)
610-
return future.failure(error)
611+
except (ConnectionError, TimeoutError) as e:
612+
log.exception("%s: Error receiving reply from server", self)
613+
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
614+
close = True
615+
616+
if err is not None:
617+
if close:
618+
self.close(error=err)
619+
return future.failure(err)
611620

612621
if data != b'\x00\x00\x00\x00':
613622
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
@@ -625,61 +634,67 @@ def _try_authenticate_gssapi(self, future):
625634
).canonicalize(gssapi.MechType.kerberos)
626635
log.debug('%s: GSSAPI name: %s', self, gssapi_name)
627636

628-
self._lock.acquire()
629-
if not self._can_send_recv():
630-
return future.failure(Errors.NodeNotReadyError(str(self)))
631-
# Establish security context and negotiate protection level
632-
# For reference RFC 2222, section 7.2.1
633-
try:
634-
# Exchange tokens until authentication either succeeds or fails
635-
client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate')
636-
received_token = None
637-
while not client_ctx.complete:
638-
# calculate an output token from kafka token (or None if first iteration)
639-
output_token = client_ctx.step(received_token)
640-
641-
# pass output token to kafka, or send empty response if the security
642-
# context is complete (output token is None in that case)
643-
if output_token is None:
644-
self._send_bytes_blocking(Int32.encode(0))
645-
else:
646-
msg = output_token
637+
err = None
638+
close = False
639+
with self._lock:
640+
if not self._can_send_recv():
641+
err = Errors.NodeNotReadyError(str(self))
642+
close = False
643+
else:
644+
# Establish security context and negotiate protection level
645+
# For reference RFC 2222, section 7.2.1
646+
try:
647+
# Exchange tokens until authentication either succeeds or fails
648+
client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate')
649+
received_token = None
650+
while not client_ctx.complete:
651+
# calculate an output token from kafka token (or None if first iteration)
652+
output_token = client_ctx.step(received_token)
653+
654+
# pass output token to kafka, or send empty response if the security
655+
# context is complete (output token is None in that case)
656+
if output_token is None:
657+
self._send_bytes_blocking(Int32.encode(0))
658+
else:
659+
msg = output_token
660+
size = Int32.encode(len(msg))
661+
self._send_bytes_blocking(size + msg)
662+
663+
# The server will send a token back. Processing of this token either
664+
# establishes a security context, or it needs further token exchange.
665+
# The gssapi will be able to identify the needed next step.
666+
# The connection is closed on failure.
667+
header = self._recv_bytes_blocking(4)
668+
(token_size,) = struct.unpack('>i', header)
669+
received_token = self._recv_bytes_blocking(token_size)
670+
671+
# Process the security layer negotiation token, sent by the server
672+
# once the security context is established.
673+
674+
# unwraps message containing supported protection levels and msg size
675+
msg = client_ctx.unwrap(received_token).message
676+
# Kafka currently doesn't support integrity or confidentiality security layers, so we
677+
# simply set QoP to 'auth' only (first octet). We reuse the max message size proposed
678+
# by the server
679+
msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:]
680+
# add authorization identity to the response, GSS-wrap and send it
681+
msg = client_ctx.wrap(msg + auth_id.encode(), False).message
647682
size = Int32.encode(len(msg))
648683
self._send_bytes_blocking(size + msg)
649684

650-
# The server will send a token back. Processing of this token either
651-
# establishes a security context, or it needs further token exchange.
652-
# The gssapi will be able to identify the needed next step.
653-
# The connection is closed on failure.
654-
header = self._recv_bytes_blocking(4)
655-
(token_size,) = struct.unpack('>i', header)
656-
received_token = self._recv_bytes_blocking(token_size)
657-
658-
# Process the security layer negotiation token, sent by the server
659-
# once the security context is established.
660-
661-
# unwraps message containing supported protection levels and msg size
662-
msg = client_ctx.unwrap(received_token).message
663-
# Kafka currently doesn't support integrity or confidentiality security layers, so we
664-
# simply set QoP to 'auth' only (first octet). We reuse the max message size proposed
665-
# by the server
666-
msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:]
667-
# add authorization identity to the response, GSS-wrap and send it
668-
msg = client_ctx.wrap(msg + auth_id.encode(), False).message
669-
size = Int32.encode(len(msg))
670-
self._send_bytes_blocking(size + msg)
685+
except (ConnectionError, TimeoutError) as e:
686+
log.exception("%s: Error receiving reply from server", self)
687+
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
688+
close = True
689+
except Exception as e:
690+
err = e
691+
close = True
671692

672-
except (ConnectionError, TimeoutError) as e:
673-
self._lock.release()
674-
log.exception("%s: Error receiving reply from server", self)
675-
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
676-
self.close(error=error)
677-
return future.failure(error)
678-
except Exception as e:
679-
self._lock.release()
680-
return future.failure(e)
693+
if err is not None:
694+
if close:
695+
self.close(error=err)
696+
return future.failure(err)
681697

682-
self._lock.release()
683698
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
684699
return future.success(True)
685700

@@ -688,25 +703,31 @@ def _try_authenticate_oauth(self, future):
688703

689704
msg = bytes(self._build_oauth_client_request().encode("utf-8"))
690705
size = Int32.encode(len(msg))
691-
self._lock.acquire()
692-
if not self._can_send_recv():
693-
return future.failure(Errors.NodeNotReadyError(str(self)))
694-
try:
695-
# Send SASL OAuthBearer request with OAuth token
696-
self._send_bytes_blocking(size + msg)
697706

698-
# The server will send a zero sized message (that is Int32(0)) on success.
699-
# The connection is closed on failure
700-
data = self._recv_bytes_blocking(4)
707+
err = None
708+
close = False
709+
with self._lock:
710+
if not self._can_send_recv():
711+
err = Errors.NodeNotReadyError(str(self))
712+
close = False
713+
else:
714+
try:
715+
# Send SASL OAuthBearer request with OAuth token
716+
self._send_bytes_blocking(size + msg)
701717

702-
except (ConnectionError, TimeoutError) as e:
703-
self._lock.release()
704-
log.exception("%s: Error receiving reply from server", self)
705-
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
706-
self.close(error=error)
707-
return future.failure(error)
718+
# The server will send a zero sized message (that is Int32(0)) on success.
719+
# The connection is closed on failure
720+
data = self._recv_bytes_blocking(4)
708721

709-
self._lock.release()
722+
except (ConnectionError, TimeoutError) as e:
723+
log.exception("%s: Error receiving reply from server", self)
724+
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
725+
close = True
726+
727+
if err is not None:
728+
if close:
729+
self.close(error=err)
730+
return future.failure(err)
710731

711732
if data != b'\x00\x00\x00\x00':
712733
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
@@ -857,6 +878,9 @@ def _send(self, request, blocking=True):
857878
future = Future()
858879
with self._lock:
859880
if not self._can_send_recv():
881+
# In this case, since we created the future above,
882+
# we know there are no callbacks/errbacks that could fire w/
883+
# lock. So failing + returning inline should be safe
860884
return future.failure(Errors.NodeNotReadyError(str(self)))
861885

862886
correlation_id = self._protocol.send_request(request)
@@ -935,56 +959,57 @@ def recv(self):
935959
def _recv(self):
936960
"""Take all available bytes from socket, return list of any responses from parser"""
937961
recvd = []
938-
self._lock.acquire()
939-
if not self._can_send_recv():
940-
log.warning('%s cannot recv: socket not connected', self)
941-
self._lock.release()
942-
return ()
943-
944-
while len(recvd) < self.config['sock_chunk_buffer_count']:
945-
try:
946-
data = self._sock.recv(self.config['sock_chunk_bytes'])
947-
# We expect socket.recv to raise an exception if there are no
948-
# bytes available to read from the socket in non-blocking mode.
949-
# but if the socket is disconnected, we will get empty data
950-
# without an exception raised
951-
if not data:
952-
log.error('%s: socket disconnected', self)
953-
self._lock.release()
954-
self.close(error=Errors.KafkaConnectionError('socket disconnected'))
955-
return []
956-
else:
957-
recvd.append(data)
962+
err = None
963+
with self._lock:
964+
if not self._can_send_recv():
965+
log.warning('%s cannot recv: socket not connected', self)
966+
return ()
958967

959-
except SSLWantReadError:
960-
break
961-
except (ConnectionError, TimeoutError) as e:
962-
if six.PY2 and e.errno == errno.EWOULDBLOCK:
968+
while len(recvd) < self.config['sock_chunk_buffer_count']:
969+
try:
970+
data = self._sock.recv(self.config['sock_chunk_bytes'])
971+
# We expect socket.recv to raise an exception if there are no
972+
# bytes available to read from the socket in non-blocking mode.
973+
# but if the socket is disconnected, we will get empty data
974+
# without an exception raised
975+
if not data:
976+
log.error('%s: socket disconnected', self)
977+
err = Errors.KafkaConnectionError('socket disconnected')
978+
break
979+
else:
980+
recvd.append(data)
981+
982+
except SSLWantReadError:
963983
break
964-
log.exception('%s: Error receiving network data'
965-
' closing socket', self)
966-
self._lock.release()
967-
self.close(error=Errors.KafkaConnectionError(e))
968-
return []
969-
except BlockingIOError:
970-
if six.PY3:
984+
except (ConnectionError, TimeoutError) as e:
985+
if six.PY2 and e.errno == errno.EWOULDBLOCK:
986+
break
987+
log.exception('%s: Error receiving network data'
988+
' closing socket', self)
989+
err = Errors.KafkaConnectionError(e)
971990
break
972-
self._lock.release()
973-
raise
974-
975-
recvd_data = b''.join(recvd)
976-
if self._sensors:
977-
self._sensors.bytes_received.record(len(recvd_data))
978-
979-
try:
980-
responses = self._protocol.receive_bytes(recvd_data)
981-
except Errors.KafkaProtocolError as e:
982-
self._lock.release()
983-
self.close(e)
984-
return []
985-
else:
986-
self._lock.release()
987-
return responses
991+
except BlockingIOError:
992+
if six.PY3:
993+
break
994+
# For PY2 this is a catchall and should be re-raised
995+
raise
996+
997+
# Only process bytes if there was no connection exception
998+
if err is None:
999+
recvd_data = b''.join(recvd)
1000+
if self._sensors:
1001+
self._sensors.bytes_received.record(len(recvd_data))
1002+
1003+
# We need to keep the lock through protocol receipt
1004+
# so that we ensure that the processed byte order is the
1005+
# same as the received byte order
1006+
try:
1007+
return self._protocol.receive_bytes(recvd_data)
1008+
except Errors.KafkaProtocolError as e:
1009+
err = e
1010+
1011+
self.close(error=err)
1012+
return ()
9881013

9891014
def requests_timed_out(self):
9901015
with self._lock:

0 commit comments

Comments
 (0)