import base64, enum, hashlib, hmac, typing, uuid

def _scram_nonce():
    return uuid.uuid4().hex.encode("utf8")
def _scram_escape(s):
    return s.replace(b"=", b"=3D").replace(b",", b"=2C")
def _scram_unescape(s):
    return s.replace(b"=3D", b"=").replace(b"=2C", b",")
def _scram_xor(s1, s2):
    return bytes(a ^ b for a, b in zip(s1, s2))

class SCRAMState(enum.Enum):
    Uninitialised = 0
    ClientFirst = 1
    ClientFinal = 2
    Success = 3
    VerifyFailed = 4

class SCRAMError(Exception):
    pass

class SCRAM(object):
    def __init__(self, algo, username, password):
        self._algo = algo
        self._username = username.encode("utf8")
        self._password = password.encode("utf8")

        self.state = SCRAMState.Uninitialised
        self.error = None # typing.Optional[str]

        self._client_first = None # typing.Optional[bytes]
        self._salted_password = None # typing.Optional[bytes]
        self._auth_message = None # typing.Optional[bytes]

    def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]:
        pieces = (piece.split(b"=", 1) for piece in data.split(b","))
        return dict((piece[0], piece[1]) for piece in pieces)

    def _hmac(self, key: bytes, msg: bytes) -> bytes:
        return hmac.digest(key, msg, self._algo)
    def _hash(self, msg: bytes) -> bytes:
        return hashlib.new(self._algo, msg).digest()

    def client_first(self) -> bytes:
        self.state = SCRAMState.ClientFirst
        self._client_first = b"n=%s,r=%s" % (
            _scram_escape(self._username), _scram_nonce())
        return b"n,,%s" % self._client_first

    def server_first(self, data: bytes) -> bytes:
        self.state = SCRAMState.ClientFinal

        pieces = self._get_pieces(data)
        nonce = pieces[b"r"]
        salt = base64.b64decode(pieces[b"s"])
        iterations = pieces[b"i"]
        password = self._password

        salted_password = hashlib.pbkdf2_hmac(self._algo, password, salt,
            int(iterations), dklen=None)
        self._salted_password = salted_password

        client_key = self._hmac(salted_password, b"Client Key")
        stored_key = self._hash(client_key)

        channel = base64.b64encode(b"n,,")
        auth_noproof = b"c=%s,r=%s" % (channel, nonce)
        auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof)
        self._auth_message = auth_message

        client_signature = self._hmac(stored_key, auth_message)
        client_proof = base64.b64encode(
            _scram_xor(client_key, client_signature))

        return auth_noproof + (b",p=%s" % client_proof)

    def server_final(self, data: bytes) -> bool:
        pieces = self._get_pieces(data)
        if b"e" in pieces:
            self.error = pieces[b"e"].decode("utf8")
            return False

        verifier = pieces[b"v"]

        server_key = self._hmac(self._salted_password, b"Server Key")
        server_signature = self._hmac(server_key, self._auth_message)

        if server_signature != base64.b64decode(verifier):
            self.state = SCRAMState.VerifyFailed
            return False
        else:
            self.state = SCRAMState.Success
            return True