diff --git a/modules/sasl/__init__.py b/modules/sasl/__init__.py index e9424ba3..40604c38 100644 --- a/modules/sasl/__init__.py +++ b/modules/sasl/__init__.py @@ -64,19 +64,26 @@ class Module(ModuleManager.BaseModule): auth_text = "+" elif mechanism.startswith("SCRAM-"): - algo = mechanism.split("SCRAM-", 1)[1].replace("-", "") - sasl_username, sasl_password = sasl["args"].split(":", 1) + if event["message"] == "+": # start SCRAM handshake + + # create SCRAM helper + sasl_username, sasl_password = sasl["args"].split(":", 1) + algo = mechanism.split("SCRAM-", 1)[1].replace("-", "") event["server"]._scram = scram.SCRAM( algo, sasl_username, sasl_password) + + # generate client-first-message auth_text = event["server"]._scram.client_first() else: current_scram = event["server"]._scram data = base64.b64decode(event["message"]) if current_scram.state == scram.SCRAMState.ClientFirst: + # use server-first-message to generate client-final-message auth_text = current_scram.server_first(data) elif current_scram.state == scram.SCRAMState.ClientFinal: + # use server-final-message to check server proof verified = current_scram.server_final(data) del event["server"]._scram @@ -84,6 +91,7 @@ class Module(ModuleManager.BaseModule): auth_text = "+" else: if current_scram.state == scram.SCRAMState.VerifyFailed: + # server gave a bad verification so we should panic event["server"].disconnect() raise ValueError("Server SCRAM verification failed")