SCRAM.error should be within standardised errors (put raw in self.raw_error)

This commit is contained in:
jesopo 2019-02-15 20:09:32 +00:00
parent e51c653c1e
commit 686d852e2b

View file

@ -7,6 +7,20 @@ import base64, enum, hashlib, hmac, os, typing
ALGORITHMS = [ ALGORITHMS = [
"MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"] "MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"]
SCRAM_ERRORS = [
"invalid-encoding",
"extensions-not-supported", # unrecognized 'm' value
"invalid-proof",
"channel-bindings-dont-match",
"server-does-support-channel-binding",
"channel-binding-not-supported",
"unsupported-channel-binding-type",
"unknown-user",
"invalid-username-encoding", # invalid utf8 of bad SASLprep
"no-resources",
"other-error"
]
def _scram_nonce() -> bytes: def _scram_nonce() -> bytes:
return base64.b64encode(os.urandom(32)) return base64.b64encode(os.urandom(32))
def _scram_escape(s: bytes) -> bytes: def _scram_escape(s: bytes) -> bytes:
@ -38,6 +52,7 @@ class SCRAM(object):
self.state = SCRAMState.Uninitialised self.state = SCRAMState.Uninitialised
self.error = "" self.error = ""
self.raw_error = ""
self._client_first = b"" self._client_first = b""
self._salted_password = b"" self._salted_password = b""
@ -93,7 +108,13 @@ class SCRAM(object):
def server_final(self, data: bytes) -> bool: def server_final(self, data: bytes) -> bool:
pieces = self._get_pieces(data) pieces = self._get_pieces(data)
if b"e" in pieces: if b"e" in pieces:
self.error = pieces[b"e"].decode("utf8") error = pieces[b"e"].decode("utf8")
self.raw_error = error
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.Failed self.state = SCRAMState.Failed
return False return False