refactor sasl a little and fix logic to retry on 908

This commit is contained in:
jesopo 2019-11-08 13:30:08 +00:00
parent f8083ae7b4
commit e26d769b7b

View file

@ -38,50 +38,53 @@ class Module(ModuleManager.BaseModule):
@utils.hook("new.server") @utils.hook("new.server")
def new_server(self, event): def new_server(self, event):
event["server"]._sasl_timeout = None event["server"]._sasl_timeout = None
event["server"]._sasl_retry = False
def _best_userpass_mechanism(self, mechanisms): def _best_userpass_mechanism(self, mechanisms):
for potential_mechanism in USERPASS_MECHANISMS: for potential_mechanism in USERPASS_MECHANISMS:
if potential_mechanism in mechanisms: if potential_mechanism in mechanisms:
return potential_mechanism return potential_mechanism
def _mech_match(self, server, server_mechanisms):
our_sasl = server.get_setting("sasl", None)
our_mechanism = our_sasl["mechanism"].upper()
if not server_mechanisms and our_mechanism in ALL_MECHANISMS:
return our_mechanism
elif our_mechanism in server_mechanisms:
return our_mechanism
elif our_mechanism == "USERPASS":
if server_mechanisms:
return self._best_userpass_mechanism(server_mechanisms)
else:
return USERPASS_MECHANISMS[0]
return None
@utils.hook("received.cap.new") @utils.hook("received.cap.new")
@utils.hook("received.cap.ls") @utils.hook("received.cap.ls")
def on_cap(self, event): def on_cap(self, event):
has_sasl = "sasl" in event["capabilities"] has_sasl = "sasl" in event["capabilities"]
our_sasl = event["server"].get_setting("sasl", None) if has_sasl:
server_mechanisms = event["capabilities"]["sasl"]
do_sasl = False if server_mechanisms:
if has_sasl and our_sasl: server_mechanisms = server_mechanisms.split(",")
if not event["capabilities"]["sasl"] == None:
our_mechanism = our_sasl["mechanism"].upper()
server_mechanisms = event["capabilities"]["sasl"].split(",")
if our_mechanism == "USERPASS":
our_mechanism = self._best_userpass_mechanism(
server_mechanisms)
do_sasl = our_mechanism in server_mechanisms
else: else:
do_sasl = True server_mechanisms = []
if do_sasl: mechanism = self._mech_match(event["server"], server_mechanisms)
if mechanism:
cap = CAP.copy() cap = CAP.copy()
cap.on_ack(lambda: self._sasl_ack(event["server"])) cap.on_ack(
lambda: self._sasl_ack(event["server"], mechanism))
return cap return cap
def _sasl_ack(self, server): def _sasl_ack(self, server, mechanism):
sasl = server.get_setting("sasl")
mechanism = sasl["mechanism"].upper()
if mechanism == "USERPASS":
server_mechanisms = server.server_capabilities["sasl"]
server_mechanisms = server_mechanisms or [
USERPASS_MECHANISMS[0]]
mechanism = self._best_userpass_mechanism(server_mechanisms)
server.send_authenticate(mechanism) server.send_authenticate(mechanism)
timer = self.timers.add("sasl-timeout", self._sasl_timeout, server._sasl_timeout = self.timers.add("sasl-timeout",
SASL_TIMEOUT, server=server) self._sasl_timeout, SASL_TIMEOUT, server=server)
server._sasl_timeout = timer server._sasl_mechanism = mechanism
server.sasl_mechanism = mechanism
server.wait_for_capability("sasl") server.wait_for_capability("sasl")
def _sasl_timeout(self, timer): def _sasl_timeout(self, timer):
@ -91,7 +94,7 @@ class Module(ModuleManager.BaseModule):
@utils.hook("received.authenticate") @utils.hook("received.authenticate")
def on_authenticate(self, event): def on_authenticate(self, event):
sasl = event["server"].get_setting("sasl") sasl = event["server"].get_setting("sasl")
mechanism = event["server"].sasl_mechanism mechanism = event["server"]._sasl_mechanism
auth_text = None auth_text = None
if mechanism == "PLAIN": if mechanism == "PLAIN":
@ -157,17 +160,22 @@ class Module(ModuleManager.BaseModule):
@utils.hook("received.908") @utils.hook("received.908")
def sasl_mechanisms(self, event): def sasl_mechanisms(self, event):
server_mechanisms = event["line"].args[1].split(",") server_mechanisms = event["line"].args[1].split(",")
mechanism = self._best_userpass_mechanism(server_mechanisms) mechanism = self._mech_match(event["server"], server_mechanisms)
event["server"].sasl_mechanism = mechanism if mechanism:
event["server"]._sasl_mechanism = mechanism
event["server"].send_authenticate(mechanism) event["server"].send_authenticate(mechanism)
event["server"]._sasl_retry = True
@utils.hook("received.903") @utils.hook("received.903")
def sasl_success(self, event): def sasl_success(self, event):
self._end_sasl(event["server"]) self._end_sasl(event["server"])
@utils.hook("received.904") @utils.hook("received.904")
def sasl_failure(self, event): def sasl_failure(self, event):
if not event["server"]._sasl_retry:
self._panic(event["server"], "ERR_SASLFAIL (%s)" % self._panic(event["server"], "ERR_SASLFAIL (%s)" %
event["line"].args[1]) event["line"].args[1])
else:
event["server"]._sasl_retry = False
@utils.hook("received.907") @utils.hook("received.907")
def sasl_already(self, event): def sasl_already(self, event):