From 1fe20a2c98ed5e4042d10c415d7923aebfaa362d Mon Sep 17 00:00:00 2001 From: jesopo Date: Tue, 5 Feb 2019 15:54:20 +0000 Subject: [PATCH] Move sasl.py to a directory module and move SCRAM logic to a different file, move `github/module.py` to `github/__init__.py` --- modules/github/{module.py => __init__.py} | 0 modules/{sasl.py => sasl/__init__.py} | 66 +++-------------- modules/sasl/scram.py | 87 +++++++++++++++++++++++ 3 files changed, 98 insertions(+), 55 deletions(-) rename modules/github/{module.py => __init__.py} (100%) rename modules/{sasl.py => sasl/__init__.py} (53%) create mode 100644 modules/sasl/scram.py diff --git a/modules/github/module.py b/modules/github/__init__.py similarity index 100% rename from modules/github/module.py rename to modules/github/__init__.py diff --git a/modules/sasl.py b/modules/sasl/__init__.py similarity index 53% rename from modules/sasl.py rename to modules/sasl/__init__.py index ed4faabf..b961ba00 100644 --- a/modules/sasl.py +++ b/modules/sasl/__init__.py @@ -1,5 +1,6 @@ import base64, hashlib, hmac, uuid from src import ModuleManager, utils +from . import scram def _validate(self, s): mechanism = s @@ -68,62 +69,17 @@ class Module(ModuleManager.BaseModule): sasl_username, sasl_password = sasl["args"].split(":", 1) if event["message"] == "+": # start SCRAM handshake - first_base = "n=%s,r=%s" % ( - _scram_escape(sasl_username), _scram_nonce()) - first_withchannel = "n,,%s" % first_base - auth_text = first_withchannel.encode("utf8") - event["server"]._scram_first = first_base.encode("utf8") + event["server"]._scram = scram.SCRAM( + algo, sasl_username, sasl_password) + auth_text = event["server"]._scram.client_first() + print(auth_text) else: - data = base64.b64decode(event["message"]).decode("utf8") - pieces = dict(piece.split("=", 1) for piece in data.split(",")) - if "s" in pieces: - # server-first-message - nonce = pieces["r"].encode("utf8") - salt = base64.b64decode(pieces["s"]) - iterations = pieces["i"] - password = sasl_password.encode("utf8") - - salted_password = hashlib.pbkdf2_hmac(algo, password, salt, - int(iterations), dklen=None) - event["server"]._scram_salted_password = salted_password - - client_key = hmac.digest(salted_password, b"Client Key", - algo) - stored_key = hashlib.new(algo, client_key).digest() - - channel = base64.b64encode(b"n,,") - auth_noproof = b"c=%s,r=%s" % (channel, nonce) - auth_message = b"%s,%s,%s" % (event["server"]._scram_first, - data.encode("utf8"), auth_noproof) - event["server"]._scram_auth_message = auth_message - - client_signature = hmac.digest(stored_key, auth_message, - algo) - client_proof = base64.b64encode( - _scram_xor(client_key, client_signature)) - - auth_text = auth_noproof + (b",p=%s" % client_proof) - elif "v" in pieces: - # server-final-message - verifier = pieces["v"] - - salted_password = event["server"]._scram_salted_password - auth_message = event["server"]._scram_auth_message - server_key = hmac.digest(salted_password, b"Server Key", - algo) - server_signature = hmac.digest(server_key, auth_message, - algo) - - del event["server"]._scram_first - del event["server"]._scram_salted_password - del event["server"]._scram_auth_message - - if server_signature != base64.b64decode(verifier): - raise ValueError("SCRAM %s authentication failed " - % algo) - event["server"].disconnect() - auth_text = "+" - + current_scram = event["server"]._scram + if current_scram.state == scram.SCRAMState.ClientFirst: + auth_text = current_scram.server_first(event["message"]) + elif current_scram.state == scram.SCRAMState.ClientFinal: + auth_text = current_scram.server_final(event["message"]) + del event["server"]._scram else: raise ValueError("unknown sasl mechanism '%s'" % mechanism) diff --git a/modules/sasl/scram.py b/modules/sasl/scram.py new file mode 100644 index 00000000..7d2551b4 --- /dev/null +++ b/modules/sasl/scram.py @@ -0,0 +1,87 @@ +import base64, enum, hashlib, hmac, uuid + +def _scram_nonce(): + return uuid.uuid4().hex.encode("utf8") +def _scram_escape(s): + return s.replace(b"=", b"=3D").replace(b",", b"=2C") +def _scram_unescape(s): + return s.replace(b"=3D", b"=").replace(b"=2C", b",") +def _scram_xor(s1, s2): + return bytes(a ^ b for a, b in zip(s1, s2)) + +class SCRAMState(enum.Enum): + Uninitialised = 0 + ClientFirst = 1 + ClientFinal = 2 + Success = 3 + VerifyFailed = 4 + +class SCRAMError(Exception): + pass + +class SCRAM(object): + def __init__(self, algo, username, password): + self._algo = algo + self._username = username.encode("utf8") + self._password = password.encode("utf8") + + self.state = SCRAMState.Uninitialised + self._client_first = None + self._salted_password = None + self._auth_message = None + + def _get_data(self, message): + data = base64.b64decode(message) + return data, dict(piece.split(b"=", 1) for piece in data.split(b",")) + + def client_first(self): + self.state = SCRAMState.ClientFirst + # start SCRAM handshake + self._client_first = b"n=%s,r=%s" % ( + _scram_escape(self._username), _scram_nonce()) + return b"n,,%s" % self._client_first + + def server_first(self, message): + self.state = SCRAMState.ClientFinal + + data, pieces = self._get_data(message) + # server-first-message + nonce = pieces[b"r"] + salt = base64.b64decode(pieces[b"s"]) + iterations = pieces[b"i"] + password = self._password + + salted_password = hashlib.pbkdf2_hmac(self._algo, password, salt, + int(iterations), dklen=None) + self._salted_password = salted_password + + client_key = hmac.digest(salted_password, b"Client Key", self._algo) + stored_key = hashlib.new(self._algo, client_key).digest() + + channel = base64.b64encode(b"n,,") + auth_noproof = b"c=%s,r=%s" % (channel, nonce) + auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof) + self._auth_message = auth_message + + client_signature = hmac.digest(stored_key, auth_message, self._algo) + client_proof = base64.b64encode( + _scram_xor(client_key, client_signature)) + + return auth_noproof + (b",p=%s" % client_proof) + + def server_final(self, message): + # server-final-message + data, pieces = self._get_data(message) + verifier = pieces[b"v"] + + server_key = hmac.digest(self._salted_password, b"Server Key", + self._algo) + server_signature = hmac.digest(server_key, self._auth_message, + self._algo) + + if server_signature != base64.b64decode(verifier): + self.state = SCRAMState.VerifyFailed + return None + else: + self.state = SCRAMState.Success + return "+"