import base64, enum, hashlib, hmac, uuid def _scram_nonce(): return uuid.uuid4().hex.encode("utf8") def _scram_escape(s): return s.replace(b"=", b"=3D").replace(b",", b"=2C") def _scram_unescape(s): return s.replace(b"=3D", b"=").replace(b"=2C", b",") def _scram_xor(s1, s2): return bytes(a ^ b for a, b in zip(s1, s2)) class SCRAMState(enum.Enum): Uninitialised = 0 ClientFirst = 1 ClientFinal = 2 Success = 3 VerifyFailed = 4 class SCRAMError(Exception): pass class SCRAM(object): def __init__(self, algo, username, password): self._algo = algo self._username = username.encode("utf8") self._password = password.encode("utf8") self.state = SCRAMState.Uninitialised self._client_first = None self._salted_password = None self._auth_message = None def _get_data(self, message): data = base64.b64decode(message) return data, dict(piece.split(b"=", 1) for piece in data.split(b",")) def client_first(self): self.state = SCRAMState.ClientFirst # start SCRAM handshake self._client_first = b"n=%s,r=%s" % ( _scram_escape(self._username), _scram_nonce()) return b"n,,%s" % self._client_first def server_first(self, message): self.state = SCRAMState.ClientFinal data, pieces = self._get_data(message) # server-first-message nonce = pieces[b"r"] salt = base64.b64decode(pieces[b"s"]) iterations = pieces[b"i"] password = self._password salted_password = hashlib.pbkdf2_hmac(self._algo, password, salt, int(iterations), dklen=None) self._salted_password = salted_password client_key = hmac.digest(salted_password, b"Client Key", self._algo) stored_key = hashlib.new(self._algo, client_key).digest() channel = base64.b64encode(b"n,,") auth_noproof = b"c=%s,r=%s" % (channel, nonce) auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof) self._auth_message = auth_message client_signature = hmac.digest(stored_key, auth_message, self._algo) client_proof = base64.b64encode( _scram_xor(client_key, client_signature)) return auth_noproof + (b",p=%s" % client_proof) def server_final(self, message): # server-final-message data, pieces = self._get_data(message) verifier = pieces[b"v"] server_key = hmac.digest(self._salted_password, b"Server Key", self._algo) server_signature = hmac.digest(server_key, self._auth_message, self._algo) if server_signature != base64.b64decode(verifier): self.state = SCRAMState.VerifyFailed return None else: self.state = SCRAMState.Success return "+"