diff --git a/modules/ircv3_chathistory.py b/modules/ircv3_chathistory.py index 36609368..5f72850d 100644 --- a/modules/ircv3_chathistory.py +++ b/modules/ircv3_chathistory.py @@ -12,8 +12,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") @utils.hook("received.cap.new") def on_cap(self, event): - if EVENTPLAYBACK_CAP.available(event["capabilities"]): - return EVENTPLAYBACK_CAP.copy() + return EVENTPLAYBACK_CAP.copy() @utils.hook("received.batch.end") def batch_end(self, event): diff --git a/modules/ircv3_labeled_responses.py b/modules/ircv3_labeled_responses.py index 36fe8b35..952cddfa 100644 --- a/modules/ircv3_labeled_responses.py +++ b/modules/ircv3_labeled_responses.py @@ -23,8 +23,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") @utils.hook("received.cap.new") def on_cap(self, event): - if CAP.available(event["capabilities"]): - return CAP.copy() + return CAP.copy() @utils.hook("preprocess.send") def raw_send(self, event): diff --git a/modules/ircv3_metadata.py b/modules/ircv3_metadata.py index 55193dce..dd210d54 100644 --- a/modules/ircv3_metadata.py +++ b/modules/ircv3_metadata.py @@ -7,9 +7,8 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") def on_cap(self, event): cap = CAP.copy() - if cap.available(event["capabilities"]): - cap.on_ack(lambda: self._ack(event["server"])) - return cap + cap.on_ack(lambda: self._ack(event["server"])) + return cap def _ack(self, server): url = self.bot.get_setting("bot-url", IRCBot.SOURCE) diff --git a/modules/ircv3_multi_line.py b/modules/ircv3_multi_line.py index a014566c..e8b65c3b 100644 --- a/modules/ircv3_multi_line.py +++ b/modules/ircv3_multi_line.py @@ -7,8 +7,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") @utils.hook("received.cap.new") def on_cap(self, event): - if CAP.available(event["capabilities"]): - return CAP.copy() + return CAP.copy() @utils.hook("preprocess.send.privmsg") def preprocess_send_privmsg(self, event): diff --git a/modules/ircv3_resume.py b/modules/ircv3_resume.py index 19afb3b4..1750e954 100644 --- a/modules/ircv3_resume.py +++ b/modules/ircv3_resume.py @@ -24,10 +24,9 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") def on_cap_ls(self, event): - if CAP.available(event["capabilities"]): - cap = CAP.copy() - cap.on_ack(lambda: self._cap_ack(event["server"])) - return cap + cap = CAP.copy() + cap.on_ack(lambda: self._cap_ack(event["server"])) + return cap def _cap_ack(self, server): server.wait_for_capability("resume") @@ -39,7 +38,7 @@ class Module(ModuleManager.BaseModule): if event["args"][0] == "SUCCESS": resume_channels = event["server"].get_setting("resume-channels", []) self.log.info("Successfully resumed session") - event["server"].cap_started = False + event["server"].clear_waiting_capabilities() elif event["args"][0] == "TOKEN": token = self._get_token(event["server"]) diff --git a/modules/ircv3_server_time.py b/modules/ircv3_server_time.py index e363b341..31126b45 100644 --- a/modules/ircv3_server_time.py +++ b/modules/ircv3_server_time.py @@ -7,8 +7,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.cap.ls") @utils.hook("received.cap.new") def on_cap(self, event): - if CAP.available(event["capabilities"]): - return CAP.copy() + return CAP.copy() @utils.hook("raw.received") def raw_recv(self, event): diff --git a/modules/line_handler/ircv3.py b/modules/line_handler/ircv3.py index 9111a537..3c389c85 100644 --- a/modules/line_handler/ircv3.py +++ b/modules/line_handler/ircv3.py @@ -17,60 +17,22 @@ CAPABILITIES = [ utils.irc.Capability(None, "draft/setname") ] -def _match_caps(our_capabilities, offered_capabilities): - matched = {} - for capability in our_capabilities: - available = capability.available(offered_capabilities) - if available: - matched[available] = capability - return matched - -def _caps_offered(server, caps): +def _cap_match(server, caps): + matched_caps = {} blacklist = server.get_setting("blacklisted-caps", []) - for cap_name, cap in caps.items(): - if not cap_name in blacklist: - server.capability_queue[cap_name] = cap + for cap in caps: + available = cap.available(server.server_capabilities) + if (available and not server.has_capability(cap) and + not available in blacklist): + matched_caps[available] = cap + return matched_caps def cap(events, event): capabilities = utils.parse.keyvalue(event["args"][-1]) subcommand = event["args"][1].upper() is_multiline = len(event["args"]) > 3 and event["args"][2] == "*" - if subcommand == "LS": - event["server"].cap_started = True - event["server"].server_capabilities.update(capabilities) - if not is_multiline: - server_caps = list(event["server"].server_capabilities.keys()) - matched_caps = _match_caps(CAPABILITIES, server_caps) - - module_caps = events.on("received.cap.ls").call( - capabilities=event["server"].server_capabilities, - server=event["server"]) - module_caps = list(filter(None, module_caps)) - matched_caps.update(_match_caps(module_caps, server_caps)) - - _caps_offered(event["server"], matched_caps) - - if event["server"].capability_queue: - event["server"].send_capability_queue() - else: - event["server"].send_capability_end() - elif subcommand == "NEW": - capabilities_keys = capabilities.keys() - event["server"].server_capabilities.update(capabilities) - - matched_caps = _match_caps(CAPABILITIES, list(capabilities_keys)) - - module_caps = events.on("received.cap.new").call( - server=event["server"], capabilities=capabilities) - module_caps = list(filter(None, module_caps)) - matched_caps.update(_match_caps(module_caps, capabilities_keys)) - - _caps_offered(event["server"], matched_caps) - - if event["server"].capability_queue: - event["server"].send_capability_queue() - elif subcommand == "DEL": + if subcommand == "DEL": for capability in capabilities.keys(): event["server"].agreed_capabilities.discard(capability) del event["server"].server_capabilities[capability] @@ -82,6 +44,27 @@ def cap(events, event): events.on("received.cap.ack").call(capabilities=capabilities, server=event["server"]) + if subcommand == "LS" or subcommand == "NEW": + event["server"].server_capabilities.update(capabilities) + if not is_multiline: + server_caps = list(event["server"].server_capabilities.keys()) + all_caps = CAPABILITIES[:] + + module_caps = events.on("received.cap.ls").call( + capabilities=event["server"].server_capabilities, + server=event["server"]) + module_caps = list(filter(None, module_caps)) + all_caps.extend(module_caps) + + matched_caps = _cap_match(event["server"], all_caps) + event["server"].capability_queue.update(matched_caps) + + if event["server"].capability_queue: + event["server"].send_capability_queue() + else: + event["server"].send_capability_end() + + if subcommand == "ACK" or subcommand == "NAK": ack = subcommand == "ACK" for capability in capabilities: @@ -92,10 +75,8 @@ def cap(events, event): else: cap_obj.nak() - if (event["server"].cap_started and - not event["server"].capabilities_requested and + if (not event["server"].capabilities_requested and not event["server"].waiting_for_capabilities()): - event["server"].cap_started = False event["server"].send_capability_end() def authenticate(events, event): diff --git a/src/IRCServer.py b/src/IRCServer.py index 78d3ac09..17d0795c 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -35,7 +35,6 @@ class Server(IRCObject.Object): self.agreed_capabilities = set([]) # type: typing.Set[str] self.server_capabilities = {} # type: typing.Dict[str, str] self.batches = {} # type: typing.Dict[str, utils.irc.IRCBatch] - self.cap_started = False self.users = {} # type: typing.Dict[str, IRCUser.User] self.new_users = set([]) #type: typing.Set[IRCUser.User] @@ -323,9 +322,12 @@ class Server(IRCObject.Object): def wait_for_capability(self, capability: str): self._capabilities_waiting.add(capability) def capability_done(self, capability: str): - self._capabilities_waiting.discard(capability) - if self.cap_started and not self._capabilities_waiting: - self.send_capability_end() + if capability in self._capabilities_waiting: + self._capabilities_waiting.discard(capability) + if not self._capabilities_waiting: + self.send_capability_end() + def clear_waiting_capabilities(self): + self._capabilities_waiting.clear() def send_pass(self, password: str) -> typing.Optional[IRCLine.SentLine]: return self.send(utils.irc.protocol.password(password))