Handle error on server-final-message (sasl.scram)

This commit is contained in:
jesopo 2019-02-06 15:28:17 +00:00
parent 403466dee3
commit cfa590eef7

View file

@ -26,9 +26,11 @@ class SCRAM(object):
self._password = password.encode("utf8") self._password = password.encode("utf8")
self.state = SCRAMState.Uninitialised self.state = SCRAMState.Uninitialised
self._client_first = None self.error = None # typing.Optional[str]
self._salted_password = None
self._auth_message = None self._client_first = None # typing.Optional[bytes]
self._salted_password = None # typing.Optional[bytes]
self._auth_message = None # typing.Optional[bytes]
def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]: def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]:
pieces = (piece.split(b"=", 1) for piece in data.split(b",")) pieces = (piece.split(b"=", 1) for piece in data.split(b","))
@ -41,7 +43,6 @@ class SCRAM(object):
def client_first(self) -> bytes: def client_first(self) -> bytes:
self.state = SCRAMState.ClientFirst self.state = SCRAMState.ClientFirst
# start SCRAM handshake
self._client_first = b"n=%s,r=%s" % ( self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), _scram_nonce()) _scram_escape(self._username), _scram_nonce())
return b"n,,%s" % self._client_first return b"n,,%s" % self._client_first
@ -50,7 +51,6 @@ class SCRAM(object):
self.state = SCRAMState.ClientFinal self.state = SCRAMState.ClientFinal
pieces = self._get_pieces(data) pieces = self._get_pieces(data)
# server-first-message
nonce = pieces[b"r"] nonce = pieces[b"r"]
salt = base64.b64decode(pieces[b"s"]) salt = base64.b64decode(pieces[b"s"])
iterations = pieces[b"i"] iterations = pieces[b"i"]
@ -75,8 +75,11 @@ class SCRAM(object):
return auth_noproof + (b",p=%s" % client_proof) return auth_noproof + (b",p=%s" % client_proof)
def server_final(self, data: bytes) -> bool: def server_final(self, data: bytes) -> bool:
# server-final-message
pieces = self._get_pieces(data) pieces = self._get_pieces(data)
if b"e" in pieces:
self.error = pieces[b"e"].decode("utf8")
return False
verifier = pieces[b"v"] verifier = pieces[b"v"]
server_key = self._hmac(self._salted_password, b"Server Key") server_key = self._hmac(self._salted_password, b"Server Key")