From 51a4b8ef4ebd05cd1f5ec82b8d7ddc49cd2e7b4b Mon Sep 17 00:00:00 2001 From: jesopo Date: Tue, 5 Feb 2019 12:17:25 +0000 Subject: [PATCH] Support SCRAM SASL mechanisms (sasl.py) --- modules/sasl.py | 114 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 98 insertions(+), 16 deletions(-) diff --git a/modules/sasl.py b/modules/sasl.py index 505bdfe4..32673d0b 100644 --- a/modules/sasl.py +++ b/modules/sasl.py @@ -1,4 +1,4 @@ -import base64 +import base64, hashlib, hmac, uuid from src import ModuleManager, utils def _validate(self, s): @@ -7,6 +7,15 @@ def _validate(self, s): mechanism, arguments = s.split(" ", 1) return {"mechanism": mechanism, "args": arguments} +def _scram_nonce(): + return str(uuid.uuid4().hex) +def _scram_escape(s): + return s.replace("=", "=3D").replace(",", "=2C") +def _scram_unescape(s): + return s.replace("=3D", "=").replace("=2C", ",") +def _scram_xor(s1, s2): + return bytes(a ^ b for a, b in zip(s1, s2)) + @utils.export("serverset", {"setting": "sasl", "help": "Set the sasl username/password for this server", "validate": _validate}) @@ -37,24 +46,97 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.authenticate") def on_authenticate(self, event): - if event["message"] != "+": - event["server"].send_authenticate("*") - else: - sasl = event["server"].get_setting("sasl") - mechanism = sasl["mechanism"].upper() + sasl = event["server"].get_setting("sasl") + mechanism = sasl["mechanism"].upper() - if mechanism == "PLAIN": - sasl_nick, sasl_pass = sasl["args"].split(":", 1) - auth_text = "%s\0%s\0%s" % (sasl_nick, sasl_nick, sasl_pass) - elif mechanism == "EXTERNAL": - auth_text = "+" + if mechanism == "PLAIN": + if event["message"] != "+": + event["server"].send_authenticate("*") else: - raise ValueError("unknown sasl mechanism '%s'" % mechanism) + sasl_username, sasl_password = sasl["args"].split(":", 1) + auth_text = ("%s\0%s\0%s" % ( + sasl_username, sasl_username, sasl_password)).encode("utf8") - if not auth_text == "+": - auth_text = base64.b64encode(auth_text.encode("utf8")) - auth_text = auth_text.decode("utf8") - event["server"].send_authenticate(auth_text) + elif mechanism == "EXTERNAL": + if event["message"] != "+": + event["server"].send_authenticate("*") + else: + auth_text = "+" + + elif mechanism.startswith("SCRAM-"): + algo = mechanism.split("SCRAM-", 1)[1].replace("-", "") + sasl_username, sasl_password = sasl["args"].split(":", 1) + if event["message"] == "+": + # start SCRAM handshake + first_base = "n=%s,r=%s" % ( + _scram_escape(sasl_username), _scram_nonce()) + first_withchannel = "n,,%s" % first_base + auth_text = first_withchannel.encode("utf8") + event["server"]._scram_first = first_base.encode("utf8") + self.log.debug("SCRAM client-first-message: %s", + [first_withchannel]) + else: + data = base64.b64decode(event["message"]).decode("utf8") + pieces = dict(piece.split("=", 1) for piece in data.split(",")) + if "s" in pieces: + # server-first-message + self.log.debug("SCRAM server-first-message: %s", [data]) + + nonce = pieces["r"].encode("utf8") + salt = base64.b64decode(pieces["s"]) + iterations = pieces["i"] + password = sasl_password.encode("utf8") + self.log.debug("SCRAM server-first-message salt: %s", + [salt]) + + salted_password = hashlib.pbkdf2_hmac(algo, password, salt, + int(iterations), dklen=None) + self.log.debug("SCRAM server-first-message salted: %s", + [salted_password]) + event["server"]._scram_salted_password = salted_password + + client_key = hmac.digest(salted_password, b"Client Key", + algo) + stored_key = hashlib.new(algo, client_key).digest() + + channel = base64.b64encode(b"n,,") + auth_noproof = b"c=%s,r=%s" % (channel, nonce) + auth_message = b"%s,%s,%s" % (event["server"]._scram_first, + data.encode("utf8"), auth_noproof) + self.log.debug("SCRAM server-first-message auth msg: %s", + [auth_message]) + event["server"]._scram_auth_message = auth_message + + client_signature = hmac.digest(stored_key, auth_message, + algo) + client_proof = base64.b64encode( + _scram_xor(client_key, client_signature)) + + auth_text = auth_noproof + (b",p=%s" % client_proof) + elif "v" in pieces: + # server-final-message + verifier = pieces["v"] + + salted_password = event["server"]._scram_salted_password + auth_message = event["server"]._scram_auth_message + server_key = hmac.digest(salted_password, b"Server Key", + algo) + server_signature = hmac.digest(server_key, auth_message, + algo) + + if server_signature != base64.b64decode(verifier): + raise ValueError("SCRAM %s authentication failed " + % algo) + event["server"].disconnect() + auth_text = "+" + + else: + raise ValueError("unknown sasl mechanism '%s'" % mechanism) + + if not auth_text == "+": + auth_text = base64.b64encode(auth_text) + auth_text = auth_text.decode("utf8") + event["server"].send_authenticate(auth_text) def _end_sasl(self, server): server.capability_done("sasl")