From 257659aa73b1b1488ace4ef83ba78e1217aa6031 Mon Sep 17 00:00:00 2001 From: jesopo Date: Mon, 5 Nov 2018 20:51:33 +0000 Subject: [PATCH] Change modules/sts.py to reference connection_params, fix some typos and logig issues --- modules/sts.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/modules/sts.py b/modules/sts.py index b4f3c83e..0bdc46c2 100644 --- a/modules/sts.py +++ b/modules/sts.py @@ -10,22 +10,21 @@ class Module(ModuleManager.BaseModule): def _remove_policy(self, server): server.del_setting("sts-policy") - def _set_policy(self, server, port, duration, one_shot): + def set_policy(self, server, port, duration): expiration = None self._set_policy(server, { "port": port, "from": time.time(), - "duration": duration, - "one-shot": one_shot}) - def _change_duration(self, server, info): + "duration": duration}) + def change_duration(self, server, info): duration = int(info["duration"]) if duration == 0: self._remove_policy(server) else: - port = event["server"].port + port = server.connection_params.port if "port" in info: port = int(info["port"]) - self._set_policy(server, port, duration, False) + self.set_policy(server, port, duration) def _get_sts(self, capabilities): return capabilities.get("sts", capabilities.get("draft/sts", None)) @@ -35,37 +34,32 @@ class Module(ModuleManager.BaseModule): sts = self._get_sts(event["capabilities"]) if sts: info = utils.parse.keyvalue(sts, delimiter=",") - if not event["server"].tls: - self._set_policy(event["server"], int(info["port"]), - None, True) + if not event["server"].connection_params.tls: + self.set_policy(event["server"], int(info["port"]), None) event["server"].disconnect() self.bot.reconnect(event["server"].id, event["server"].connection_params) else: - self._change_duration(event["server"], info) + self.change_duration(event["server"], info) @utils.hook("received.cap.new") def on_cap_new(self, event): sts = self._get_sts(event["capabilities"]) - if sts and event["server"].tls: + if sts and event["server"].connection_params.tls: info = utils.parse.keyvalue(sts, delimiter=",") - if event["server"].tls: - self._change_duration(event["server"], info) + self.change_duration(event["server"], info) @utils.hook("new.server") def new_server(self, event): sts_policy = self._get_policy(event["server"]) if sts_policy: - if sts_policy["one-shot"]: - self._remove_policy(event["server"]) - if not event["server"].tls: - expiration = sts_policy["from"]+sts_policy + if not event["server"].connection_params.tls: if not sts_policy["duration"] or time.time() <= ( sts_policy["from"]+sts_policy["duration"]): self.log.trace("Applying STS policy for '%s'", [str(event["server"])]) - event["server"].tls = True - event["server"].port = sts_policy["port"] + event["server"].connection_params.tls = True + event["server"].connection_params.port = sts_policy["port"] @utils.hook("server.disconnect") def on_disconnect(self, event):