From d291cd506373046d703450bc9d97082b89cbe8be Mon Sep 17 00:00:00 2001 From: jesopo Date: Sat, 11 May 2019 18:22:40 +0100 Subject: [PATCH] Revamp how CAPs are tracked through REQ and ACK/NAK etc --- modules/line_handler/ircv3.py | 55 ++++++++++++++++++++++------------- modules/resume.py | 10 +++---- modules/sasl/__init__.py | 28 +++++++++--------- src/IRCServer.py | 21 ++++--------- src/utils/irc/__init__.py | 9 ++++++ 5 files changed, 68 insertions(+), 55 deletions(-) diff --git a/modules/line_handler/ircv3.py b/modules/line_handler/ircv3.py index ea68c67b..64a349b4 100644 --- a/modules/line_handler/ircv3.py +++ b/modules/line_handler/ircv3.py @@ -19,14 +19,20 @@ CAPABILITIES = [ utils.irc.Capability(None, "draft/setname") ] -def _match_caps(capabilities): - matched = [] - for capability in CAPABILITIES: - available = capability.available(capabilities) +def _match_caps(our_capabilities, offered_capabilities): + matched = {} + for capability in our_capabilities: + available = capability.available(offered_capabilities) if available: - matched.append(available) + matched[available] = capability return matched +def _caps_offered(server, caps): + blacklist = server.get_setting("blacklisted-caps", []) + for cap_name, cap in caps.items(): + if not cap_name in blacklist: + server.capability_queue[cap_name] = cap + def cap(events, event): capabilities = utils.parse.keyvalue(event["args"][-1]) subcommand = event["args"][1].lower() @@ -36,20 +42,18 @@ def cap(events, event): event["server"].cap_started = True event["server"].server_capabilities.update(capabilities) if not is_multiline: - matched_caps = _match_caps( - list(event["server"].server_capabilities.keys())) - blacklisted_caps = event["server"].get_setting( - "blacklisted-caps", []) - matched_caps = list( - set(matched_caps)-set(blacklisted_caps)) + server_caps = list(event["server"].server_capabilities.keys()) + matched_caps = _match_caps(CAPABILITIES, server_caps) - event["server"].queue_capabilities(matched_caps) - - events.on("received.cap.ls").call( + module_caps = events.on("received.cap.ls").call( capabilities=event["server"].server_capabilities, server=event["server"]) + module_caps = list(filter(None, module_caps)) + matched_caps.update(_match_caps(module_caps, server_caps)) - if event["server"].has_capability_queue(): + _caps_offered(event["server"], matched_caps) + + if event["server"].capability_queue: event["server"].send_capability_queue() else: event["server"].send_capability_end() @@ -58,12 +62,15 @@ def cap(events, event): event["server"].server_capabilities.update(capabilities) matched_caps = _match_caps(list(capabilities_keys)) - event["server"].queue_capabilities(matched_caps) - events.on("received.cap.new").call(server=event["server"], - capabilities=capabilities) + module_caps = events.on("received.cap.new").call( + server=event["server"], capabilities=capabilities) + module_caps = list(filter(None, module_caps)) + matched_caps.update(_match_caps(module_caps, capabilities_keys)) - if event["server"].has_capability_queue(): + _caps_offered(event["server"], matched_caps) + + if event["server"].capability_queue: event["server"].send_capability_queue() elif subcommand == "del": for capability in capabilities.keys(): @@ -78,11 +85,17 @@ def cap(events, event): server=event["server"]) if subcommand == "ack" or subcommand == "nak": + ack = subcommand == "ack" for capability in capabilities: - event["server"].requested_capabilities.remove(capability) + cap_obj = event["server"].capability_queue[capability] + del event["server"].capability_queue[capability] + if ack: + cap_obj.ack() + else: + cap_obj.nak() if (event["server"].cap_started and - not event["server"].requested_capabilities and + not event["server"].capability_queue and not event["server"].waiting_for_capabilities()): event["server"].cap_started = False event["server"].send_capability_end() diff --git a/modules/resume.py b/modules/resume.py index ebb5f29f..da1c6893 100644 --- a/modules/resume.py +++ b/modules/resume.py @@ -23,12 +23,12 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") def on_cap_ls(self, event): if CAP in event["capabilities"]: - event["server"].queue_capability(CAP) + cap = utils.irc.Capability(CAP) + cap.on_ack(lambda: self._cap_ack(event["server"])) + return cap - @utils.hook("received.cap.ack") - def on_cap_ack(self, event): - if CAP in event["capabilities"]: - event["server"].wait_for_capability("resume") + def _cap_ack(self, server): + server.wait_for_capability("resume") @utils.hook("received.resume") def on_resume(self, event): diff --git a/modules/sasl/__init__.py b/modules/sasl/__init__.py index f6292d4a..b0635b4d 100644 --- a/modules/sasl/__init__.py +++ b/modules/sasl/__init__.py @@ -41,22 +41,22 @@ class Module(ModuleManager.BaseModule): do_sasl = True if do_sasl: - event["server"].queue_capability("sasl") + cap = utils.irc.Capability("sasl") + cap.on_ack(lambda: self._sasl_ack(event["server"])) + return cap - @utils.hook("received.cap.ack") - def on_cap_ack(self, event): - if "sasl" in event["capabilities"]: - sasl = event["server"].get_setting("sasl") - mechanism = sasl["mechanism"].upper() - if mechanism == "USERPASS": - server_mechanisms = event["server"].server_capabilities["sasl"] - server_mechanisms = server_mechanisms or [ - USERPASS_MECHANISMS[0]] - mechanism = self._best_userpass_mechanism(server_mechanisms) + def _sasl_ack(self, server): + sasl = server.get_setting("sasl") + mechanism = sasl["mechanism"].upper() + if mechanism == "USERPASS": + server_mechanisms = server.server_capabilities["sasl"] + server_mechanisms = server_mechanisms or [ + USERPASS_MECHANISMS[0]] + mechanism = self._best_userpass_mechanism(server_mechanisms) - event["server"].send_authenticate(mechanism) - event["server"].sasl_mechanism = mechanism - event["server"].wait_for_capability("sasl") + server.send_authenticate(mechanism) + server.sasl_mechanism = mechanism + server.wait_for_capability("sasl") @utils.hook("received.authenticate") def on_authenticate(self, event): diff --git a/src/IRCServer.py b/src/IRCServer.py index 1a556ee5..01010fce 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -24,10 +24,10 @@ class Server(IRCObject.Object): self.realname = None # type: typing.Optional[str] self.hostname = None # type: typing.Optional[str] - self._capability_queue = set([]) # type: typing.Set[str] + self.capability_queue = { + } # type: typing.Dict[str, utils.irc.Capability] self._capabilities_waiting = set([]) # type: typing.Set[str] self.agreed_capabilities = set([]) # type: typing.Set[str] - self.requested_capabilities = [] # type: typing.List[str] self.server_capabilities = {} # type: typing.Dict[str, str] self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine] self.cap_started = False @@ -269,21 +269,12 @@ class Server(IRCObject.Object): def send_capibility_ls(self) -> IRCLine.SentLine: return self.send(utils.irc.protocol.capability_ls()) - def queue_capability(self, capability: str): - self._capability_queue.add(capability) - def queue_capabilities(self, capabilities: typing.List[str]): - self._capability_queue.update(capabilities) def send_capability_queue(self): - if self.has_capability_queue(): - capability_queue = list(self._capability_queue) - self._capability_queue.clear() + capability_queue = [cap for cap in self.capability_queue.keys()] - for i in range(0, len(capability_queue), 10): - capability_batch = capability_queue[i:i+10] - self.requested_capabilities += capability_batch - self.send_capability_request(" ".join(capability_batch)) - def has_capability_queue(self): - return bool(len(self._capability_queue)) + for i in range(0, len(capability_queue), 10): + capability_batch = capability_queue[i:i+10] + self.send_capability_request(" ".join(capability_batch)) def send_capability_request(self, capability: str) -> IRCLine.SentLine: return self.send(utils.irc.protocol.capability_request(capability)) def send_capability_end(self) -> IRCLine.SentLine: diff --git a/src/utils/irc/__init__.py b/src/utils/irc/__init__.py index 15850cea..2e1ef6e4 100644 --- a/src/utils/irc/__init__.py +++ b/src/utils/irc/__init__.py @@ -281,8 +281,17 @@ class Capability(object): self._caps = set([name, draft_name]) self._name = name self._draft_name = draft_name + self._on_ack_callbacks = [] def available(self, capabilities: typing.List[str]) -> str: match = list(set(capabilities)&self._caps) return match[0] if match else None def enabled(self, capability: str) -> bool: return capability in self._caps + + def on_ack(self, callback: typing.Callable[[], None]): + self._on_ack_callbacks.append(callback) + def ack(self): + for callback in self._on_ack_callbacks: + callback() + def nak(self): + pass