Support a USERPASS sasl mechanism that picks the best user:pass mech (sasl)

This commit is contained in:
jesopo 2019-02-14 11:57:53 +00:00
parent 708ba2ddc7
commit d0ad43b027

View file

@ -2,6 +2,13 @@ import base64, hashlib, hmac, uuid
from src import ModuleManager, utils from src import ModuleManager, utils
from . import scram from . import scram
USERPASS_MECHANISMS = [
"SCRAM-SHA-512",
"SCRAM-SHA-256",
"SCRAM-SHA-1",
"PLAIN"
]
def _validate(self, s): def _validate(self, s):
mechanism, _, arguments = s.partition(" ") mechanism, _, arguments = s.partition(" ")
return {"mechanism": mechanism, "args": arguments} return {"mechanism": mechanism, "args": arguments}
@ -19,6 +26,11 @@ def _scram_xor(s1, s2):
"help": "Set the sasl username/password for this server", "help": "Set the sasl username/password for this server",
"validate": _validate}) "validate": _validate})
class Module(ModuleManager.BaseModule): class Module(ModuleManager.BaseModule):
def _best_userpass_mechanism(self, mechanisms):
for potential_mechanism in USERPASS_MECHANISMS:
if potential_mechanism in mechanisms:
return potential_mechanism
@utils.hook("received.cap.new") @utils.hook("received.cap.new")
@utils.hook("received.cap.ls") @utils.hook("received.cap.ls")
def on_cap(self, event): def on_cap(self, event):
@ -29,8 +41,11 @@ class Module(ModuleManager.BaseModule):
if has_sasl and our_sasl: if has_sasl and our_sasl:
if not event["capabilities"]["sasl"] == None: if not event["capabilities"]["sasl"] == None:
our_mechanism = our_sasl["mechanism"].upper() our_mechanism = our_sasl["mechanism"].upper()
do_sasl = our_mechanism in event["capabilities" server_mechanisms = event["capabilities"]["sasl"].split(",")
]["sasl"].split(",") if our_mechanism == "USERPASS":
our_mechanism = self._best_userpass_mechanism(
server_mechanisms)
do_sasl = our_mechanism in server_mechanisms
else: else:
do_sasl = True do_sasl = True
@ -41,13 +56,21 @@ class Module(ModuleManager.BaseModule):
def on_cap_ack(self, event): def on_cap_ack(self, event):
if "sasl" in event["capabilities"]: if "sasl" in event["capabilities"]:
sasl = event["server"].get_setting("sasl") sasl = event["server"].get_setting("sasl")
event["server"].send_authenticate(sasl["mechanism"].upper()) mechanism = sasl["mechanism"].upper()
if mechanism == "USERPASS":
server_mechanisms = event["server"].server_capabilities["sasl"]
server_mechanisms = server_mechanisms or [
USERPASS_MECHANISMS[0]]
mechanism = self._best_userpass_mechanism(server_mechanisms)
event["server"].send_authenticate(mechanism)
event["server"].sasl_mechanism = mechanism
event["server"].wait_for_capability("sasl") event["server"].wait_for_capability("sasl")
@utils.hook("received.authenticate") @utils.hook("received.authenticate")
def on_authenticate(self, event): def on_authenticate(self, event):
sasl = event["server"].get_setting("sasl") sasl = event["server"].get_setting("sasl")
mechanism = sasl["mechanism"].upper() mechanism = event["server"].sasl_mechanism
auth_text = None auth_text = None
if mechanism == "PLAIN": if mechanism == "PLAIN":
@ -108,6 +131,13 @@ class Module(ModuleManager.BaseModule):
def _end_sasl(self, server): def _end_sasl(self, server):
server.capability_done("sasl") server.capability_done("sasl")
@utils.hook("received.numeric.908")
def sasl_mechanisms(self, event):
server_mechanisms = event["args"][1].split(",")
mechanism = self._best_userpass_mechanism(server_mechanimsms)
event["server"].sasl_mechanism = mechanism
event["server"].send_authenticate(mechanism)
@utils.hook("received.numeric.903") @utils.hook("received.numeric.903")
def sasl_success(self, event): def sasl_success(self, event):
self._end_sasl(event["server"]) self._end_sasl(event["server"])