diff --git a/modules/sasl/__init__.py b/modules/sasl/__init__.py index a4076729..408ba319 100644 --- a/modules/sasl/__init__.py +++ b/modules/sasl/__init__.py @@ -2,6 +2,13 @@ import base64, hashlib, hmac, uuid from src import ModuleManager, utils from . import scram +USERPASS_MECHANISMS = [ + "SCRAM-SHA-512", + "SCRAM-SHA-256", + "SCRAM-SHA-1", + "PLAIN" +] + def _validate(self, s): mechanism, _, arguments = s.partition(" ") return {"mechanism": mechanism, "args": arguments} @@ -19,6 +26,11 @@ def _scram_xor(s1, s2): "help": "Set the sasl username/password for this server", "validate": _validate}) class Module(ModuleManager.BaseModule): + def _best_userpass_mechanism(self, mechanisms): + for potential_mechanism in USERPASS_MECHANISMS: + if potential_mechanism in mechanisms: + return potential_mechanism + @utils.hook("received.cap.new") @utils.hook("received.cap.ls") def on_cap(self, event): @@ -29,8 +41,11 @@ class Module(ModuleManager.BaseModule): if has_sasl and our_sasl: if not event["capabilities"]["sasl"] == None: our_mechanism = our_sasl["mechanism"].upper() - do_sasl = our_mechanism in event["capabilities" - ]["sasl"].split(",") + server_mechanisms = event["capabilities"]["sasl"].split(",") + if our_mechanism == "USERPASS": + our_mechanism = self._best_userpass_mechanism( + server_mechanisms) + do_sasl = our_mechanism in server_mechanisms else: do_sasl = True @@ -41,13 +56,21 @@ class Module(ModuleManager.BaseModule): def on_cap_ack(self, event): if "sasl" in event["capabilities"]: sasl = event["server"].get_setting("sasl") - event["server"].send_authenticate(sasl["mechanism"].upper()) + mechanism = sasl["mechanism"].upper() + if mechanism == "USERPASS": + server_mechanisms = event["server"].server_capabilities["sasl"] + server_mechanisms = server_mechanisms or [ + USERPASS_MECHANISMS[0]] + mechanism = self._best_userpass_mechanism(server_mechanisms) + + event["server"].send_authenticate(mechanism) + event["server"].sasl_mechanism = mechanism event["server"].wait_for_capability("sasl") @utils.hook("received.authenticate") def on_authenticate(self, event): sasl = event["server"].get_setting("sasl") - mechanism = sasl["mechanism"].upper() + mechanism = event["server"].sasl_mechanism auth_text = None if mechanism == "PLAIN": @@ -108,6 +131,13 @@ class Module(ModuleManager.BaseModule): def _end_sasl(self, server): server.capability_done("sasl") + @utils.hook("received.numeric.908") + def sasl_mechanisms(self, event): + server_mechanisms = event["args"][1].split(",") + mechanism = self._best_userpass_mechanism(server_mechanimsms) + event["server"].sasl_mechanism = mechanism + event["server"].send_authenticate(mechanism) + @utils.hook("received.numeric.903") def sasl_success(self, event): self._end_sasl(event["server"])