From aaf0c8cf2a413a2e4e7196dc3e08b7538ea60027 Mon Sep 17 00:00:00 2001 From: jesopo Date: Mon, 5 Nov 2018 14:12:21 +0000 Subject: [PATCH] Reschedule STS expiration on disconnect --- modules/sts.py | 32 +++++++++++++++++++++++--------- src/IRCBot.py | 1 + 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/modules/sts.py b/modules/sts.py index c16e8214..184f7dfd 100644 --- a/modules/sts.py +++ b/modules/sts.py @@ -2,18 +2,24 @@ import time from src import ModuleManager, utils class Module(ModuleManager.BaseModule): + def _get_policy(self, server): + return server.get_setting("sts-policy", None) + def _set_policy(self, server, policy): + server.set_setting("sts-policy", policy) + def _remove_policy(self, server): + server.del_setting("sts-policy") + def _set_policy(self, server, port, duration, one_shot): expiration = None - if duration: - expiration = time.time()+duration - server.set_setting("sts-policy", { + self._set_policy(server, { "port": port, - "expiration": expiration, + "from": time.time() + "duration": duration, "one-shot": one_shot}) def _change_duration(self, server, info): duration = int(info["duration"]) if duration == 0: - server.del_setting("sts-policy") + self._remove_policy(server) else: port = event["server"].port if "port" in info: @@ -43,14 +49,22 @@ class Module(ModuleManager.BaseModule): @utils.hook("new.server") def new_server(self, event): - sts_policy = event["server"].get_setting("sts-policy") + sts_policy = self._get_policy(event["server"]) if sts_policy: if sts_policy["one-shot"]: - event["server"].del_setting("sts-policy") + self._remove_policy(event["server"]) if not event["server"].tls: - expiration = sts_policy["expiration"] - if not expiration or time.time() <= expiration: + expiration = sts_policy["from"]+sts_policy + if not sts_policy["duration"] or time.time() <= ( + sts_policy["from"]+sts_policy["duration"]): self.log.debug("Applying STS policy for '%s'", [str(event["server"])]) event["server"].tls = True event["server"].port = sts_policy["port"] + + @utils.hook("server.disconnect") + def on_disconnect(self, event): + sts_policy = self._get_policy(event["server"]) + if sts_policy: + sts_policy["from"] = time.time() + self._set_policy(event["server"], sts_policy diff --git a/src/IRCBot.py b/src/IRCBot.py index f9548e6e..95c8f617 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -199,6 +199,7 @@ class Bot(object): server.send_ping() server.ping_sent = True if not server.connected: + self._events.on("server.disconnect").call(server=server) self.disconnect(server) reconnect_delay = self.config.get("reconnect-delay", 10)