diff --git a/modules/sasl/__init__.py b/modules/sasl/__init__.py index 0c74cc97..d7a768ae 100644 --- a/modules/sasl/__init__.py +++ b/modules/sasl/__init__.py @@ -74,10 +74,11 @@ class Module(ModuleManager.BaseModule): auth_text = event["server"]._scram.client_first() else: current_scram = event["server"]._scram + data = base64.b64decode(event["message"]) if current_scram.state == scram.SCRAMState.ClientFirst: - auth_text = current_scram.server_first(event["message"]) + auth_text = current_scram.server_first(data) elif current_scram.state == scram.SCRAMState.ClientFinal: - auth_text = current_scram.server_final(event["message"]) + auth_text = current_scram.server_final(data) del event["server"]._scram if current_scram.state == scram.SCRAMState.VerifyFailed: diff --git a/modules/sasl/scram.py b/modules/sasl/scram.py index 65dfa790..2ac402e1 100644 --- a/modules/sasl/scram.py +++ b/modules/sasl/scram.py @@ -1,4 +1,4 @@ -import base64, enum, hashlib, hmac, uuid +import base64, enum, hashlib, hmac, typing, uuid def _scram_nonce(): return uuid.uuid4().hex.encode("utf8") @@ -30,26 +30,25 @@ class SCRAM(object): self._salted_password = None self._auth_message = None - def _get_data(self, message): - data = base64.b64decode(message) - return data, dict(piece.split(b"=", 1) for piece in data.split(b",")) + def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]: + return dict(piece.split(b"=", 1) for piece in data.split(b",")) - def _hmac(self, key, msg): + def _hmac(self, key: bytes, msg: bytes) -> bytes: return hmac.digest(key, msg, self._algo) - def _hash(self, msg): + def _hash(self, msg: bytes) -> bytes: return hashlib.new(self._algo, msg).digest() - def client_first(self): + def client_first(self) -> bytes: self.state = SCRAMState.ClientFirst # start SCRAM handshake self._client_first = b"n=%s,r=%s" % ( _scram_escape(self._username), _scram_nonce()) return b"n,,%s" % self._client_first - def server_first(self, message): + def server_first(self, data: bytes) -> bytes: self.state = SCRAMState.ClientFinal - data, pieces = self._get_data(message) + pieces = self._get_pieces(data) # server-first-message nonce = pieces[b"r"] salt = base64.b64decode(pieces[b"s"]) @@ -74,9 +73,9 @@ class SCRAM(object): return auth_noproof + (b",p=%s" % client_proof) - def server_final(self, message): + def server_final(self, data: bytes) -> bytes: # server-final-message - data, pieces = self._get_data(message) + pieces = self._get_pieces(data) verifier = pieces[b"v"] server_key = self._hmac(self._salted_password, b"Server Key")