3
3
import base64
4
4
import hashlib
5
5
import hmac
6
+ import logging
6
7
import uuid
7
8
9
+
10
+ from kafka .sasl .abc import SaslMechanism
8
11
from kafka .vendor import six
9
12
10
13
14
+ log = logging .getLogger (__name__ )
15
+
16
+
11
17
if six .PY2 :
12
18
def xor_bytes (left , right ):
13
19
return bytearray (ord (lb ) ^ ord (rb ) for lb , rb in zip (left , right ))
@@ -16,17 +22,58 @@ def xor_bytes(left, right):
16
22
return bytes (lb ^ rb for lb , rb in zip (left , right ))
17
23
18
24
25
+ class SaslMechanismScram (SaslMechanism ):
26
+
27
+ def __init__ (self , ** config ):
28
+ assert config ['sasl_plain_username' ] is not None , 'sasl_plain_username required for SCRAM sasl'
29
+ assert config ['sasl_plain_password' ] is not None , 'sasl_plain_password required for SCRAM sasl'
30
+ if config ['security_protocol' ] == 'SASL_PLAINTEXT' :
31
+ log .warning ('Exchanging credentials in the clear during Sasl Authentication' )
32
+
33
+ self ._scram_client = ScramClient (
34
+ config ['sasl_plain_username' ],
35
+ config ['sasl_plain_password' ],
36
+ config ['sasl_mechanism' ]
37
+ )
38
+ self ._state = 0
39
+
40
+ def auth_bytes (self ):
41
+ if self ._state == 0 :
42
+ return self ._scram_client .first_message ()
43
+ elif self ._state == 1 :
44
+ return self ._scram_client .final_message ()
45
+ else :
46
+ raise ValueError ('No auth_bytes for state: %s' % self ._state )
47
+
48
+ def receive (self , auth_bytes ):
49
+ if self ._state == 0 :
50
+ self ._scram_client .process_server_first_message (auth_bytes )
51
+ elif self ._state == 1 :
52
+ self ._scram_client .process_server_final_message (auth_bytes )
53
+ else :
54
+ raise ValueError ('Cannot receive bytes in state: %s' % self ._state )
55
+ self ._state += 1
56
+ return self .is_done ()
57
+
58
+ def is_done (self ):
59
+ return self ._state == 2
60
+
61
+ def is_authenticated (self ):
62
+ # receive raises if authentication fails...?
63
+ return self ._state == 2
64
+
65
+
19
66
class ScramClient :
20
67
MECHANISMS = {
21
68
'SCRAM-SHA-256' : hashlib .sha256 ,
22
69
'SCRAM-SHA-512' : hashlib .sha512
23
70
}
24
71
25
72
def __init__ (self , user , password , mechanism ):
26
- self .nonce = str (uuid .uuid4 ()).replace ('-' , '' )
27
- self .auth_message = ''
73
+ self .nonce = str (uuid .uuid4 ()).replace ('-' , '' ). encode ( 'utf-8' )
74
+ self .auth_message = b ''
28
75
self .salted_password = None
29
- self .user = user
76
+ self .user = user . encode ( 'utf-8' )
30
77
self .password = password .encode ('utf-8' )
31
78
self .hashfunc = self .MECHANISMS [mechanism ]
32
79
self .hashname = '' .join (mechanism .lower ().split ('-' )[1 :3 ])
@@ -38,29 +85,29 @@ def __init__(self, user, password, mechanism):
38
85
self .server_signature = None
39
86
40
87
def first_message (self ):
41
- client_first_bare = 'n={},r={}' . format ( self .user , self .nonce )
88
+ client_first_bare = b 'n=' + self .user + b',r=' + self .nonce
42
89
self .auth_message += client_first_bare
43
- return 'n,,' + client_first_bare
90
+ return b 'n,,' + client_first_bare
44
91
45
92
def process_server_first_message (self , server_first_message ):
46
- self .auth_message += ',' + server_first_message
47
- params = dict (pair .split ('=' , 1 ) for pair in server_first_message .split (',' ))
48
- server_nonce = params ['r' ]
93
+ self .auth_message += b ',' + server_first_message
94
+ params = dict (pair .split ('=' , 1 ) for pair in server_first_message .decode ( 'utf-8' ). split (',' ))
95
+ server_nonce = params ['r' ]. encode ( 'utf-8' )
49
96
if not server_nonce .startswith (self .nonce ):
50
97
raise ValueError ("Server nonce, did not start with client nonce!" )
51
98
self .nonce = server_nonce
52
- self .auth_message += ',c=biws,r=' + self .nonce
99
+ self .auth_message += b ',c=biws,r=' + self .nonce
53
100
54
101
salt = base64 .b64decode (params ['s' ].encode ('utf-8' ))
55
102
iterations = int (params ['i' ])
56
103
self .create_salted_password (salt , iterations )
57
104
58
105
self .client_key = self .hmac (self .salted_password , b'Client Key' )
59
106
self .stored_key = self .hashfunc (self .client_key ).digest ()
60
- self .client_signature = self .hmac (self .stored_key , self .auth_message . encode ( 'utf-8' ) )
107
+ self .client_signature = self .hmac (self .stored_key , self .auth_message )
61
108
self .client_proof = xor_bytes (self .client_key , self .client_signature )
62
109
self .server_key = self .hmac (self .salted_password , b'Server Key' )
63
- self .server_signature = self .hmac (self .server_key , self .auth_message . encode ( 'utf-8' ) )
110
+ self .server_signature = self .hmac (self .server_key , self .auth_message )
64
111
65
112
def hmac (self , key , msg ):
66
113
return hmac .new (key , msg , digestmod = self .hashfunc ).digest ()
@@ -71,11 +118,9 @@ def create_salted_password(self, salt, iterations):
71
118
)
72
119
73
120
def final_message (self ):
74
- return 'c=biws,r={},p={}' . format ( self .nonce , base64 .b64encode (self .client_proof ). decode ( 'utf-8' ) )
121
+ return b 'c=biws,r=' + self .nonce + b',p=' + base64 .b64encode (self .client_proof )
75
122
76
123
def process_server_final_message (self , server_final_message ):
77
- params = dict (pair .split ('=' , 1 ) for pair in server_final_message .split (',' ))
124
+ params = dict (pair .split ('=' , 1 ) for pair in server_final_message .decode ( 'utf-8' ). split (',' ))
78
125
if self .server_signature != base64 .b64decode (params ['v' ].encode ('utf-8' )):
79
126
raise ValueError ("Server sent wrong signature!" )
80
-
81
-
0 commit comments