Skip to content

Commit 1df58bf

Browse files
authored
Check for disconnects during ssl handshake and sasl authentication (#1249)
1 parent 5c17cf0 commit 1df58bf

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

kafka/conn.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,15 @@ def connect(self):
299299
self._sock.setsockopt(*option)
300300

301301
self._sock.setblocking(False)
302+
self.last_attempt = time.time()
303+
self.state = ConnectionStates.CONNECTING
302304
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
303305
self._wrap_ssl()
304-
log.info('%s: connecting to %s:%d', self, self.host, self.port)
305-
self.state = ConnectionStates.CONNECTING
306-
self.last_attempt = time.time()
307-
self.config['state_change_callback'](self)
306+
# _wrap_ssl can alter the connection state -- disconnects on failure
307+
# so we need to double check that we are still connecting before
308+
if self.connecting():
309+
self.config['state_change_callback'](self)
310+
log.info('%s: connecting to %s:%d', self, self.host, self.port)
308311

309312
if self.state is ConnectionStates.CONNECTING:
310313
# in non-blocking mode, use repeated calls to socket.connect_ex
@@ -367,10 +370,12 @@ def connect(self):
367370
if self.state is ConnectionStates.AUTHENTICATING:
368371
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
369372
if self._try_authenticate():
370-
log.debug('%s: Connection complete.', self)
371-
self.state = ConnectionStates.CONNECTED
372-
self._reset_reconnect_backoff()
373-
self.config['state_change_callback'](self)
373+
# _try_authenticate has side-effects: possibly disconnected on socket errors
374+
if self.state is ConnectionStates.AUTHENTICATING:
375+
log.debug('%s: Connection complete.', self)
376+
self.state = ConnectionStates.CONNECTED
377+
self._reset_reconnect_backoff()
378+
self.config['state_change_callback'](self)
374379

375380
return self.state
376381

@@ -397,10 +402,7 @@ def _wrap_ssl(self):
397402
password=self.config['ssl_password'])
398403
if self.config['ssl_crlfile']:
399404
if not hasattr(ssl, 'VERIFY_CRL_CHECK_LEAF'):
400-
error = 'No CRL support with this version of Python.'
401-
log.error('%s: %s Disconnecting.', self, error)
402-
self.close(Errors.ConnectionError(error))
403-
return
405+
raise RuntimeError('This version of Python does not support ssl_crlfile!')
404406
log.info('%s: Loading SSL CRL from %s', self, self.config['ssl_crlfile'])
405407
self._ssl_context.load_verify_locations(self.config['ssl_crlfile'])
406408
# pylint: disable=no-member
@@ -443,7 +445,9 @@ def _try_authenticate(self):
443445
self._sasl_auth_future = future
444446
self._recv()
445447
if self._sasl_auth_future.failed():
446-
raise self._sasl_auth_future.exception # pylint: disable-msg=raising-bad-type
448+
ex = self._sasl_auth_future.exception
449+
if not isinstance(ex, Errors.ConnectionError):
450+
raise ex # pylint: disable-msg=raising-bad-type
447451
return self._sasl_auth_future.succeeded()
448452

449453
def _handle_sasl_handshake_response(self, future, response):
@@ -463,6 +467,19 @@ def _handle_sasl_handshake_response(self, future, response):
463467
'kafka-python does not support SASL mechanism %s' %
464468
self.config['sasl_mechanism']))
465469

470+
def _recv_bytes_blocking(self, n):
471+
self._sock.setblocking(True)
472+
try:
473+
data = b''
474+
while len(data) < n:
475+
fragment = self._sock.recv(n - len(data))
476+
if not fragment:
477+
raise ConnectionError('Connection reset during recv')
478+
data += fragment
479+
return data
480+
finally:
481+
self._sock.setblocking(False)
482+
466483
def _try_authenticate_plain(self, future):
467484
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
468485
log.warning('%s: Sending username and password in the clear', self)
@@ -476,30 +493,23 @@ def _try_authenticate_plain(self, future):
476493
self.config['sasl_plain_password']]).encode('utf-8'))
477494
size = Int32.encode(len(msg))
478495
self._sock.sendall(size + msg)
496+
self._sock.setblocking(False)
479497

480498
# The server will send a zero sized message (that is Int32(0)) on success.
481499
# The connection is closed on failure
482-
while len(data) < 4:
483-
fragment = self._sock.recv(4 - len(data))
484-
if not fragment:
485-
log.error('%s: Authentication failed for user %s', self, self.config['sasl_plain_username'])
486-
error = Errors.AuthenticationFailedError(
487-
'Authentication failed for user {0}'.format(
488-
self.config['sasl_plain_username']))
489-
future.failure(error)
490-
raise error
491-
data += fragment
492-
self._sock.setblocking(False)
493-
except (AssertionError, ConnectionError) as e:
500+
self._recv_bytes_blocking(4)
501+
502+
except ConnectionError as e:
494503
log.exception("%s: Error receiving reply from server", self)
495504
error = Errors.ConnectionError("%s: %s" % (self, e))
496-
future.failure(error)
497505
self.close(error=error)
506+
return future.failure(error)
498507

499508
if data != b'\x00\x00\x00\x00':
500-
return future.failure(Errors.AuthenticationFailedError())
509+
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
510+
return future.failure(error)
501511

502-
log.info('%s: Authenticated as %s', self, self.config['sasl_plain_username'])
512+
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
503513
return future.success(True)
504514

505515
def _try_authenticate_gssapi(self, future):
@@ -524,14 +534,15 @@ def _try_authenticate_gssapi(self, future):
524534
msg = output_token
525535
size = Int32.encode(len(msg))
526536
self._sock.sendall(size + msg)
537+
self._sock.setblocking(False)
538+
527539
# The server will send a token back. Processing of this token either
528540
# establishes a security context, or it needs further token exchange.
529541
# The gssapi will be able to identify the needed next step.
530542
# The connection is closed on failure.
531-
header = self._sock.recv(4)
543+
header = self._recv_bytes_blocking(4)
532544
token_size = struct.unpack('>i', header)
533-
received_token = self._sock.recv(token_size)
534-
self._sock.setblocking(False)
545+
received_token = self._recv_bytes_blocking(token_size)
535546

536547
except ConnectionError as e:
537548
log.exception("%s: Error receiving reply from server", self)

0 commit comments

Comments
 (0)